ResNet & PatchGAN

In [1]:
import torch 
import torch.nn as nn 
import random 

In [2]:
class ResidualBlock(nn.Module):
    def __init__(self , in_fts):
        super(ResidualBlock , self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels = in_fts , out_channels = in_fts , kernel_size=3),
            nn.InstanceNorm2d(in_fts),
            nn.ReLU(inplace = True),

            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels = in_fts , out_channels= in_fts , kernel_size = 3),
            nn.InstanceNorm2d(in_fts)
        )


    def forward (self , x):
        return x +self.block(x)

In [3]:
class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks=6):
        super(GeneratorResNet, self).__init__()
        channels = input_shape[0]
        
        # Initial Convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling: 64 -> 128 -> 256
        in_features = 64
        for _ in range(2):
            out_features = in_features * 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features

        # Residual Blocks: stays at 256
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling: 256 -> 128 -> 64
        # The error happened here. We must ensure in_features starts at 256.
        for _ in range(2):
            out_features = in_features // 2
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features

        # Output Layer
        model += [nn.ReflectionPad2d(3), nn.Conv2d(64, channels, 7), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
'''
class GeneratorResNet(nn.Module):
    def __init__(self , input_shape , num_residual_blocs = 6):
        super(GeneratorResNet , self ).__init__()
        channels = input_shape[0]

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels = channels , out_channels = 64 , kernel_size = 7 ),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace = True)
        ]

        in_fts = 64

        for _ in range(2):
            
            out_fts = in_fts * 2
            model += [
                nn.Conv2d(in_channels = in_fts , out_channels = out_fts , kernel_size = 3 , stride = 2 , padding = 1 ),
                nn.InstanceNorm2d(out_fts),
                nn.ReLU(inplace = True)
            ]
            in_fts = out_fts
             
        for _ in range(num_residual_blocs):
            model += [ResidualBlock(in_fts)]

        for _ in range(2):
            out_fts = in_fts //2
            model += [
                nn.ConvTranspose2d(in_channels = in_fts , out_channels = out_fts , kernel_size = 3 , stride  = 2 , padding = 1 , output_padding = 1),
                nn.InstanceNorm2d(out_fts),
                nn.ReLU(inplace = True)
            ] 
        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(in_channels = 64 , out_channels = channels , kernel_size = 7),
                  nn.Tanh()
                  ] 
        self.model = nn.Sequential(*model)

    def forward(self , x):
        return self.model(x)

'''

In [4]:
class Discriminator(nn.Module):
    def __init__(self , input_shape):
        super(Discriminator , self ).__init__()
        channels , H , W = input_shape

        def discriminator_block(in_filters , out_filters , normalize = True):
            layers = [nn.Conv2d(in_channels = in_filters ,out_channels = out_filters , kernel_size  = 4 , stride = 2  , padding = 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2 , inplace = True))
            return layers 
        
        self.model = nn.Sequential(
            *discriminator_block(channels , 64 , normalize = False),
            *discriminator_block(64 , 128),
            *discriminator_block(128 , 256),
            *discriminator_block(256 , 512),

            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(512 , 1 , 4, padding = 1)
        )

    def forward(self , img):
        return self.model(img)


Utilities (Buffer & Visualization)

In [None]:
import os

import torchvision.utils as vutils

class ReplayBuffer:
    def __init__(self , max_size = 50):
        self.max_size = max_size
        self.data = []

    def push_and_pop(self , data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element , 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform( 0, 1 ) >0.5:
                    i = random.randint(0 , self.max_size -1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element

                else:
                    to_return.append(element)
        return torch.cat(to_return)
    

def save_sample(epoch , batch_i , real_A , real_B , G_AB , G_BA):
    G_AB.eval() ; G_BA.eval()

    with torch.no_grad():
        fake_B , fake_A = G_AB(real_A) , G_BA(real_B)
        recov_A , recov_B = G_BA(fake_B) , G_AB(fake_A)

        grid = torch.cat((real_A[:4] , fake_B[:4], recov_A[:4] , real_B[:4] , fake_A[:4] , recov_B[:4]) , 0 )
        os.makedirs("output_img_Cycle_GAN", exist_ok=True)
        vutils.save_image((grid + 1.0) / 2, f"output_img_Cycle_GAN/output_epoch_{epoch}.png", nrow=4)
    G_AB.train() ; G_BA.train()

Training Loop

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_shape = (3 , 32 , 32)

G_AB = GeneratorResNet(input_shape).to(device)
G_BA = GeneratorResNet(input_shape).to(device)
D_A = Discriminator(input_shape).to(device)
D_B = Discriminator(input_shape).to(device)

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

opt_G = torch.optim.Adam(list(G_AB.parameters()) + list(G_BA.parameters()) , lr = 0.0002 , betas = (0.5 , 0.999))
opt_DA = torch.optim.Adam(D_A.parameters() , lr = 0.0002 , betas = (0.5,0.999))
opt_DB = torch.optim.Adam(D_B.parameters() , lr = 0.0002 , betas = (0.5 , 0.999))

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()


def train(dataloader , epochs = 200):
    for epoch in range(epochs):
        for i , batch in enumerate(dataloader):

            real_A = batch['A'].to(device)
            real_B = batch['B'].to(device)

            valid = torch.ones((real_A.size(0) , 1 , 2 , 2)).to(device)
            fake = torch.zeros((real_A.size(0) , 1 , 2 , 2)).to(device)

            opt_G.zero_grad()
            loss_id = (criterion_identity(G_BA(real_A) , real_A) + criterion_identity(G_AB(real_B) , real_B)) /2

            fake_B , fake_A = G_AB(real_A) , G_AB(real_B)
            loss_GAN = (criterion_GAN(D_B(fake_B) , valid) + criterion_cycle(G_AB(fake_A), real_B)) /2
            loss_cycle = (criterion_cycle(G_BA(fake_B) , real_A) + criterion_cycle(G_AB(fake_A) , real_B)) /2

            loss_G = loss_GAN + (10.0 * loss_cycle) + (5.0 * loss_id)
            loss_G.backward()
            opt_G.step()

            opt_DA.zero_grad()
            f_A_buff = fake_A_buffer.push_and_pop(fake_A)
            loss_DA = (criterion_GAN(D_A(real_A), valid) + criterion_GAN(D_A(f_A_buff.detach()) , fake)) /2
            loss_DA.backward();opt_DB.step()


            if i % 100 == 0:
                print(f"Epoch {epoch} | Batch{i} | Loss D_A: {loss_DA.item()} | Loss G: {loss_G.item()}")
                save_sample(epoch , i , real_A , real_B , G_AB , G_BA)


Data Preparation

In [7]:
from torch.utils.data import Dataset
import random

class UnpairedDataset(Dataset):
    def __init__(self, data_A, data_B, transform=None):
        self.transform = transform
        self.files_A = data_A
        self.files_B = data_B

    def __getitem__(self, index):
        image_A = self.files_A[index % len(self.files_A)]

        image_B = self.files_B[random.randint(0, len(self.files_B) - 1)]

        if self.transform:
            image_A = self.transform(image_A)
            image_B = self.transform(image_B)

        return {"A": image_A, "B": image_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [8]:
from torchvision import datasets , transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(), 
    transforms.Normalize((0.5 , 0.5 , 0.5) , (0.5 ,0.5 , 0.5))

])


full_ds =  datasets.CIFAR10(root = './data' , download = True , transform = transform)
data_A = [img for img , label in full_ds if label ==0]
data_B = [img for img , label in full_ds if label ==1]

dataset = UnpairedDataset(data_A , data_B , transform = None)
dataloader = DataLoader(dataset , batch_size = 4 , shuffle = True)

In [9]:
train(dataloader , epochs = 200)

Epoch 0 | Batch0 | Loss D_A: 0.7149485349655151 | Loss G: 8.337106704711914


KeyboardInterrupt: 