In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Lambda

import cv2

import numpy as np

import matplotlib.pyplot as plt

import json
import os
import random

In [38]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.device(device)

device(type='cuda')

In [3]:
class EnvMapNetDataset(Dataset):
    def __init__(self, images_path, type):
        self.__images_path = os.path.join(images_path, type)
        self.__images = os.listdir(self.__images_path)

    def __len__(self):
        return len(self.__images)

    def __getitem__(self, index):
        image_path = os.path.join(self.__images_path, self.__images[index])
        data = np.load(image_path)
        data = torch.from_numpy(np.array([data[:,:,0], data[:,:,1], data[:,:,2]]))
        return data

In [None]:
def L1Loss(orig, pred, mask):
    return (orig * mask - pred * mask).abs().sum()

In [None]:
def L2Loss(orig, pred):
    return (orig - pred).norm().sum()

In [None]:
def FakeRealLoss(orig, pred):
    return (orig.log() + (1 - pred).log()).sum()

In [None]:
def AdversarialLoss(pred):
    return -pred.log().sum()

In [None]:
class ORBPatcher:
    def __init__(self, *, key_points=1000):
        self.__orb = cv2.ORB_create(key_points)

    def __call__(self, image):
        key_points, _ = self.__orb.detectAndCompute(image, None)

        res = np.zeros((image.shape[0], image.shape[1]), dtype=np.int)
        for point in key_points:
            x, y = map(np.int, point.pt)
            res[y, x] = 1

        res == 1
        return res.flatten()

In [None]:
class KMeansFiles:
    def __init__(self, orb, *, clusters=5, n_iter=300, partitions=1):
        self.__orb = orb
        self.__clusters = clusters
        self.__n_iter = n_iter
        self.__partitions = partitions

        self.__centers = None

        self.__centers_save_path = os.path.join('k_means.bin')

    def fit(self, files):
        partition = (len(files) + self.__partitions - 1) // self.__partitions
        print('preparing partitions\n')
        for num in range(self.__partitions):
            tmp = self.__prepare_partition(files[num::self.__partitions])
            torch.save(tmp, os.path.join('{}.pt'.format(num)))
            print('partition {} prepared'.format(num))
        print('partitions prepared\n'.format(len(files)))

        print('fitting k-means on {} samples\n'.format(len(files)))
        for iteration in range(self.__n_iter):
            for num in range(self.__partitions):
                tmp = torch.load(os.path.join('{}.pt'.format(num)))
                print('fitting on partition {}'.format(num))
                self.__fit_partition(tmp)
                print('partition fitted!\n')
            print('iteration {} fitted!\n'.format(iteration))
            if (iteration + 1) % 10:
                self.save()
        print('k-means ready!')

        self.__centers = self.__centers.to(device)

    def __prepare_partition(self, files):
        return torch.tensor([
            self.__orb((((np.load(file)) + 1) / 2 * 255).astype(np.uint8))
            for file in files
        ])

    def __fit_partition(self, orig):
        x = orig * 1.0

        n, d = x.shape
        if self.__centers is None:
            self.__centers = x[:self.__clusters, :].clone() * 1.0

        x_i = x.view(n, 1, d)
        c_j = self.__centers.view(1, self.__clusters, d)

        D_ij = ((x_i - c_j) ** 2).sum(-1)
        clusters = D_ij.argmin(dim=1).long().view(-1)

        self.__centers.zero_()
        self.__centers.scatter_add_(0, clusters[:, None].repeat(1, d), x)

        Ncl = torch.bincount(clusters, minlength=self.__clusters)
        Ncl = Ncl.type_as(self.__centers)
        Ncl = Ncl.view(self.__clusters, 1)
        self.__centers /= Ncl

    def predict(self, x):
        if self.__centers == None:
            print('k-means wasn\'t fitted')
            return

        n, d = x.shape
        x_i = x.view(n, 1, d)
        c_j = self.__centers.view(1, self.__clusters, d)
        D_ij = ((x_i - c_j) ** 2).sum(-1)
        clusters = D_ij.argmin(dim=1).long().view(-1)

        res = torch.zeros((n, self.__clusters), requires_grad=False).to(device)
        for num, item in enumerate(clusters):
            res[num, item] = 1
        return res

    def save(self):
        torch.save(self.__centers, self.__centers_save_path)

    def load(self):
        self.__centers = torch.load(self.__centers_save_path).to(device)

