<a href="https://colab.research.google.com/github/drago467/Generative-AI-networks/blob/main/Advanced_Generative_Architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# importing the libraries
import torch, torchvision, os, PIL, pdb
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import pdb

def show(tensor, num = 25, wandb = 0, name = ''):
  data = tensor.detach().cpu()
  grid = make_grid(data[:num], nrow = 5).permute(1, 2, 0)

  plt.imshow(grid, clip(0, 1))
  plt.show()

### hyperparameters and general parameters

n_epochs = 10000
batch_size = 128
lr = 1e-4
z_dim = 200
device = 'cuda' #GPU

cur_step = 0
crit_cycles = 5

In [7]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, z_dim=64, d_dim=16):
        super(Generator, self).__init__()
        self.z_dim = z_dim

        self.gen = nn.Sequential(
            nn.ConvTranspose2d(z_dim, d_dim * 32, 4, 1, 0),
            nn.BatchNorm2d(d_dim * 32),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim * 32, d_dim * 16, 4, 2, 1),
            nn.BatchNorm2d(d_dim * 16),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim * 16, d_dim * 8, 4, 2, 1),
            nn.BatchNorm2d(d_dim * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim * 8, d_dim * 4, 4, 2, 1),
            nn.BatchNorm2d(d_dim * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim * 4, d_dim * 2, 4, 2, 1),
            nn.BatchNorm2d(d_dim * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(d_dim * 2, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, noise):
      x = noise.view(len(noise), self.z_dim, 1, 1) # 128 x 200 x 1 x 1
      return self.gen(x)

def gen_noise(num, z_dim, device = 'cuda'):
  return torch.randn(num, z_dim, device = device)


In [8]:
class Critic(nn.Module):
  def __init__(self, d_dim = 16):
    super(Critic, self).__init__()

    self.crit = nn.Sequential(
        nn.Conv2d(3, d_dim, 4, 2, 1),
        nn.InstanceNorm2d(d_dim),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim, d_dim*2, 4, 2, 1),
        nn.InstanceNorm2d(d_dim*2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*2, d_dim*4, 4, 2, 1),
        nn.InstanceNorm2d(d_dim*4),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*4, d_dim*8, 4, 2, 1),
        nn.InstanceNorm2d(d_dim*8),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*8, d_dim*16, 4, 2, 1),
        nn.InstanceNorm2d(d_dim*16),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*16, 1, 4, 1, 0),
        nn.Sigmoid()
    )

  def forward(self, image):
    crit_pred = self.crit(image)
    return crit_pred.view(len(crit_pred), -1)