In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision 
from torchvision import transforms

from torch.utils.data import Dataset, DataLoader

from tqdm.autonotebook import tqdm

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow

import pandas as pd

from sklearn.metrics import accuracy_score

import time

from idlmam import train_network, Flatten, weight_reset, View, set_seed

%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('png', 'pdf')

  set_matplotlib_formats('png', 'pdf')


In [2]:
torch.backends.cudnn.deterministic=True
set_seed(42)

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [6]:
def simpleGAN(latent_d, neurons, out_shape, sigmoidG=False, leak=0.2):
    """
    This function will create a simple GAN for us to train. It will return a 
    tuple (G, D), holding the generator and discriminator network respectively. 
    
    latent_d: the number of latent variables we will use as input to the generator G. 
    neurons: how many hidden neurons to use in each hidden layer
    out_shape: the shape of the output of the discriminator D. This should be the 
    same shape as the real data. 
    sigmoidG: true if the generator G should end with a sigmoid activation, or 
    False if it should just return unbounded activations
    """
    G = nn.Sequential(
        fcLayer(latent_d, neurons, leak),
        fcLayer(neurons, neurons, leak),
        fcLayer(neurons, neurons, leak),
        nn.Linear(neurons, abs(np.prod(out_shape)) ),# np.prod will multiply each value together in the shape, giving us the total number of needed outputs. abs will remove the impact of "-1" for the batch dimension.
        View(out_shape)#Now re-shape the output to whatever D expects. 
    )
    #Sometimes we do/don't want G to return a sigmoid value (i.e., [0,1]), so we wrap it in an a conditional
    if sigmoidG:
        G = nn.Sequential(G, nn.Sigmoid())
    
    D = nn.Sequential(
        nn.Flatten(),
        fcLayer(abs(np.prod(out_shape)), neurons, leak),
        fcLayer(neurons, neurons, leak),
        fcLayer(neurons, neurons, leak),
        nn.Linear(neurons, 1 ) #D has 1 output for a binary classification problem
    )
    return G, D

In [7]:
class ConditionalWrapper(nn.Module):
    def __init__(self, input_shape, neurons, classes, main_network, leak=0.2):
        """ 
        input_shape: the shape that the latent variable $\boldsymbol{z}$ should take.
        neurons: nuerons to use in hidden layers
        classes: number of classes in labels $y$
        main_network: either the generator $G$ or discriminator $D$
        """
        super().__init__()
        
        self.input_shape = input_shape
        self.classes = classes
        #figure out number of latent parameters from the latent shape
        input_size = abs(np.prod(input_shape))
        #create an embedding layer to convert labels to vectors
        self.label_embedding = nn.Embedding(classes, input_size)
        
        #In the forward function we will concatenate the label and original date into one vector. Then this ‘combiner’ will take that 
        #extra large tensor and create a new tensor that is the size of just the original ‘input_shape’. This does the work of merging 
        #the conditional information (from label_embedding) into the latent vector.
        self.combiner = nn.Sequential(
            nn.Flatten(),
            fcLayer(input_size*2, input_size, leak=leak),#one FC layer
            nn.Linear(input_size, input_size),#A second FC layer, but first linear & activtion applied
            nn.LeakyReLU(leak),
            View(input_shape), #So that we can re-shape the output and apply normalizing based on the target output shape. This makes the Conditional wrapper useful for linear and convolutional models. 
            nn.LayerNorm(input_shape[1:]),
        )
        self.net = main_network
        
    
    #The forward function is the code that takes an input and produce an output. 
    def forward(self, x, condition=None):
        if condition is None:#if no label was given, lets pick one at random
            condition = torch.randint(0, self.classes, size=(x.size(0),), device=x.get_device()) 
        #embed the label and re-shape it as desied
        embd = self.label_embedding(condition)
        #make sure the label embd and data x have the same shape so that we can concatenate them
        embd = embd.view(self.input_shape)
        x = x.view(self.input_shape)
        #concatenate the latent input with the embedded label
        x_comb = torch.cat([x, embd], dim=1)
        #return the result of the network on the combined inputs
        return self.net(self.combiner(x_comb))
        

