<a href="https://colab.research.google.com/github/magicbycalvin/GANTheftAuto/blob/dev/GANav.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Implementation of a Vanilla GAN to generate trajectories

In [None]:
import torch
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import numpy as np
from torch.autograd import Variable
from torch.utils.data import TensorDataset, DataLoader
from skimage.util import random_noise

In [None]:
class generator(nn.Module):
    # Network Architecture based on infoGAN (https://arxiv.org/abs/1606.03657)
    def __init__(self, input_dim=100, output_dim=3, input_size=32,base_size=64):
        super(generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size
        self.base_size = base_size


        self.cn = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 2*self.base_size * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(2*self.base_size * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(2*self.base_size, self.base_size, 4, 2, 1),
            nn.BatchNorm2d(self.base_size),
            nn.ReLU(),
            nn.ConvTranspose2d(self.base_size, self.output_dim, 4, 2, 1),
            nn.Tanh(),
        )
        initialize_weights(self)

        self.l1 = nn.Linear(2,10)
        self.l2 = nn.Linear(10, 10)
        self.l3 = nn.Linear(10,3)
        self.cf3 = nn.Linear(33, 3)

    def forward(self, input, y):
        x = self.cn(input)
        print(x.shape)
        x = x.view(-1, 2*self.base_size, (self.input_size // 4), (self.input_size // 4))
        print(x.shape)
        x = self.deconv(x)
        print(x.shape)
        y = self.l1(y)
        print(x.shape)
        y = self.l2(y)
        print(x.shape)
        y = self.l3(y)
        print(x.shape)
        z = torch.cat([x,y],1)
        print(x.shape)

        return z

In [None]:
class discriminator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
    def __init__(self, input_dim=1, output_dim=1, input_size=32,base_size=64):
        super(discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = 28
        self.base_size = base_size

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, self.base_size, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(self.base_size, 2*self.base_size, 4, 2, 1),
            nn.BatchNorm2d(2*self.base_size),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Sequential(
            nn.Linear(2*self.base_size * (self.input_size // 4) * (self.input_size // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),  # Note: no activation at the output.
        )
        initialize_weights(self)

    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 2*self.base_size * (self.input_size // 4) * (self.input_size // 4))
        x = self.fc(x)

        return x


In [None]:
def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

In [None]:
class GAN():
    def __init__(self,params):
        # parameters
        self.epoch = params['max_epochs']
        self.sample_num = 100
        self.batch_size = 300
        self.input_size = 300
        self.z_dim = params['z_dim']
        self.base_size = params['base_size']

        # load dataset
        self.data_loader = torch.utils.data.DataLoader(mnist_data_reduced, 
                                               batch_size=self.batch_size, 
                                               shuffle=True)
        data = self.data_loader.__iter__().__next__()[0]

        # initialization of the generator and discriminator
        self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size,base_size=self.base_size).cuda()
        self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size,base_size=self.base_size).cuda()
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=params['lr_g'], betas=(params['beta1'], params['beta2']))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=params['lr_d'], betas=(params['beta1'], params['beta2']))
        #self.G_optimizer = optim.SGD(self.G.parameters(), lr=params['lr_g'], momentum=0.9)
        #self.D_optimizer = optim.SGD(self.D.parameters(), lr=params['lr_d'], momentum=0.9)
        
        # initialization of the loss function

        self.BCE_loss = nn.BCELoss().cuda()
        self.MSE_loss = nn.MSELoss().cuda()
        
        # Gettng a batch of noise to generate the fake data
        self.sample_z_ = torch.rand((self.batch_size, self.z_dim)).cuda()
        
# Fucntion to train the GAN, where you alternate between the training of the genenator and discriminator
#--------------------------------------------------------------------------------------------------------

    def train(self):

       # Setting empty arrays for storing the losses

        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []

        # Setting up the labels for real and fake images
        self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1).cuda(), torch.zeros(self.batch_size, 1).cuda()
        
        print('training start!!')

        # Epoch loops

        for epoch in range(self.epoch):
            epoch_start_time = time.time()


            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // self.batch_size:
                    break

                # Generate random noise to push through the generator   

                z_ = torch.rand((self.batch_size, self.z_dim))
                x_, z_ = x_.cuda(), z_.cuda()

                # update D network
                for i in range(Ninner):
                  self.D_optimizer.zero_grad()
                  x_ = x_  + 0.025 * torch.randn(x_.shape).cuda()
                  D_real = self.D(x_)
                  D_real_loss = self.BCE_loss(D_real, self.y_real_)


                  G_ = self.G(z_)
                  D_fake = self.D(G_)
                  D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)

                  D_loss = D_real_loss + D_fake_loss
                  D_loss.backward()
                  self.D_optimizer.step()    

                # update G network
                for i in range(Ninner):
                  self.G_optimizer.zero_grad()
                  G_ = self.G(z_)
                  D_fake = self.D(G_)
                  G_loss = self.BCE_loss(D_fake, self.y_real_)
                
                  G_loss.backward()
                  self.G_optimizer.step()


                self.train_hist['D_loss'].append(D_loss.item())
                self.train_hist['G_loss'].append(G_loss.item())


                # Print iterations and losses

                if ((iter + 1) % 50) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                    ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))
                
            # Visualize results
            with torch.no_grad():
                visualize_results(self)

        print("Training finished!")

    def train_ae(self):
        torch.manual_seed(42)
        criterion = nn.MSELoss() # mean square error loss
        #optimizer = torch.optim.Adam(
            #model.params(), lr=learning_rate, weight_decay=1e-4)

        #train_loader = torch.utils.data.DataLoader(mnist_data_reduced, 
                                                  #batch_size=batch_size, 
                                                  #shuffle=True)
        models = []
        for epoch in range(self.epoch):
            for data in self.data_loader:
                img, _ = data
                img = Variable(img).cuda().type(torch.cuda.FloatTensor)
                recon = self.G(self.D(img))
                loss = criterion(recon, img)
                self.D_optimizer.zero_grad()
                loss.backward()
                self.D_optimizer.step()
        
            # Saving the models at each epoch for visualization purposes
            #fname = 'dict'+str(epoch)
            #torch.save(model.state_dict(), fname)
            print('Epoch:{}, Loss:{:.4f}'.format(epoch+1, float(loss))) 

        return model