In [None]:
class ClusterLoss:
    def __init__(self, orb, *, train_k_means=False, partitions=5, clusters=5):
        self.__orb = orb

        self.__k_means = KMeansFiles(orb, clusters=clusters, partitions=partitions)
        if train_k_means:
            self.__k_means.fit(self.__prepare_data())
            self.__k_means.save()
        else:
            self.__k_means.load()

    def __call__(self, images, pred):
        if device != 'cpu':
            images = images.detach().cpu().numpy()
        else:
            images = images.detach().numpy()

        images = np.array([
            np.concatenate([
                image[0].reshape(*image[0].shape, 1),
                image[1].reshape(*image[1].shape, 1),
                image[2].reshape(*image[2].shape, 1),
            ], axis=2) for image in images
        ])
        
        base = self.__k_means.predict(torch.tensor([
            self.__orb(((image + 1) * 255 / 2).astype(np.uint8))
            for image in images
        ]).to(device))

        return -(base * torch.log(pred + 1e-9)).sum()

    def __prepare_data(self):
        catalogs = [
            os.path.join('LavalIndoorHDRDatasetReady', 'train'),
            os.path.join('PanoIndoorLDRDatasetReady', 'test'),
            os.path.join('LavalIndoorHDRDatasetReady', 'train'),
            os.path.join('PanoIndoorLDRDatasetReady', 'test')
        ]
        files = []
        for catalog in catalogs:
            files.extend([os.path.join(catalog, file) for file in os.listdir(catalog)])
        return files

In [None]:
class ProjectionLoss:
    def __init__(self, *, generate_masks=False, masks_count=50, base_shape=(1024, 2048)):
        self.__masks_path = 'masks'

        if generate_masks:
            self.__masks = [
                self.__build_mask(base_shape)
                for _ in range(masks_count)
            ]
            for num, mask in enumerate(self.__masks):
                path = os.path.join(self.__masks_path, str(num + 1))
                np.save(path, mask)
        else:
            self.__masks = []
            for i in range(masks_count):
                path = os.path.join(self.__masks_path, str(num + 1))
                self.__masks.append(np.load(path))

    def __call__(self, orig, pred):
        res = torch.zeros(1)

        for mask_np in self.__masks:
            mask = torch.from_numpy(mask_np).set_grad_enabled(False)
            res += torch.abs((orig * mask).sum() - (pred * mask).sum())

        return res

    def __build_mask(self, base_shape):
        base_shape = (*base_shape, 3)
        image = np.zeros(base_shape)

        for _ in range(int(random.gauss(4, 0.7))):
            scale = random.randint(10, 40)

            h, w, _ = base_shape
            h = h * scale // 100
            w = w * scale // 100

            tmp = np.ones((h, w, 3)) * 255

            angle = random.randint(0, 180)

            M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1)
            tmp = cv2.warpAffine(tmp, M, (w, h))

            h_c = random.randint(base_shape[0] // 3, 2 * base_shape[0] // 3)
            w_c = random.randint(base_shape[1] // 3, 2 * base_shape[1] // 3)

            top_pad = base_shape[0] - h_c - h // 2
            bottom_pad = h_c - h // 2 - h % 2

            right_pad = base_shape[1] - w_c - w // 2
            left_pad = w_c - w // 2 - w % 2

            tmp = np.pad(tmp, ((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)))

            image = cv2.add(image, tmp)

        return (image == 0).astype(np.int)

In [None]:
class EnvMapNetLoss:
    def __init__(self):
        self.ClusterLoss = ClusterLoss()
        self.ProjectionLoss = ProjectionLoss()

    def __call__(self, orig, mask, pred, discriminator):
        return\
            0.5 * L1Loss(orig, pred, mask) +\
            0.01 * L2Loss(orig, pred) +\
            self.ProjectionLoss(orig, pred) +\
            AdversarialLoss(discriminator(pred)) +\
            self.ClusterLoss(orig, pred)

In [None]:
class DiscriminatorLoss:
    def __init__(self):
        self.ClusterLoss = ClusterLoss()

    def __call__(self, orig, pred):
        return\
            FakeRealLoss(orig, pred) +\
            self.ClusterLoss(orig, pred)

In [4]:
class EnvMapNetConvBlock(nn.Module):
    def __init__(self, channels, *, sets_count=5, conv_channels=16):
        super(EnvMapNetConvBlock, self).__init__()
        self.__blocks = nn.Sequential(*[
            item for sublist in [
                self.__get_set(channels, conv_channels, i)
                for i in range(sets_count)
            ] for item in sublist
        ])

    def forward(self, x):
        short_cut = x
        counter = 0
        for block in self.__blocks:
            x = block(x)
            counter += 1
            if counter == 3:
                x = torch.cat((x, short_cut), 1)
                short_cut = x
                counter = 0
        return x
    
    def __get_set(self, in_channels, conv_channels, i):
        channels = in_channels + conv_channels * i
        return [
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(
                channels,
                conv_channels,
                kernel_size=(3, 3),
                padding=(1, 1),
                padding_mode='random',
            )
        ]

In [5]:
class EnvMapNetDownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EnvMapNetDownsampleBlock, self).__init__()
        self.__conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            padding_mode='random',
        )
        self.__downsample = nn.AvgPool2d((2, 2))

    def forward(self, x):
        x = self.__conv(x)
        x = self.__downsample(x)
        return x