In [None]:
batch_size = 128
latent_d = 128
neurons = 512
out_shape = (-1, 28, 28) #You could also do (-1, 1, 28, 28) for 1 channel, but that makes numpy code a little more cumbersome later
num_epochs = 10

def fcLayer(in_neurons, out_neurons, leak=0.1): #our helper function
    """
    in_neurons: how many inputs to this layer
    out_neurons: how many outputs for this layer
    leak: the leaky relu leak value. 
    """
    return nn.Sequential(
        nn.Linear(in_neurons, out_neurons),
        nn.LeakyReLU(leak),
        nn.LayerNorm(out_neurons)
    )

In [None]:
latent_d = 128
out_shape = (-1, 1, 28, 28)
in_shape = (-1, latent_d)
classes = 10
G, D = simpleGAN(latent_d, neurons, out_shape, sigmoidG=True)

G = ConditionalWrapper(in_shape, neurons, classes, G)
D = ConditionalWrapper(out_shape, neurons, classes, D)

In [None]:
def train_c_wgan(D, G, loader, latent_d, epochs=20, device="gpu"):
    G_losses = []
    D_losses = []

    G = G.to(device)
    D = D.to(device)

    # Setup Adam optimizers for both G and D
    optimizerD = torch.optim.AdamW(D.parameters(), lr=0.0001, betas=(0.0, 0.9))
    optimizerG = torch.optim.AdamW(G.parameters(), lr=0.0001, betas=(0.0, 0.9))

    for epoch in tqdm(range(epochs)):
        for data in tqdm(loader, leave=False):
            if isinstance(data, tuple) or len(data) == 2:
                data, class_label = data
            batch_size = data.size(0)
            D.zero_grad()
            G.zero_grad()
            real = data.to(device)
            class_label = class_label.to(device)
            # Step 1) D-score, G-score, and gradient penalty
            #How well does D work on real data 
            D_success = D(real, class_label)

            ## Train with all-fake batch
            # Generate batch of latent vectors
            noise = torch.randn(batch_size, latent_d, device=device)
            # Generate fake image batch with G
            fake = G(noise, class_label) 
            # Classify all fake batch with D
            D_failure = D(fake, class_label) 

            #Now calculate for gradient penalty
            eps_shape = [batch_size]+[1]*(len(data.shape)-1)
            eps = torch.rand(eps_shape, device=device)
            fake = eps*real + (1-eps)*fake
            output = D(fake, class_label) 

            grad = torch.autograd.grad(outputs=output, inputs=fake,
                                  grad_outputs=torch.ones(output.size(), device=device),
                                  create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True)[0]

            D_grad_penalty = ((grad.norm(2, dim=1) - 1) ** 2).mean() 

            # Calculate D's loss on the all-fake batch
            errD = (D_failure-D_success).mean() + D_grad_penalty.mean()*10
            errD.backward()
            # Update D
            optimizerD.step()
            
            D_losses.append(errD.item())

            # Step 2) -D(G(z))
            D.zero_grad()
            G.zero_grad()
            # Since we just updated D, perform another forward pass of all-fake batch through D

            noise = torch.randn(batch_size, latent_d, device=device)
            output = -D(G(noise, class_label), class_label)
            # Calculate G's loss based on this output
            errG = output.mean()
            # Calculate gradients for G
            errG.backward()
            # Update G
            optimizerG.step()
            
            G_losses.append(errG.item())
            
    return D_losses, G_losses

In [None]:
train_data = torchvision.datasets.MNIST("./", train=True, transform=transforms.ToTensor(), download=True)
test_data = torchvision.datasets.MNIST("./", train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size)

In [None]:
D_losses, G_losses = train_c_wgan(D, G, train_loader, latent_d, epochs=20, device=device)

G = G.eval()
D = D.eval()

In [None]:
# Initialize BCEWithLogitsLoss function. The BCE loss is for binary classification problems, which ours is (real vs fake)
loss_func = nn.BCEWithLogitsLoss()

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = torch.optim.AdamW(D.parameters(), lr=0.0001, betas=(0.0, 0.9))
optimizerG = torch.optim.AdamW(G.parameters(), lr=0.0001, betas=(0.0, 0.9))

