In [1]:
import os
from argparse import ArgumentParser
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_lightning as pl

from PIL import Image

from glob import glob

# Data Set augmentation and Dataset class

In [2]:
# Random Crop Augment
class RandomCrop():
    def __init__(self, min_crop = 256, max_crop = 1024):
        self.min_crop = min_crop
        self.max_crop = max_crop
    
    @staticmethod
    def random_crop(I1, crop_size):
        h, w = I1.size # Assume I1.size == I2.size
        th, tw = crop_size
        
        if w == tw and h == th:
            return 0, 0, h, w

        i = torch.randint(0, h - th + 1, size=(1, )).item()
        j = torch.randint(0, w - tw + 1, size=(1, )).item()
        return i, j, i+th, j+tw
    
    def __call__(self, img_set):
        I1, I2 = img_set
        cs = torch.randint(self.min_crop, self.max_crop, size=(1, )).item()
        crop_size = (cs, cs)
        bbox =  self.random_crop(I1, crop_size)
        
        return I1.crop(bbox), I2.crop(bbox)
    
class RandomRotate():
    def __init__(self, min_angle = -25, max_angle = 25):
        self.min_angle = min_angle
        self.max_angle = max_angle
    
    def __call__(self, img_set):
        I1, I2 = img_set
        angle = torch.randint(self.min_angle, self.max_angle, size=(1, )).item()
        return TF.rotate(I1, angle, Image.BILINEAR), TF.rotate(I2, angle, Image.BILINEAR)

class RandomFlip():
    def __init__(self):
        pass
    
    def __call__(self, img_set):
        I1, I2 = img_set
        horizontal_flip = np.random.choice([True, False])
        vertical_flip = np.random.choice([True, False])
        if horizontal_flip:
            I1 = I1.transpose(Image.FLIP_LEFT_RIGHT)
            I2 = I2.transpose(Image.FLIP_LEFT_RIGHT)
        if vertical_flip:
            I1 = I1.transpose(Image.FLIP_TOP_BOTTOM)
            I2 = I2.transpose(Image.FLIP_TOP_BOTTOM)
        
        return I1, I2

class ColorJitter():
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.color_jitter = transforms.ColorJitter(brightness, contrast, saturation, hue)
    
    def __call__(self, img_set):
        I1, I2 = img_set
        
        return self.color_jitter(I1), I2

class CenterCrop():
    def __init__(self, crop):
        self.center_crop = transforms.CenterCrop(crop)
    
    def __call__(self, img_set):
        I1, I2 = img_set
        return self.center_crop(I1), self.center_crop(I2)
    

class Resize():
    def __init__(self, size):
        self.resize = transforms.Resize(size)
    
    def __call__(self, img_set):
        I1, I2 = img_set
        return self.resize(I1), self.resize(I2)
    
class ToTensor():
    def __init__(self):
        self.tensor = transforms.ToTensor()
    
    def __call__(self, img_set):
        I1, I2 = img_set
        return self.tensor(I1), self.tensor(I2)
    

class Normalize():
    def __init__(self, mean = (0.5,0.5, 0.5), std = (0.5,0.5, 0.5)):
        self.norm = transforms.Normalize(mean, std)
    
    def __call__(self, img_set):
        I1, I2 = img_set
        return self.norm(I1), self.norm(I2)

data_transform = transforms.Compose([
    RandomRotate(min_angle = -25, max_angle = 25),
    RandomCrop(min_crop = 256, max_crop = 1024),
    Resize(340),
    CenterCrop(256),
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.5),
    RandomFlip(),
    ToTensor(),
    Normalize((0.5,0.5, 0.5), (0.5,0.5, 0.5))
])
    
class TextureDataset(Dataset):
    def __init__(self, imgs, transform = data_transform):

        self.imgs = imgs
        self.num_samples = len(self.imgs)
        self.transform = transform

    def __getitem__(self, i):

        I = Image.open(self.imgs[i]).convert("RGB")
        I_diffuse = I.crop((0, 0, 1024, 1024)) 
        I_normal = I.crop((1024, 0, 2048, 1024)) 
        if self.transform:
            I_diffuse, I_normal = self.transform([I_diffuse, I_normal])
        return I_diffuse, I_normal

    def __len__(self):
        return self.num_samples

