# Pytorch GANs

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard


class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            # Depending of images normalization:
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
            
        )

    def forward(self, x):
        return self.gen(x)


# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Noise:
z_dim = 64
# Image:
image_dim = 28 * 28 * 1  # 784


# Define the Models
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

# Optimizers
lr = 3e-4
opt_disc = optim.Adam(disc.parameters(), lr=lr) 
opt_gen = optim.Adam(gen.parameters(), lr=lr)

# Loss
criterion = nn.BCELoss()


# Dataset
## Transforms:
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
)
# Dataset
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
# Dataloader
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

batch_size = 32
num_epochs = 50

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )


# Pytorch DCGANs

In [None]:
# Train loop Option 2:

# G (Generator) and D (Discriminator) are models 
g.to(device)
d.to(device)

g_optimizer = torch.optim.Adam(g.parameters(), lr=3e-4)
d_optimizer = torch.optim.Adam(d.parameters(), lr=3e-4)

crit = nn.BCEWithLogitsLoss()
g_loss, d_loss = [], []


for epoch in mb:
    for X, y in iter(dataloader):
        
        #X, y = X.to(device), y.to(device)  
        
        # train discriminator
        g.eval()
        d.train()
        #   generate random noise for input
        noise = torch.randn((X.size(0), g.input_size)).to(device)
        genenerated_images = g(noise)
        #   input of discrminator
        d_input = torch.cat([genenerated_images, X.view(X.size(0), -1)])
        #   gorund truth for discriminator
        d_gt = torch.cat([torch.zeros(X.size(0)), torch.ones(X.size(0))]).view(-1,1).to(device)
        #   optimization
        d_optimizer.zero_grad()
        d_output = d(d_input)
        d_l = crit(d_output, d_gt)
        d_l.backward()
        d_optimizer.step()
        d_loss.append(d_l.item())
        
        
        # train generator
        g.train()
        d.eval()
        #  generate fake images
        noise = torch.randn((X.size(0), g.input_size)).to(device)
        genenerated_images = g(noise)
        #  test those images in discrimitaor
        d_output = d(genenerated_images)
        #   gorund truth for generator
        g_gt = torch.ones(X.size(0)).view(-1,1).to(device)
        #   optimization of generator
        g_optimizer.zero_grad()
        g_l = crit(d_output, g_gt)
        g_l.backward()
        g_optimizer.step()
        
        g_loss.append(g_l.item())
        