In [None]:
G_losses = []
D_losses = []

for epoch in tqdm(range(num_epochs)):
    for data, class_label in tqdm(train_loader, leave=False):
        # prep batch & make labels
        real_data = data.to(device)
        y_real = torch.full((batch_size,1), real_label, dtype=torch.float32, device=device)
        y_fake = torch.full((batch_size,1), fake_label, dtype=torch.float32, device=device)
        
        # Step 1) $\ell ( D( x_{\mathit{real}}) ,\ y_{\mathit{real}})\$ and $\ell ( D(\boldsymbol{x}_{\mathit{fake}}) ,\ y_{\mathit{fake}})$
        D.zero_grad()

        # Calculate loss on all-real batch
        errD_real = loss_func(D(real_data), y_real)
        # Calculate gradients for D in backward pass
        errD_real.backward()

        ## Train with all-fake batch
        # Generate batch of latent vectors $z \sim \mathcal{N}(\vec{0}, 1)$
        z = torch.randn(batch_size, latent_d, device=device)
        # Generate fake image batch with G
        # Classify all fake batch with D. We will save this to re-use for the 2nd step. 
        fake = G(z) 
        #Why do we detach here? Because we don't want the gradient to impact G. 
        #Our goal right now is to update _just_ the discriminator. 
        #BUT, we will re-use this fake data for updating the discriminator, so we want to save the 
        #non-detached version! 
        # Calculate D's loss on the all-fake batch
        errD_fake = loss_func(D(fake.detach()), y_fake)
        # Calculate the gradients for this batch
        errD_fake.backward()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        # Step 2) $\ell ( D(\boldsymbol{x}_{\mathit{fake}}) ,\ y_{\mathit{real}})$
        G.zero_grad()
        # Since we just updated D, perform another forward pass of all-fake batch through D
        # Calculate G's loss based on this output
        errG = loss_func(D(fake), y_real)
        # Calculate gradients for G
        errG.backward()
        # Update G
        optimizerG.step()
        
        G_losses.append(errG.item())
        D_losses.append(errD.item())

In [None]:
with torch.no_grad():
    noise = torch.randn(batch_size, latent_d, device=device) #$\boldsymbol{z} \sim \mathcal{N}(\vec{0}, \boldsymbol{I})$
    fake_digits = G(noise) 
    scores = torch.sigmoid(D(fake_digits))
    
    fake_digits = fake_digits.cpu()
    scores = scores.cpu().numpy().flatten()

In [None]:
def plot_gen_imgs(fake_digits, scores=None):
    batch_size = fake_digits.size(0)
    #This code assumes we are working with black-and-white images
    fake_digits = fake_digits.reshape(-1, fake_digits.size(-1), fake_digits.size(-1))
    i_max = int(round(np.sqrt(batch_size)))
    j_max = int(np.floor(batch_size/float(i_max)))
    f, axarr = plt.subplots(i_max,j_max, figsize=(10,10))
    for i in range(i_max):
        for j in range(j_max):
            indx = i*j_max+j
            axarr[i,j].imshow(fake_digits[indx,:].numpy(), cmap='gray', vmin=0, vmax=1)
            axarr[i,j].set_axis_off()
            if scores is not None:
                axarr[i,j].text(0.0, 0.5, str(round(scores[indx],2)), dict(size=20, color='red'))
plot_gen_imgs(fake_digits, scores)

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
with torch.no_grad():
    #Generate 10 latent noise vectors, and repeat them 10 times. So we are re-using the same latent codes
    noise =  torch.randn(10, latent_d, device=device).repeat((1,10)).view(-1, latent_d)
    #Count from 0 to 9, then wrap around back to 0 again. This is done 10 times
    labels = torch.fmod(torch.arange(0, noise.size(0), device=device), classes)
    #Now we have the same latent in noise being used to generate 10 images, but changing the label each time. 
    fake_digits = G(noise, labels)
    scores = D(fake_digits, labels)
    
    fake_digits = fake_digits.cpu()
    scores = scores.cpu().numpy().flatten()
plot_gen_imgs(fake_digits)#When we plot the results, we should see a grid of digits going from 0 to 9, where each row all used the same latent vector and share similar visual properties.  