In [1]:
import os
import random
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary

from PIL import Image
import pytorch_lightning as pl

In [2]:
from src import PatchDiscriminator, ResnetGenerator
from src.UnpairedDataset import UnpairedDataset

In [3]:
# model = PatchDiscriminator.get_model()
# summary(model, input_size=(3, 256, 256), device='cpu')

In [4]:
# from src import ResnetGenerator
# model = ResnetGenerator.get_generator()
# summary(model, input_size=(3, 256, 256), device='cpu')

## Guestimating batchsize possible
- Generator size = 770
- Discriminator size = 75

Therefore Cycle gan will need:

(770 + 75) * 2 = 1690 MB for float64 batch_size 1

ie, for float32 will have 845 MB for batch_size 1

for 4GB GPU we can have 4096/845 = 4.85 => max batch_size is 4

In [5]:
def ganLoss(imgA, isRealA, fakeB, isFakeB, cycledA):
    """Computes the loass
        isRealA, isFakeB -> Tensors of shape [b 1 30 30]
        imgA, fakeB, cycledA -> Tensors of shape [b 3 256 256]
    """
    
    # generator must fool the discriminator
    discLoss = F.binary_cross_entropy_with_logits(isFakeB, torch.ones_like(isFakeB))
    
    # cycledA and imgA must be same
    cycleLoss = F.l1_loss(cycledA, imgA)
    

In [6]:
class CycleGan(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # generator pair
        self.genX = ResnetGenerator.get_generator()
        self.genY = ResnetGenerator.get_generator()
        
        # discriminator pair
        self.disX = PatchDiscriminator.get_model()
        self.disY = PatchDiscriminator.get_model()
        
        # losses defined here
        
    def cycle_image(self, imgA, name='A'):
        genX, genY = self.genX, self.genY
        disX, disY = self.disX, self.disY
        
        if name=='B': # swap the nets
            genX, genY = genY, genX
            disX, disY = disY, disX
        
        # reinforce that A is a real image according to disX
        isRealA = disX(imgA)
        
        # fakeB must be such that disY must be fooled
        fakeB = genX(imgA)
        isFakeB = disY(fakeB)
        
        # cycledA should be such that disX must be fooled
        cycledA = genY(fakeB)
        
        L = "isRealA, fakeB, isFakeB, cycledA".split(", ")
        V = isRealA, fakeB, isFakeB, cycledA
        for name, val in zip(L, V):
            print(name, val.shape)
        return isRealA, fakeB, isFakeB, cycledA
        

    def training_step(self, batch, batch_idx):
        """
        Dataloader will feed batch like so
        {
            'A': imgA, 'pathA': pathA,
            'B': imgB, 'pathB': pathB
        }
        
        genX: A->fakeB, genY: B->fakeA
        """
        imgA, imgB = batch['A'], batch['B']
        
        # we get 8 return values
        isRealA, fakeB, isFakeB, cycledA = self.cycle_image(imgA, name='A')
        isRealB, fakeA, isFakeA, cycledB = self.cycle_image(imgB, name='B')
        
        # ensure tint of the image is maintained using
        # identity loss |A - sameA| + |B - sameB|
        # this also ensures that generators dont change the image too much
        sameB = self.genX(B)
        sameA = self.genY(A)

In [7]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [8]:
root = 'C:/Users/Deepak H R/Desktop/data/monet2photo/'
train = UnpairedDataset(root, 'train', transforms=data_transforms['train'])
train = DataLoader(train, batch_size=2, shuffle=True)
batch = next(iter(train))

Found 1072 images of trainA and 6287 images of trainB


In [9]:
model = CycleGan()
with torch.no_grad():
    isRealA, fakeB, isFakeB, cycledA = model.cycle_image(batch['A'], name='A')

isRealA torch.Size([2, 1, 30, 30])
fakeB torch.Size([2, 3, 256, 256])
isFakeB torch.Size([2, 1, 30, 30])
cycledA torch.Size([2, 3, 256, 256])


In [10]:
batch.keys()

dict_keys(['A', 'pathA', 'B', 'pathB'])