In [3]:
class TextureLightDataModule(pl.LightningDataModule):

    def __init__(self, data_fldr: str = '', data_ext: str = '.png', batch_size: int = 64, train_val_split:float = 0.9, num_workers: int = 0):
        super().__init__()
        self.data_fldr = data_fldr
        self.data_ext = data_ext
        self.batch_size = batch_size
        self.train_val_split = train_val_split
        self.num_workers = num_workers
        
    def prepare_data(self):
        self.imgs = glob(self.data_fldr + os.sep + "*" + self.data_ext)
        self.num_imgs = len(self.imgs)

    def setup(self, stage_name):
        np.random.shuffle(self.imgs)
        
        self.prepare_data()
        
        self.train_dataset = TextureDataset(self.imgs)
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

In [4]:
def conv_block(in_ch = 3, out_ch = 32, ker = 3, stride = 1, groups = 1):

        layers = []
        pad = ker//2 if ker > 1 else 0
        layers += [nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=ker, padding = pad, bias=True, stride = stride, groups = groups),
                   nn.InstanceNorm2d(out_ch),
                   nn.ReLU(True),
                   ]

        return layers

class Reshape(nn.Module):
    def __init__(self, reshape = [6, 256, 256]):
        super().__init__()
        self.reshape = reshape
    
    def forward(self, z):
        return z.view(z.size(0), *self.reshape)


