<a href="https://colab.research.google.com/github/beaumahadev/GAN_image_gen/blob/main/CELEB_GENERATOR_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Advanced GAN


import torch, torchvision, os, PIL, pdb
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

#show the grid of images
def show(tensor, num= 25, name=''):
  data = tensor.detach().cpu()
  grid = make_grid(data[:num], nrow=5).permute(1,2,0)
  plt.imshow(grid.clip(0,1))
  plt.show()

### hyperparameters and general parameters
n_epochs=10000
batch_size=128
lr=1e-4
z_dim=200
device='cuda' #GPU
cur_step=0
#5 cycles of training for the crit every one cycle for the generator, because otherwise the gen might overpower the critic (fool the critic too early)
crit_cycles=5
gen_losses=[]
crit_losses=[]
show_step=35
save_step=35


In [2]:
# generator model

class Generator(nn.Module):
  def __init__(self, z_dim=64, d_dim=16):
    super(Generator, self).__init__()
    self.z_dim = z_dim

    self.gen = nn.Sequential(
        ## ConvTranspose2d: in_channels, out_channels, kernel_size, stride=1, padding=0
        ## Calulating new width and height: (n-1)*stride - 2*padding + ks
        ## input channels in the dimensionality of the latent space
        ## start with z_dim # of channels and 1 pixel image, decrease channels and increase size
        nn.ConvTranspose2d(z_dim, d_dim *32, 4, 1, 0), # 4x4 0-0+4 (ch: 200, 512)
        nn.BatchNorm2d(d_dim * 32),
        nn.Relu(True),

        nn.ConvTranspose2d(z_dim*32, d_dim*16, 4, 2, 1), ## 8x8 (ch: 512, 256)
        nn.BatchNorm2d(d_dim*16),
        nn.Relu(True),

        nn.ConvTranspose2d(d_dim*16, d_dim*8, 4, 2, 1), ## 16x16 (ch: 256, 128)
        nn.BatchNorm2d(d_dim*8),
        nn.Relu(True),

        nn.ConvTranspose2d(d_dim*8, d_dim*4, 4, 2, 1), ## 32x32 (ch: 128, 64)
        nn.BatchNorm2d(d_dim*4),
        nn.Relu(True),

        nn.ConvTranspose2d(d_dim*4, d_dim*2, 4, 2, 1), ## 64x64 (ch: 64, 32)
        nn.BatchNorm2d(d_dim*2),
        nn.Relu(True),

        nn.ConvTranspose2d(d_dim*2, 3, 4, 2, 1), ## 128x128 (ch: 32, 3)
        nn.Tanh() ### produce a result in the range from -1, 1

    )

  def forward(self, noise):
    x= noise.view(len(noise), self.z_dim, 1, 1) # 128 x 200 x 1 x 1
    return self.gen(x)

def gen_noise(num, z_dim, device="cuda"):
  return torch.random(num, z_dim, device=device) # 128 x 200 noise vector (dimensionality of the latent space)


In [3]:
## critic model


class Critic(nn.Module):
  def __init__(self, d_dim=16):
    super(Critic, self).__init__()

    self.crit  = nn.Sequential(
        # Conv2d: in_channels, out_channels, kernel_size, stride=1 padding=0
        ## New width and height: ((n+2*pad-ks)//stride) +1
        nn.Conv2d(3, d_dim, 4, 2, 1), #  (128+2*(1-4)//2)+1 = 64
        nn.InstanceNorm2d(d_dim), #Normalization is a good way to stabilize numbers between layers of the neural net
        ## We can normalize by batch, channel, or instance. For critic instance works best
        nn.LeakyReLU(0.2), ## Avoid dying Relu

        nn.Conv2d(d_dim, d_dim*2, 4, 2, 1),
        nn.InstanceNorm2d(d_dim*2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*2, d_dim*4, 4, 2, 1), ## 16x16 (ch:32->64)
        nn.InstanceNorm2d(d_dim*4),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*4, d_dim*8, 4, 2, 1), ## 8x8 (ch:64->128)
        nn.InstanceNorm2d(d_dim*8),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*8, d_dim*16, 4, 2, 1), ## 4x4 (ch:128->256)
        nn.InstanceNorm2d(d_dim*16),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*16, 1, 4, 1, 0), #do this to get the right final size
        # final size: (4+2*0-4)//1 +1 = 1x1
        # channels: 256, 1
    )

  def forward(self, image):
    # image: 128 batch x 3 channels x 128 w x 128 h
    crit_pred = self.crit(image) # 128 batch, 1 channel, 1 w, 1 h (128 values for each image in the batch
    return crit_pred.view(len(crit_pred), -1)


In [4]:
#Example of overwriting pytorch initial weights

def __init_weights(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
    torch.nn.init.constant_(m.bias, 0)

  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
    torch.nn.init.constant_(m.bias, 0)

# example gen = gen.apply(init_weights)

In [None]:
#### Training dataset download address:
# Celebra gdrive: https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg?resourcekey=0-rJlzl934LzC-Xp28GeIBzQ
# Kaggle: https://www.kaggle.com/jessicali9530/celeba-dataset

In [None]:
### Dataset, DataLoader, declare gen,crit, test dataset

class Dataset(Dataset):
  def __init__(self, path, size=128, lim=10000):
    self.sizes=[size, size]
    items, labels=[],[]

    for data in os.listdir(path)[:lim]:
      #path: './data/celeba/img_align_celeba'
      #data: '114568.jpg
      item = os.path.join(path,data)
      items.append(item)
      labels.append(data)
    self.items=items
    self.labels=labels


  def __len__(self):
    return len(self.items)

  def __getitem__(self,idx):
    data = PIL.Image.open(self.items[idx]).convert('RGB') # (178,218)
    data = np.asarray(torchvision.transforms.Resize(self.sizes)(data)) # 128 x 128 x 3
    data = np.transpose(data, (2,0,1)).astype(np.float32, copy=False) # 3 x 128 x 128 # from 0 to 255
    data = torch.from_numpy(data).div(255) # from 0 to 1
    return data, self.labels[idx]

## Dataset
data_path='./data/celeba/img_align_celeba'
ds = Dataset(data_path, size=128, lim=10000)

## DataLoader
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)

## Models
gen = Generator(z_dim).to(device)
crit = Critic().to(device)

## Optimizers
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0.5,0.9))
crit_opt= torch.optim.Adam(crit.parameters(), lr=lr, betas=(0.5,0.9))

## Initializations
##gen=gen.apply(init_weights)
##crit=crit.apply(init_weights)

x,y=next(iter(dataloader))
show(x)
