In [None]:
### For google colab

# !pip install wandb
# !wget https://huggingface.co/datasets/student/celebA/resolve/main/Dataset.zip?download=true
# !unzip -q /content/Dataset.zip?download=true

In [4]:
import os
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data as data_utils
import wandb
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath("__file__"))))
import custom_loaders
from conv_layers import UpTranspose2d, DownConv2d, UpTranspose2d, UpSampleConv

In [None]:
class GeneratorTrans(nn.Module):

    def __init__(self, channels, kernelSize=4):
        super().__init__()

        self.gen = nn.ModuleList([UpTranspose2d(channels[i], channels[i+1], kernelSize) for i in range(len(channels) - 2)])
        self.output = nn.ConvTranspose2d(channels[-2], channels[-1], kernel_size=kernelSize, stride = 2, padding=(kernelSize//2 - 1))

    def forward(self, image):
            
        for block in self.gen:
            image = block(image)
    
        output = torch.tanh(self.output(image))
        return output

class GeneratorUpSample(nn.Module):

    def __init__(self, channels, kernelSize=3):
        super().__init__()

        self.gen = nn.ModuleList([UpSampleConv(channels[i], channels[i+1], kernelSize) for i in range(len(channels) - 2)])
        self.upSample = nn.Upsample(scale_factor=2, mode='nearest') 
        self.output = nn.Conv2d(channels[-2], channels[-1], kernel_size=kernelSize, padding = (kernelSize-1)//2)


    def forward(self, image):

        for block in self.gen:
            image = block(self.upSample(image))

        output = torch.tanh(self.output(self.upSample(image)))
        return output
    
class Discriminator(nn.Module):
    
    ## Don't use sigmoid in the last layer of the discriminator, use BCEWithLogitsLoss instead of BCELoss.
    ## https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
    
    def __init__(self, channels, kernelSize=4):
        super().__init__()

        self.dis = nn.ModuleList([DownConv2d(channels[i], channels[i+1], kernelSize) for i in range(len(channels) - 2)])
        self.out = nn.Conv2d(in_channels=channels[-2], out_channels=channels[-1], kernel_size=kernelSize, stride = 2, padding = kernelSize//2 - 1)

    def forward(self, image):
        
        for block in self.dis:
            image = block(image)
        
        output = self.sig(self.out(image))
        return output

In [5]:
latent_len = 100
img_size = 64
n_channels = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
channelsG = [latent_len, 256, 128, 128,64, 32, n_channels]
channelsD = [n_channels, 32, 64, 128, 128,256, 1]
label_flip = 0
add_noise = 0

G_lr = 0.0002
D_lr = 0.0002
epochs = 20
D_epochs = 1


assert(len(channelsD) == len(channelsG))
assert(img_size == 2**(len(channelsD) - 1))

In [6]:
class Args():
    def __init__(self):
        self.dataset = 'GAN'
        self.imagePath = '/content/Dataset/CelebA_train/img_align_celeba'
        self.image_size = img_size
        self.download = False
        self.imgC = n_channels
        self.num_images = 20000
        self.convert2bw = False

args = Args()
print("Loading data...")
train_dataset = custom_loaders.get_data_loader(args)

Loading data...


100%|██████████| 20000/20000 [00:30<00:00, 660.25it/s]


In [11]:
bs = 32
train_loader = data_utils.DataLoader(train_dataset, batch_size=bs, shuffle=True)

G = conv_layers.GeneratorTrans(channelsG).to(device)
D = conv_layers.Discriminator(channelsD).to(device)
# G = conv_layers.GeneratorUpSample(channelsG).to(device)

criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(G.parameters(), lr=G_lr, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(D.parameters(), lr=D_lr, betas=(0.5, 0.999))
fixed_noise = torch.rand(bs,latent_len,1,1).to(device)

In [None]:
config={"epochs": epochs, "batch_size": bs,
         "D_epochs": D_epochs, "D_lr": D_lr, "G_lr": G_lr,
           "img_size": img_size, "n_channels": n_channels,
           "latent_len": latent_len}

wandb.init(project='pytorch-gan-celeba', entity='basujindal123', config=config)

In [12]:
log_iter = 200
log = True

lossD_Real = 0
lossD_Fake = 0
lossG = 0
iter = 0


for i in (range(epochs)):
    for data in tqdm(train_loader):
        real_imgs = data.to(device)

        iter+=1
        # Training Discriminator
        D.zero_grad()

        with torch.no_grad():
            z = torch.rand(bs,latent_len,1,1).to(device)
            fake_imgs = G(z)

        output = D(fake_imgs).flatten()

        label_val = 0
        ## randomly flip labels
        if label_flip and np.random.random() > 0.95:
            label_val = 1

        fake_labels = np.array([label_val]*output.shape[0])

        if add_noise:
            fake_labels = fake_labels + np.random.normal(0,0.05,fake_labels.shape[0])
        labels = torch.tensor(fake_labels).float().to(device)
        lossF = criterion(output, labels)
        lossF.backward()

        real_imgs = real_imgs.to(device)
        output = D(real_imgs).flatten()

        label_val = 1
        ## randomly flip labels
        if label_flip and np.random.random() > 0.95:
            label_val = 0

        real_labels = np.array([label_val]*output.shape[0])

        if add_noise:
            real_labels = real_labels + np.random.normal(0,0.05,real_labels.shape[0])
        labels = torch.tensor(real_labels).float().to(device)
        lossR = criterion(output, labels)
        lossR.backward()


        lossD = lossR + lossF
        optimizerD.step()

        lossD_Real+=lossR.item()
        lossD_Fake+=lossF.item()

        # if((iter+1)%D_epochs == 0):
        if 1:
            ## Training Generator
            G.train()
            optimizerG.zero_grad()
            optimizerD.zero_grad()
            z = torch.rand(bs,latent_len,1,1).to(device)
            fake_imgs = G(z)
            output = D(fake_imgs).flatten()

            label = torch.tensor([1]*output.shape[0]).float().to(device)
            lossG = criterion(output, label)
            lossG.backward()
            optimizerG.step()
            lossG = lossG.item()


        if((iter+1)%log_iter == 0 and log==True):

            G.eval()
            with torch.no_grad():
                fixed_fake_imgs = G(fixed_noise[:16]).detach()

            wandb.log({
                'lossG': lossG,
                'lossD_Real': lossD_Real,
                'lossD_Fake': lossD_Fake,
                'lossD': lossD_Real + lossD_Fake,
                'Fake Images': [wandb.Image(i) for i in fixed_fake_imgs],
                'Real Images' : [wandb.Image(i) for i in real_imgs[:16].detach()]
                })

            lossD_Real = 0
            lossD_Fake = 0

100%|██████████| 625/625 [00:15<00:00, 39.60it/s]
100%|██████████| 625/625 [00:15<00:00, 40.71it/s]
100%|██████████| 625/625 [00:15<00:00, 40.23it/s]
100%|██████████| 625/625 [00:17<00:00, 35.09it/s]
100%|██████████| 625/625 [00:18<00:00, 33.46it/s]
100%|██████████| 625/625 [00:16<00:00, 36.97it/s]
100%|██████████| 625/625 [00:16<00:00, 37.56it/s]
100%|██████████| 625/625 [00:19<00:00, 32.37it/s]
100%|██████████| 625/625 [00:18<00:00, 33.43it/s]
100%|██████████| 625/625 [00:17<00:00, 35.56it/s]
100%|██████████| 625/625 [00:15<00:00, 40.57it/s]
100%|██████████| 625/625 [00:19<00:00, 32.63it/s]
100%|██████████| 625/625 [00:15<00:00, 39.70it/s]
100%|██████████| 625/625 [00:15<00:00, 40.68it/s]
100%|██████████| 625/625 [00:15<00:00, 39.53it/s]
100%|██████████| 625/625 [00:15<00:00, 39.84it/s]
100%|██████████| 625/625 [00:18<00:00, 33.02it/s]
100%|██████████| 625/625 [00:16<00:00, 36.79it/s]
100%|██████████| 625/625 [00:17<00:00, 35.63it/s]
100%|██████████| 625/625 [00:17<00:00, 34.90it/s]
