In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from models.building_blocks import ResBlock, UpBlock, DownBlock, ModulatingNet
from models.resnet_cond import cResNet_Generator, cResNet_Discriminator

In [None]:
print("Loading embeddings...")
df = pd.read_csv("F:\CollegeStuff\Sem3\Experiments_Tensorflow\GAN Stuff\Datasets for GANs\glove.6B.50d.txt", sep=" ", quoting=3, header=None, index_col=0)
wv = {key: val.values for key, val in df.T.items()}
print("Done")
cifar_classes = {0 : "airplane", 1: "automobile", 2 : "bird", 3 : "cat", 4 : "deer", 5 : "dog", 6 : "frog", 
                 7 : "house", 8 : "ship", 9 : "truck"}

In [None]:
batch_size_train = 64

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=2, shuffle = True)

In [None]:
data_loader = iter(trainloader)
(data, target) = next(data_loader)

print(data.shape)
print(data.max(), data.min())
img = np.transpose(data[0], (1, 2, 0))
plt.imshow((img+1)/2)
plt.show()
print(target[0].shape, target[0])

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

noise_dim = 128
netG = cResNet_Generator(image_size = (32,32), noise_dim = noise_dim, emb_dim = 50).to(device)
netD = cResNet_Discriminator(image_size = (32,32), emb_dim = 50, mbd_features = 16).to(device)

if torch.cuda.device_count() > 1:
    netG = nn.DataParallel(netG, list(range(torch.cuda.device_count())))
    netD = nn.DataParallel(netD, list(range(torch.cuda.device_count())))

#Optional orthogonal initialization of weights, does not work with Spectral Normalization!#########
for m in netG.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)
'''
for m in netD.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
        nn.init.orthogonal_(m.weight)
'''

#Two Timescale Update Rule
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0001, betas = (0.5, 0.999))

In [None]:
epochs = 500
disc_steps = 1
path = "./saved_models/"

for epoch in range(epochs):
    for i, (data, target) in enumerate(trainloader):

        #Dealing with the discriminator################################
        #Specify number of disc updates above##############
        for s in range(disc_steps):
            netD.zero_grad()
        
            embed_labels = []
            for c in target:
                embed_labels.append(wv[cifar_classes[c.item()]])
            embed_labels = torch.Tensor(np.array(embed_labels, dtype = np.float32))
        
            real_images = data.to(device)
            b_size = real_images.size(0)
        
            output = netD(real_images, embed_labels).view(-1)
            errD_real = torch.mean(F.relu(1 - output))
            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(b_size, noise_dim, device = device)
            fake = netG(noise, embed_labels)

            output = netD(fake.detach(), embed_labels).view(-1)
            errD_fake = torch.mean(F.relu(1 + output))
            errD_fake.backward()
            D_G_z1 = output.mean().item()

            errD = errD_fake + errD_real
            optimizerD.step()
            
        #Dealing with the generator###################################
        netG.zero_grad()

        output = netD(fake, embed_labels).view(-1)
        errG = -torch.mean(output)

        D_G_z2 = output.mean().item()
        errG.backward()
        optimizerG.step()

        if i%100 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                    % (epoch, epochs, i, len(trainloader),
                    errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    
    #if epoch%100 == 0:
        #torch.save(netG.state_dict(), path + "sngan_res_cifar_G.pth")
        #torch.save(netD.state_dict(), path + "sngan_res_cifar_D.pth")

In [None]:
f, a = plt.subplots(4, 8, figsize=(20, 20))
for i in range(8):
  noise = torch.randn(4, noise_dim, device = device)
  y_sam = torch.Tensor(np.array([wv[cifar_classes[i]]]*4, dtype = np.float32))
  with torch.no_grad():
    fake = netG(noise, y_sam)

  for j in range(4):
      img = fake[j].cpu()
      img = np.transpose(img, (1, 2, 0))
      img = (img+1)/2
      a[j][i].imshow(img)
      a[j][i].axis("off")
    
plt.savefig("CIFCol.png")
plt.show()

plt.imshow(img)
plt.show()