## Variational autoencoder / GAN / SimCLR
- 2 pt: implement and test variational (convolutional) autoencoder or GAN or SimCLR

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import imageio
import numpy as np
import matplotlib

from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
matplotlib.style.use('ggplot')
torch.manual_seed(1)

# learning parameters
batch_size = 256
epochs = 100
sample_size = 64 # fixed sample size
nz = 128 # latent vector size
k = 1 # number of steps to apply to the netD
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),
])

to_pil_image = transforms.ToPILImage()

train_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transform
)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)


In [None]:
class netG(nn.Module):
    def __init__(self, nz):
        super(netG, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            nn.Linear(self.nz, 256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.main(x).view(-1, 1, 28, 28)

class netD(nn.Module):
    def __init__(self):
        super(netD, self).__init__()
        self.n_input = 784
        self.main = nn.Sequential(
            nn.Linear(self.n_input, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(-1, 784)
        return self.main(x)

netG = netG(nz).to(device)
netD = netD().to(device)

print('##### netG #####')
print(netG)
print('######################')

print('\n##### netD #####')
print(netD)
print('######################')

In [None]:
# optimizers
optim_g = optim.Adam(netG.parameters(), lr=0.0002)
optim_d = optim.Adam(netD.parameters(), lr=0.0002)

# loss function
loss_fn = nn.BCELoss()

losses_g = [] # to store netG loss after each epoch
losses_d = [] # to store netD loss after each epoch
images = [] # to store images generatd by the netG

# to create real labels (1s)
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)

# to create fake labels (0s)
def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)

# function to create the noise vector
def create_noise(sample_size, nz):
    return torch.randn(sample_size, nz).to(device)

# to save the images generated by the netG
def save_generator_image(image_batch, path):
    save_image(image_batch, path)

# function to train the netD network
def train_discriminator(optimizer, real_image_batch, fake_image_batch):
    b_size = real_image_batch.size(0)
    real_label = label_real(b_size)
    fake_label = label_fake(b_size)

    optimizer.zero_grad()

    output_real = netD(real_image_batch)
    loss_real = loss_fn(output_real, real_label)

    output_fake = netD(fake_image_batch)
    loss_fake = loss_fn(output_fake, fake_label)

    loss_real.backward()
    loss_fake.backward()
    optimizer.step()

    return loss_real + loss_fake

# function to train the netG network
def train_generator(optimizer, fake_image_batch):
    b_size = fake_image_batch.size(0)
    real_label = label_real(b_size)

    optimizer.zero_grad()

    output = netD(fake_image_batch)
    loss = loss_fn(output, real_label)

    loss.backward()
    optimizer.step()

    return loss    

In [None]:
## Training
# create the noise vector
noise = create_noise(sample_size, nz)

netG.train()
netD.train()

for epoch in range(epochs):
    print(f"Epoch {epoch} of {epochs}")
    loss_g = 0.0
    loss_d = 0.0
    for bi, data in tqdm(enumerate(train_dataloader), total=int(len(train_data)/train_dataloader.batch_size)):
        image_batch, _ = data
        image_batch = image_batch.to(device)
        b_size = len(image_batch)
        # run the netD for k number of steps
        for step in range(k):
            fake_image_batch = netG(create_noise(b_size, nz)).detach()
            real_image_batch = image_batch
            # train the netD network
            loss_d += train_discriminator(optim_d, real_image_batch, fake_image_batch)
        fake_image_batch = netG(create_noise(b_size, nz))
        # train the netG network
        loss_g += train_generator(optim_g, fake_image_batch)

    # create the final fake image_batch for the epoch
    generated_img = netG(noise).cpu().detach()
    # make the images as grid
    generated_img = make_grid(generated_img)
    # save the generated torch tensor models to disk
    save_generator_image(generated_img, f"data/training/gen_img{epoch}.png")
    images.append(generated_img)
    epoch_loss_g = loss_g / bi # total netG loss for the epoch
    epoch_loss_d = loss_d / bi # total netD loss for the epoch
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)
    print(f"netG loss: {epoch_loss_g:.8f}, netD loss: {epoch_loss_d:.8f}")

print('DONE TRAINING')
torch.save(netG.state_dict(), 'data/training/netG.pth')

In [None]:
# save the generated images as GIF file
imgs = [np.array(to_pil_image(img)) for img in images]
imageio.mimsave('data/outputs/generator_images.gif', imgs)

# plot and save the netG and netD loss
plt.figure()
plt.plot(losses_g, label='netG loss')
plt.plot(losses_d, label='netD Loss')
plt.legend()
plt.savefig('data/outputs/loss.png')

In [None]:
## real vs fake
# Grab a batch of real images from the dataloader
real_batch = next(iter(train_dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(torchvision.utils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(images[-1],(1,2,0)))
plt.show()

In [None]:
#%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in images]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())