### Setup

In [None]:
import torch

from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from torchvision.utils import make_grid


### Dataset

In [None]:
SRC_IMG_HEIGHT, SRC_IMG_WIDTH = 208, 176
RESIZE = (SRC_IMG_HEIGHT, SRC_IMG_WIDTH)
RESIZE = None
RESIZE = 128, 128

In [None]:
rm_dataset.class_to_idx

In [None]:
rm_transforms = []
rm_transforms += [transforms.Grayscale()]

if RESIZE is not None:
    
    rm_transforms += [transforms.Resize(RESIZE)]
rm_transforms += [transforms.ToTensor()]
rm_transforms += [transforms.Normalize([0.5], [0.5])]
print(rm_transforms)

rm_dataset = datasets.ImageFolder(root="data/Alzheimer_s Dataset/train/",
                               transform=transforms.Compose(rm_transforms))

dataloader = DataLoader(
    rm_dataset,
    batch_size=32,
    shuffle=True,
)

In [None]:
dataloader = DataLoader(
    datasets.MNIST(
        "/tmp/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(RESIZE), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=32,
    shuffle=True,
)

### GAN implementation

In [None]:
import numpy as np
import math
from tqdm import tqdm, trange

from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

if RESIZE is not None:
    img_shape = 1, RESIZE[0], RESIZE[1]
else:
    img_shape = 3, SRC_IMG_HEIGHT, SRC_IMG_WIDTH
print('img_shape:', img_shape)

class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(self.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


class GAN:
  def __init__(self, b1=0.5, b2=0.999, lr=0.0002, latent_dim=100, n_epochs=100):
    
    self.b1 = b1
    self.b2 = b2
    self.lr = lr
    self.n_epochs = n_epochs
    self.latent_dim = latent_dim
    self.cuda = True if torch.cuda.is_available() else False

    self.adversarial_loss = torch.nn.BCELoss()
    self.generator = Generator(latent_dim)
    self.discriminator = Discriminator()


    if self.cuda:
        self.generator.cuda()
        self.discriminator.cuda()
        self.adversarial_loss.cuda()

  def fit(self, dataloader):
    optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.b1, self.b2))
    optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.b1, self.b2))

    Tensor = torch.cuda.FloatTensor if self.cuda else torch.FloatTensor
    logger = trange(self.n_epochs, desc=f"Epoch: 0, G_Loss: 0, D_Loss: 0")
    
    for epoch in logger:
        for i, (imgs, _) in enumerate(dataloader):
            #print('imgs.size(0)', imgs.shape)
            valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
            real_imgs = Variable(imgs.type(Tensor))

            optimizer_G.zero_grad()

            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], self.latent_dim))))

            gen_imgs = self.generator(z)

            g_loss = self.adversarial_loss(self.discriminator(gen_imgs), valid)

            g_loss.backward()
            optimizer_G.step()

            optimizer_D.zero_grad()

            real_loss = self.adversarial_loss(self.discriminator(real_imgs), valid)
            fake_loss = self.adversarial_loss(self.discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()
            
  def generate(self, n_samples):
    Tensor = torch.cuda.FloatTensor if self.cuda else torch.FloatTensor
    z = Variable(Tensor(np.random.normal(0, 1, (n_samples, self.latent_dim))))
    return self.generator(z)

              

##### Learn

In [None]:
gan = GAN(n_epochs=1, latent_dim=RESIZE[0])
gan.fit(dataloader)

##### Generate/Plot

In [None]:
with torch.no_grad():
  grid_img = make_grid(gan.generate(5), nrow=5)
  plt.imshow(grid_img.permute(1,2,0).cpu().detach().numpy())


In [None]:
gan2 = GAN(n_epochs=10)
gan2.fit(dataloader)

In [None]:
with torch.no_grad():
  grid_img = make_grid(gan2.generate(5), nrow=5)
  plt.imshow(grid_img.permute(1,2,0).cpu().detach().numpy())


In [None]:
grid_img.permute(1,2,0)