## Guestimating batchsize
- Generator size = 770
- Discriminator size = 75

Therefore Cycle gan will need:

(770 + 75) * 2 = 1690 MB for float64 batch_size 1

ie, for float32 will have 845 MB for batch_size 1

for 4GB GPU we can have 4096/845 = 4.85 => max batch_size is 4

In [1]:
import os
import random
import torch
import torchvision
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.nn import functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary

from PIL import Image
import itertools
import pytorch_lightning as pl

In [2]:
from src import PatchDiscriminator, ResnetGenerator
from src.UnpairedDataset import UnpairedDataset
from src.image_pool import ImagePool

In [3]:
def set_requires_grad(nets, requires_grad):
    for net in nets:
        for param in net.parameters():
            param.requires_grad = requires_grad

In [4]:
class CycleGan(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # generator pair
        self.genX = ResnetGenerator.get_generator()
        self.genY = ResnetGenerator.get_generator()
        
        # discriminator pair
        self.disX = PatchDiscriminator.get_model()
        self.disY = PatchDiscriminator.get_model()
        
        self.lm = 10.0
        self.fakePoolA = ImagePool()
        self.fakePoolB = ImagePool()
        self.genLoss = None
        self.disLoss = None
        
    def configure_optimizers(self):
        optG = Adam(
            itertools.chain(self.genX.parameters(), self.genY.parameters()),
            lr=2e-4, betas=(0.5, 0.999))
        
        optD = Adam(
            itertools.chain(self.disX.parameters(), self.disY.parameters()),
            lr=2e-4, betas=(0.5, 0.999))
        
        gamma = lambda epoch: 1 - max(0, epoch + 1 - 100) / 101
        schG = LambdaLR(optG, lr_lambda=gamma)
        schD = LambdaLR(optD, lr_lambda=gamma)
        return [optG, optD], [schG, schD]

    def generator_training_step(self, imgA, imgB):        
        """cycle images - using only generator nets"""
        fakeB = self.genX(imgA)
        cycledA = self.genY(fakeB)
        
        fakeA = self.genY(imgB)
        cycledB = self.genX(fakeA)
        
        sameB = self.genX(imgB)
        sameA = self.genY(imgA)
        
        # generator genX must fool discrim disY
        predFakeB = self.disY(fakeB)
        bceGenB = F.binary_cross_entropy_with_logits(predFakeB, torch.ones_like(predFakeB))
        
        # generator genY must fool discrim disX
        predFakeA = self.disX(fakeA)
        bceGenA = F.binary_cross_entropy_with_logits(predFakeA, torch.ones_like(predFakeA))
        
        # compute extra losses
        identityLoss = F.l1_loss(sameA, imgA) + F.l1_loss(sameB, imgB)
        
        # compute cycleLosses
        cycleLoss = F.l1_loss(cycledA, imgA) + F.l1_loss(cycledB, imgB)
        
        # gather all losses
        extraLoss = cycleLoss + 0.5 * identityLoss
        self.genLoss = bceGenA + bceGenB + self.lm * extraLoss
        self.log('gen_loss', self.genLoss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        # store detached generated images
        self.fakeA = fakeA.detach()
        self.fakeB = fakeB.detach()
        
        return self.genLoss
    
    def discriminator_training_step(self, imgA, imgB):
        """Update Discriminator"""        
        fakeA = self.fakePoolA.query(self.fakeA)
        fakeB = self.fakePoolB.query(self.fakeB)
        
        # disX takes input for type A photos and predicts if they are fake(1) or not(0)
        predRealA = self.disX(imgA)
        bceRealA = F.binary_cross_entropy_with_logits(predRealA, torch.zeros_like(predRealA))
        
        predFakeA = self.disX(fakeA)
        bceFakeA = F.binary_cross_entropy_with_logits(predFakeA, torch.ones_like(predFakeA))
        
        # disY takes input for type B photos and predicts if they are fake(1) or not(0)
        predRealB = self.disY(imgB)
        bceRealB = F.binary_cross_entropy_with_logits(predRealB, torch.zeros_like(predRealB))
        
        predFakeB = self.disY(fakeB)
        bceFakeB = F.binary_cross_entropy_with_logits(predFakeB, torch.ones_like(predFakeB))
        
        # gather all losses
        self.disLoss = 0.5 * (bceFakeA + bceRealA + bceFakeB + bceRealB)
        self.log('dis_loss', self.disLoss.item(), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return self.disLoss
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        imgA, imgB = batch['A'], batch['B']
        discriminator_requires_grad = (optimizer_idx==1)
        set_requires_grad([self.disX, self.disY], discriminator_requires_grad)
        
        if optimizer_idx == 0:
            return self.generator_training_step(imgA, imgB)
        else:
            return self.discriminator_training_step(imgA, imgB)        

In [5]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(286, Image.BICUBIC),
        transforms.RandomCrop(256),
        transforms.ToTensor(),
        transforms.Normalize([.5, .5, .5], [.5, .5, .5])
    ]),
    'test': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([.5, .5, .5], [.5, .5, .5])
    ]),
}

In [6]:
root = 'C:/Users/Deepak H R/Desktop/data/monet2photo/'
train = UnpairedDataset(root, 'train', transforms=data_transforms['train'])
train = DataLoader(train, batch_size=4, shuffle=True)
batch = next(iter(train))

Found 1072 images of trainA and 6287 images of trainB


In [None]:
model = CycleGan()
trainer = pl.Trainer(gpus=1, precision=16, max_epochs=1)
trainer.fit(model, train)