# GAN synthetic image generation - TargetClass - 256 size - grayscale

## Import libraries

In [None]:
# Put these at the top of every notebook, to get automatic reloading and inline plotting
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# setup CUDA_VISIBLE DEVICES for titan.sci.utah.edu
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
#Import libraries

from fastai.imports import *
from fastai.conv_learner import *
from fastai.transforms import *

import numpy as np
import pandas as pd

In [None]:
# Parameters and hyper-parameters

path = '~/Project_SEM/Project_GAN/data'

csv_full = os.path.join(os.getcwd(),'Dataset_GAN_TargetClass_size256_Analysis.csv')
csv = os.path.join(os.getcwd(),'Dataset_GAN_TargetClass_size256_Analysis_sample.csv')
# Image size
sz = 256
# Batch size
bs = 100
# nz: size of input vector for Generator?
nz = 300
# Learning rate
lr = 1e-3

In [None]:
TMP_PATH = Path(os.path.join(os.getcwd(),'tmp'))
TMP_PATH.mkdir(exist_ok=True)

In [None]:
# read CSV file - create sample dataset
df = pd.read_csv(csv_full, sep=',')
df = df.sample(frac=0.5).reset_index(drop=True)
df.to_csv(csv)

In [None]:
csv = 

## WGAN

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, ni, no, ks, stride, bn=True, pad=None):
        super().__init__()
        if pad is None: pad = ks//2//stride
        self.conv = nn.Conv2d(ni, no, ks, stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm2d(no) if bn else None
        self.relu = nn.LeakyReLU(0.2, inplace=True)
    
    def forward(self, x):
        x = self.relu(self.conv(x))
        return self.bn(x) if self.bn else x

In [None]:
class DCGAN_D(nn.Module):
    def __init__(self, isize, nc, ndf, n_extra_layers=0):
        super().__init__()
        assert isize % 16 == 0, "isize has to be a multiple of 16"

        self.initial = ConvBlock(nc, ndf, 4, 2, bn=False)
        csize,cndf = isize/2,ndf
        self.extra = nn.Sequential(*[ConvBlock(cndf, cndf, 3, 1)
                                    for t in range(n_extra_layers)])

        pyr_layers = []
        while csize > 4:
            pyr_layers.append(ConvBlock(cndf, cndf*2, 4, 2))
            cndf *= 2; csize /= 2
        self.pyramid = nn.Sequential(*pyr_layers)
        
        self.final = nn.Conv2d(cndf, 1, 4, padding=0, bias=False)

    def forward(self, input):
        x = self.initial(input)
        x = self.extra(x)
        x = self.pyramid(x)
        return self.final(x).mean(0).view(1)

In [None]:
class DeconvBlock(nn.Module):
    def __init__(self, ni, no, ks, stride, pad, bn=True):
        super().__init__()
        self.conv = nn.ConvTranspose2d(ni, no, ks, stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm2d(no)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv(x))
        return self.bn(x) if self.bn else x

In [None]:
class DCGAN_G(nn.Module):
    def __init__(self, isize, nz, nc, ngf, n_extra_layers=0):
        super().__init__()
        assert isize % 16 == 0, "isize has to be a multiple of 16"

        cngf, tisize = ngf//2, 4
        while tisize!=isize: cngf*=2; tisize*=2
        layers = [DeconvBlock(nz, cngf, 4, 1, 0)]

        csize, cndf = 4, cngf
        while csize < isize//2:
            layers.append(DeconvBlock(cngf, cngf//2, 4, 2, 1))
            cngf //= 2; csize *= 2

        layers += [DeconvBlock(cngf, cngf, 3, 1, 1) for t in range(n_extra_layers)]
        layers.append(nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False))
        self.features = nn.Sequential(*layers)

    def forward(self, input): return F.tanh(self.features(input))

In [None]:
tfms = tfms_from_stats(inception_stats, sz)
md = ImageClassifierData.from_csv(path, 'TargetClass', csv, tfms=tfms, bs=bs, skip_header=False, continuous=True)

In [None]:
#md = md.resize(128)

In [None]:
# Validation dataset
list_val = iter(md.val_dl)

In [None]:
x,y=next(list_val)
idx=0

fig,axes = plt.subplots(3,3, figsize=(12,12))
for i,ax in enumerate(axes.flat):
    ima=md.val_ds.denorm(x)[i]
    #ax.set_title(data.classes[y[i]])
    ax.imshow(ima)

In [None]:
# Note: 1 channel
netG = DCGAN_G(sz, nz, 1, 64, 1).cuda()
netD = DCGAN_D(sz, 1, 64, 1).cuda()

In [None]:
netG

In [None]:
netD

In [None]:
def create_noise(b): return V(torch.zeros(b, nz, 1, 1).normal_(0, 1))

In [None]:
preds = netG(create_noise(4))
pred_ims = md.trn_ds.denorm(preds)

fig, axes = plt.subplots(2, 2, figsize=(6, 6))
for i,ax in enumerate(axes.flat): ax.imshow(pred_ims[i])

In [None]:
def gallery(x, nc=3):
    n,h,w,c = x.shape
    nr = n//nc
    assert n == nr*nc
    return (x.reshape(nr, nc, h, w, c)
              .swapaxes(1,2)
              .reshape(h*nr, w*nc, c))

In [None]:
# Note: initial lr = 1e-4
optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-3)
optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-3)