In [6]:
class EnvMapNetUpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EnvMapNetUpsampleBlock, self).__init__()
        self.__conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            padding_mode='random',
        )
        self.__upsample = nn.Upsample(scale_factor=2)

    def forward(self, x):
        x = self.__upsample(x)
        x = self.__conv(x)
        return x

In [7]:
class Generator(nn.Module):
    def __init__(self,
                 dk_orig, uk_orig, *, sets_count=5,
                 conv_channels=16, neck_channels=64):
        super(Generator, self).__init__()
        
        dk = [3] + dk_orig
        self.__downsampling_blocks = nn.Sequential(*[
            item for sublist in [
                [
                    EnvMapNetConvBlock(
                        dk[i - 1],
                        sets_count=sets_count,
                        conv_channels=conv_channels
                    ),
                    EnvMapNetDownsampleBlock(
                        dk[i - 1] + sets_count * conv_channels,
                        dk[i]
                    )
                ]
                for i in range(1, len(dk))
            ] for item in sublist
        ])

        self.__neck = nn.Conv2d(
            dk[-1],
            neck_channels,
            kernel_size=(1, 1),
        )

        uk = [neck_channels] + uk_orig
        self.__upsampling_blocks = nn.Sequential(*[
            item for sublist in [
                [
                    EnvMapNetUpsampleBlock(
                        uk[i - 1] + (i != 1) *
                            (sets_count * conv_channels),
                        uk[i]),
                    EnvMapNetConvBlock(
                        uk[i],
                        sets_count=sets_count,
                        conv_channels=conv_channels
                    )
                ]
                for i in range(1, len(uk))
            ] for item in sublist
        ])

        self.__out = nn.Conv2d(
            uk[-1] + sets_count * conv_channels,
            3,
            kernel_size=(3, 3),
            padding=(1, 1),
            padding_mode='random',
        )

    def forward(self, x):
        for block in self.__downsampling_blocks:
            x = block(x)

        x = self.__neck(x)

        for block in self.__upsampling_blocks:
            x = block(x)

        x = self.__out(x)

        return torch.tanh(x)

In [8]:
class DiscriminatorResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, *, sets_count=2):
        super(DiscriminatorResidualBlock, self).__init__()
        self.__avg = nn.AvgPool2d((2, 2))
        self.__conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            padding_mode='random',
        )

        self.__blocks = nn.Sequential(*[
            item for sublist in [
                self.__get_set(in_channels if i == 0 else out_channels, out_channels)
                for i in range(sets_count)
            ] for item in sublist
        ])

    def forward(self, inp):
        sc = self.__avg(inp)
        sc = self.__conv(sc)
        sc = nn.functional.pad(
            sc, (sc.shape[-1] // 2, sc.shape[-1] // 2, sc.shape[-2] // 2, sc.shape[-2] // 2)
        )

        x = inp
        for block in self.__blocks:
            x = block(x)

        return x + sc
    
    def __get_set(self, in_channels, out_channels):
        return [
            nn.BatchNorm2d(in_channels),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=(3, 3),
                padding=(1, 1),
                padding_mode='random',
            )
        ]

In [9]:
class Discriminator(nn.Module):
    def __init__(self, ak_orig, *, sets_count=2):
        super(Discriminator, self).__init__()

        ak = [3] + ak_orig

        self.__blocks = nn.Sequential(*[
            DiscriminatorResidualBlock(ak[i - 1], ak[i], sets_count=sets_count)
            for i in range(1, len(ak))
        ])

    def forward(self, x):
        for block in self.__blocks:
            x = block(x)
        return x

In [None]:
dk = [64, 128, 128, 128, 256, 256, 512]
uk = [512, 256, 256, 128, 128, 128, 64]
ak = [64, 128, 256, 256, 256, 256, 256]

In [None]:
minibatch_size = 16
epochs = 500

In [None]:
generator = nn.DataParallel(Generator(dk, uk)).to(device)
discriminator = nn.DataParallel(Discriminator(ak)).to(device)