class ResBlock(nn.Module):
    def __init__(self, in_ch = 3, out_ch = 32, ker = 3):
        super().__init__()
        layers = []
        pad = ker//2 if ker > 1 else 0
        layers += [nn.Conv2d(in_channels=in_ch, out_channels=out_ch//2, kernel_size=ker, padding = pad, bias=False, stride = 1),
                   nn.InstanceNorm2d(out_ch),
                   nn.ReLU(True),
                   nn.Conv2d(in_channels=out_ch//2, out_channels=out_ch, kernel_size=ker, padding = pad, bias=False, stride = 1)
                   ]
        self.layers = nn.Sequential(*layers)
    
    def forward(self, z):
        return self.layers(z) + z

class Generator(nn.Module):
    def __init__(self, noise_size):
        super().__init__()
        self.noise_size = noise_size
        
        self.noise_block = nn.Sequential(
            *conv_block(in_ch = noise_size, out_ch = 512, ker = 1, stride = 1),
            *conv_block(in_ch = 512, out_ch = 512, ker = 1, stride = 1),
            Reshape([32, 4, 4]),
            *conv_block(in_ch = 32, out_ch = 256, ker = 1, stride = 1),
        )
        
        self.feat_block_1 = nn.Sequential(
            ResBlock(in_ch = 256, out_ch = 256, ker = 3),
            *conv_block(in_ch = 256, out_ch = 512, ker = 1, stride = 1),
            nn.PixelShuffle(2), #8x8
            *conv_block(in_ch = 128, out_ch = 256, ker = 1, stride = 1),
        )
        
        self.feat_block_2 = nn.Sequential(
            ResBlock(in_ch = 256, out_ch = 256, ker = 3),
            *conv_block(in_ch = 256, out_ch = 512, ker = 1, stride = 1),
            nn.PixelShuffle(2), #16x16
            *conv_block(in_ch = 128, out_ch = 256, ker = 1, stride = 1),
        )
        
        self.feat_block_3 = nn.Sequential(
            ResBlock(in_ch = 256, out_ch = 256, ker = 3),
            *conv_block(in_ch = 256, out_ch = 512, ker = 1, stride = 1),
            nn.PixelShuffle(2), #32x32
            *conv_block(in_ch = 128, out_ch = 256, ker = 1, stride = 1),
        )
        
        self.feat_block_4 = nn.Sequential(
            ResBlock(in_ch = 256, out_ch = 256, ker = 3),
            *conv_block(in_ch = 256, out_ch = 512, ker = 1, stride = 1),
            nn.PixelShuffle(2), #64x64
            *conv_block(in_ch = 128, out_ch = 256, ker = 1, stride = 1),
        )
        
        self.feat_block_5 = nn.Sequential(
            ResBlock(in_ch = 256, out_ch = 256, ker = 3),
            *conv_block(in_ch = 256, out_ch = 512, ker = 1, stride = 1),
            nn.PixelShuffle(2), #128x128
            *conv_block(in_ch = 128, out_ch = 256, ker = 1, stride = 1),
        )
        
        self.feat_block_6 = nn.Sequential(
            ResBlock(in_ch = 256, out_ch = 256, ker = 3),
            *conv_block(in_ch = 256, out_ch = 512, ker = 1, stride = 1),
            nn.PixelShuffle(2), #256x256
            *conv_block(in_ch = 128, out_ch = 256, ker = 1, stride = 1),
        )
        
        self.final_block = nn.Sequential(
            ResBlock(in_ch = 256, out_ch = 256, ker = 3),
            nn.Conv2d(in_channels=256, out_channels=6, kernel_size=1),
            nn.Tanh()
        )
    
    def forward(self, z):
        noise_feat = self.noise_block(z)
        
        gen_feat_1 = self.feat_block_1(noise_feat)
        gen_feat_2 = self.feat_block_2(gen_feat_1)
        gen_feat_3 = self.feat_block_3(gen_feat_2)
        gen_feat_4 = self.feat_block_4(gen_feat_3)
        gen_feat_5 = self.feat_block_5(gen_feat_4)
        gen_feat_6 = self.feat_block_6(gen_feat_5)
        
        return self.final_block(gen_feat_6)

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.model = nn.Sequential(
            *self.conv_block(in_ch = 6, out_ch = 32, ker = 3, stride = 2), # 128x128
            *self.conv_block(in_ch = 32, out_ch = 64, ker = 3, stride = 2), # 64x64
            *self.conv_block(in_ch = 64, out_ch = 128, ker = 3, stride = 2), # 32x32
            *self.conv_block(in_ch = 128, out_ch = 128, ker = 3, stride = 2), # 16x16
            *self.conv_block(in_ch = 128, out_ch = 256, ker = 3, stride = 2), # 8x8
            nn.Conv2d(in_channels=256, out_channels=1, kernel_size=2),
            nn.Flatten()
        )
    
    @staticmethod
    def conv_block(in_ch = 3, out_ch = 32, ker = 3, stride = 1, groups = 1):

        layers = []
        pad = ker//2 if ker > 1 else 0
        layers += [nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=ker, padding = pad, bias=True, stride = stride, groups = groups),
                   nn.BatchNorm2d(out_ch),
                   nn.LeakyReLU(0.2, True),
                   ]

        return layers
        
    def forward(self, I):
        return self.model(I)

In [6]:
class GAN(pl.LightningModule):

    def __init__(
        self,
        channels,
        width,
        height,
        latent_dim: int = 64,
        lr: float = 0.0001,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 32,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        # networks
        data_shape = [channels, width, height]
        self.generator = Generator(noise_size=self.hparams.latent_dim)
        self.discriminator = Discriminator()

        self.validation_z = torch.randn(2, self.hparams.latent_dim, 1, 1)

        self.example_input_array = torch.zeros(2, self.hparams.latent_dim, 1, 1)

    def forward(self, z):
        return self.generator(z)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs_diffuse, imgs_normal = batch
        imgs = torch.cat([imgs_diffuse, imgs_normal], dim = 1)

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim, 1, 1)
        z = z.type_as(imgs)

        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # adversarial loss
            g_loss = - self.discriminator(self(z)).mean()
            
            self.log('g_loss', g_loss, prog_bar=True, on_step = True)
            return g_loss

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            real_pred = self.discriminator(imgs)
            real_loss = torch.nn.ReLU()(1.0 - real_pred).mean()

            fake_pred = self.discriminator(self(z).detach())
            fake_loss = torch.nn.ReLU()(1.0 + fake_pred).mean()

            # discriminator loss is the average of these
            d_loss = real_loss + fake_loss
            
            self.log('d_loss', d_loss, prog_bar=True, on_step = True)
            return d_loss

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr*10.0, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr/10.0, betas=(b1, b2))
        return [opt_g, opt_d], []

    def on_epoch_end(self):
        z = self.validation_z.type_as(self.generated_imgs)

        # log sampled images
        sample_imgs = self(z).detach()
        sample_imgs = torch.cat([sample_imgs[:, :3, :, :], sample_imgs[:, 3:, :, :]], dim = 3)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image('generated_images', grid, 0)

In [7]:
data_fldr = r"G:\texture\data_1"
root_dir = r"E:\SP_2021\synthTEX_logging"
exp_name = "model_exp_01"
batch_size = 4

In [8]:
dm = TextureLightDataModule(data_fldr, batch_size = batch_size)
model = GAN(6, 256, 256, batch_size = batch_size)
trainer = pl.Trainer(gpus=1, max_epochs=10, progress_bar_refresh_rate=1, default_root_dir = root_dir + os.sep + exp_name)
trainer.fit(model, dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params | In sizes      | Out sizes       
-----------------------------------------------------------------------------------
0 | generator     | Generator     | 5.4 M  | [2, 64, 1, 1] | [2, 6, 256, 256]
1 | discriminator | Discriminator | 539 K  | ?             | ?               
-----------------------------------------------------------------------------------
6.0 M     Trainable params
0         Non-trainable params
6.0 M     Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…






1