In [None]:
def train(niter, first=True):
    gen_iterations = 0
    for epoch in trange(niter):
        netD.train(); netG.train()
        data_iter = iter(md.trn_dl)
        i,n = 0,len(md.trn_dl)
        with tqdm(total=n) as pbar:
            while i < n:
                set_trainable(netD, True)
                set_trainable(netG, False)
                d_iters = 100 if (first and (gen_iterations < 25) or (gen_iterations % 500 == 0)) else 5
                j = 0
                while (j < d_iters) and (i < n):
                    j += 1; i += 1
                    for p in netD.parameters(): p.data.clamp_(-0.01, 0.01)
                    real = V(next(data_iter)[0])
                    real_loss = netD(real)
                    fake = netG(create_noise(real.size(0)))
                    fake_loss = netD(V(fake.data))
                    netD.zero_grad()
                    lossD = real_loss-fake_loss
                    lossD.backward()
                    optimizerD.step()
                    pbar.update()

                set_trainable(netD, False)
                set_trainable(netG, True)
                netG.zero_grad()
                lossG = netD(netG(create_noise(bs))).mean(0).view(1)
                lossG.backward()
                optimizerG.step()
                gen_iterations += 1
            
        print(f'Loss_D {to_np(lossD)}; Loss_G {to_np(lossG)}; '
              f'D_real {to_np(real_loss)}; Loss_D_fake {to_np(fake_loss)}')

In [None]:
torch.backends.cudnn.benchmark=True

In [None]:
train(100, False)

In [None]:
fixed_noise = create_noise(bs)

In [None]:
netD.eval(); netG.eval();
fake = netG(fixed_noise).data.cpu()
faked = np.clip(md.trn_ds.denorm(fake),0,1)

plt.figure(figsize=(9,9))
plt.imshow(gallery(faked, 8));

In [None]:
set_trainable(netD, True)
set_trainable(netG, True)
# Note: initial lr = 1e-5
optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-4)
optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-4)

In [None]:
train(100, False)

In [None]:
netD.eval(); netG.eval();
fake = netG(fixed_noise).data.cpu()
faked = np.clip(md.trn_ds.denorm(fake),0,1)

plt.figure(figsize=(9,9))
plt.imshow(gallery(faked, 8));

In [None]:
torch.save(netG.state_dict(), TMP_PATH/'netG_200.h5')
torch.save(netD.state_dict(), TMP_PATH/'netD_200.h5')