In [None]:
generator_train_dataloader = DataLoader(
    EnvMapNetDataset(os.path.join('LavalIndoorHDRDatasetReady'), type='train'), batch_size=minibatch_size, shuffle=True
)
generator_test_dataloader = DataLoader(
    EnvMapNetDataset(os.path.join('LavalIndoorHDRDatasetReady'), type='test'), batch_size=minibatch_size, shuffle=True
)
discriminator_train_dataloader = DataLoader(
    EnvMapNetDataset(os.path.join('PanoIndoorLDRDatasetReady'), type='train'), batch_size=minibatch_size // 2, shuffle=True
)
discriminator_test_dataloader = DataLoader(
    EnvMapNetDataset(os.path.join('PanoIndoorLDRDatasetReady'), type='test'), batch_size=minibatch_size // 2, shuffle=True
)

In [None]:
generator_loss = EnvMapNetLoss()
discriminator_loss = DiscriminatorLoss()

In [None]:
generator_optimiser = torch.optim.Adam(generator.parameters(), lr=0.0002)
discriminator_optimiser = torch.optim.Adam(discriminator.parameters(), lr=0.0002)

In [None]:
window_mask = torch.from_numpy()

In [41]:
def train_loop(generator,
               generator_train_dataloader,
               generator_loss,
               generator_optimiser,
               discriminator,
               discriminator_train_dataloader,
               discriminator_loss,
               discriminator_optimiser
              ):
    generator.train()
    discriminator.train()

    g_size = len(generator_train_dataloader.dataset)
    d_size = len(discriminator_train_dataloader.dataset)

    for batch, (g_data, d_data) in enumerate(zip(
        generator_train_dataloader,
        discriminator_train_dataloader
    )):
        mask = build_random_masks()

        g_data = g_data.to(device)
        g_pred = generator(g_data * mask)
        g_loss = generator_loss(g_data, mask, g_pred, discriminator)
        generator_optimiser.zero_grad()
        g_loss.backward()
        generator_optimiser.step()

        if batch % 100 == 0:
            loss, current = g_loss.item(), batch * len(g_data)
            print(f"loss: {loss:>7f}  [{current:>5d}/{g_size:>5d}]")

        d_data = d_data.to(device)
        d_pred = discriminator(d_data)
        d_loss = discriminator_loss(d_data, d_pred)
        discriminator_optimiser.zero_grad()
        d_loss.backward()
        discriminator_optimiser.step()

        if batch % 100 == 0:
            loss, current = d_loss.item(), batch * len(d_data)
            print(f"loss: {d_loss:>7f}  [{current:>5d}/{d_size:>5d}]")
            print()


def test_loop(generator,
              generator_test_dataloader,
              generator_loss,
              discriminator,
              discriminator_test_dataloader,
              discriminator_loss
             ):
    generator.eval()
    discriminator.eval()

    g_size = len(generator_test_dataloader.dataset)
    d_size = len(discriminator_test_dataloader.dataset)
    g_loss, d_loss = 0, 0

    with torch.no_grad():
        for g_data, d_data in zip(
            generator_train_dataloader,
            discriminator_train_dataloader
        ):
            mask = build_random_masks()

            g_data = g_data.to(device)
            g_pred = generator(g_data * mask)
            g_loss += generator_loss(g_data, mask, g_pred, discriminator).item()

            d_data = d_data.to(device)
            d_pred = discriminator(d_data)
            d_loss += discriminator_loss(d_data, d_pred).item()

    g_loss /= g_size
    d_loss /= d_size

    print("Test Error:")
    print(f" Avg generator loss: {g_loss:>8f}")
    print(f" Avg discriminator loss: {d_loss:>8f}")
    print()

In [39]:
for t in range(epochs):
    print(f"Epoch {t + 1} started")
    train_loop(
        generator,
        generator_train_dataloader,
        generator_loss,
        generator_optimiser,
        discriminator,
        discriminator_train_dataloader,
        discriminator_loss,
        discriminator_optimiser
    )
    print(f"Epoch {t + 1} trained")
    test_loop(
        generator,
        generator_test_dataloader,
        generator_loss,
        discriminator,
        discriminator_test_dataloader,
        discriminator_loss
    )
    print(f"Epoch {t + 1} tested\n")
    if (t + 1) % 10 == 0:
        torch.save(generator.state_dict(), os.path.join('generator'))
        torch.save(discriminator.state_dict(), os.path.join('discriminator'))
        print(f"Model by epoch {t + 1} saved\n")