In [None]:
import torch

from torch import nn

from torchvision import transforms

from torchvision.utils import make_grid

import torchvision.datasets as datasets

from torch.utils.data import DataLoader

from torchsummary import summary

import os

import numpy as np

import matplotlib.pyplot as plt

import random

import urllib.request

import tarfile

import warnings

warnings.simplefilter("ignore")

In [None]:
#!pip install torchsummary

In [None]:
'''
imgs_tar_url = 'https://github.com/kbmurali/hindi_hw_deep_gan/blob/main/hindi_alps.tar.gz?raw=true'

tar_file = 'hindi_alps.tar.gz'

urllib.request.urlretrieve( imgs_tar_url, tar_file )

tar = tarfile.open( tar_file )

# Extract all files to the current directory
tar.extractall()

# Close the tar file
tar.close()
'''

print()

In [None]:
# set seed
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

In [None]:
#torch.cuda.empty_cache()

#torch.cuda.memory_summary(device=None, abbreviated=True)

torch.cuda.is_available()

In [None]:
class HindiHWAlphabetGenerator( nn.Module ):
    def __init__(self, input_channels=10, final_image_channels=1, conv_filter_factor=64 ):
        super(HindiHWAlphabetGenerator, self).__init__()
        
        self.input_channels = input_channels
        
        self.model = nn.Sequential()
        
        ##Convolution layers
        self.model.add_module( 'conv_1', self._convolution_layer( input_channels,
                                                                  conv_filter_factor * 4, 
                                                                  kernel_size=3,
                                                                  stride=2 ) )
        
        self.model.add_module( 'conv_2', self._convolution_layer( conv_filter_factor * 4,
                                                                  conv_filter_factor * 8, 
                                                                  kernel_size=4,
                                                                  stride=2 ) )
        
        self.model.add_module( 'conv_3', self._convolution_layer( conv_filter_factor * 8,
                                                                  conv_filter_factor * 4, 
                                                                  kernel_size=4,
                                                                  stride=1 ) )
        
        self.model.add_module( 'conv_4', self._convolution_layer( conv_filter_factor * 4,
                                                                  conv_filter_factor * 2, 
                                                                  kernel_size=4,
                                                                  stride=2 ) )
        
        self.model.add_module( 'conv_5', self._convolution_layer( conv_filter_factor * 2,
                                                                  conv_filter_factor, 
                                                                  kernel_size=6,
                                                                  stride=1 ) )
        
        ##Layer for image output
        self.model.add_module( 'output', self._output_layer( conv_filter_factor ,
                                                             final_image_channels,
                                                             kernel_size=4,
                                                             stride=1 ) )
    
    def forward( self, noise_tensors ):
        '''
        Given a input noise vectors, returns generated image for each noise vector.
        Parameters:
            noise_tensors: a noise tensor with dimensions (n_samples, input_channels)
        '''
        channelized_noise_inputs = self.channelize_noise_inputs( noise_tensors )
        
        return self.model( channelized_noise_inputs )
    
    def channelize_noise_inputs( self, noise_tensors ):
        return noise_tensors.view( len(noise_tensors), self.input_channels, 1, 1)
    
    def _convolution_layer( self, input_channels, output_channels, kernel_size=4, stride=1 ):
        return nn.Sequential(
                                nn.ConvTranspose2d( input_channels,
                                                    output_channels, 
                                                    kernel_size=kernel_size,
                                                    stride=stride ),

                                nn.BatchNorm2d( output_channels ),

                                nn.ReLU(inplace=True)
                            )
    
    def _output_layer( self, input_channels, output_channels, kernel_size=4, stride=1 ):
        return nn.Sequential(
                                nn.ConvTranspose2d( input_channels,
                                                    output_channels, 
                                                    kernel_size=kernel_size,
                                                    stride=stride ),

                                nn.Tanh()
                            )

In [None]:
class HindiHWAlphabetDiscriminator( nn.Module ):
    def __init__(self, image_dim=1024, hidden_dim=128, num_hidden=3 ):
        super( HindiHWAlphabetDiscriminator, self ).__init__()
        
        self.model = nn.Sequential()
        
        curr_input_dim = image_dim
        
        ##Hidden layers
        for i in range(num_hidden-1, -1, -1 ):
            hidden_out_dim = hidden_dim * (2 ** i)
            
            hidden_layer = self._hidden_layer( curr_input_dim, hidden_out_dim )
            
            self.model.add_module( 'hidden_' + str( num_hidden - i ), hidden_layer )
            
            curr_input_dim = hidden_out_dim
        
        ##Output layer for fake probability of the input image
        self.model.add_module( 'output', nn.Linear( curr_input_dim, 1 ) )
        
    def forward( self, image_tensors ):
        '''
        Given a input image tensors, returns fake probability for each input image.
        Parameters:
            image_tensors: a tensor with dimensions (n_samples, img_dim)
        '''

        image_inputs = image_tensors.view( len(image_tensors), -1 )

        return self.model( image_inputs )
    
    def _hidden_layer( self, input_dim, output_dim ):
        '''
        Parameters:
            input_dim: a scalar dimension of the vector from the previous layer
            output_dim: a scalar dimension of the vector output from this layer
        Returns:
            a NN hidden layer represented by a nn.Sequential instance containing
            a Linear transformation followed by a LeakyReLU activation with a
            negative slope of 0.2
        '''
        return nn.Sequential(
            nn.Linear( input_dim, output_dim ),
            nn.LeakyReLU(0.2)
        )

In [None]:
class GANTrainer:
    def __init__( self,
                  input_noise_dim,
                  generator,
                  discriminator,
                  gen_optimizer,
                  disc_optimizer,
                  noise_inputs_generator_func,
                  num_epochs,
                  real_images_loader,
                  criterion = nn.BCEWithLogitsLoss(),
                  device='cpu',
                  display_step=250 ):
        self.input_noise_dim = input_noise_dim
        self.generator = generator
        self.discriminator = discriminator
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer
        self.noise_inputs_generator_func = noise_inputs_generator_func
        self.num_epochs = num_epochs
        self.real_images_loader = real_images_loader
        self.criterion = criterion
        self.device = device
        self.display_step = display_step
    
    @staticmethod
    def plt_imgs( images_tensor, num_imgs=25, size=(1, 32, 32), nrow=5 ):
        images_tensor = (images_tensor + 1) / 2
        unflattened_imgs = images_tensor.detach().cpu()
        img_quilt = make_grid( unflattened_imgs[:num_imgs], nrow=nrow )
        plt.imshow( img_quilt.permute(1, 2, 0).squeeze() )
        plt.show()
    
    def train( self ):
        current_step = 1

        mean_discriminator_loss = 0
        mean_generator_loss = 0

        for epoch in range( self.num_epochs ):
            for _,(real_images, _) in enumerate( self.real_images_loader ):
                cur_batch_size = len( real_images )
                real_images = real_images.to( self.device )

                discriminator_loss, generator_loss = self._train( real_images )

                mean_discriminator_loss += discriminator_loss
                mean_generator_loss += generator_loss

                if current_step % self.display_step == 0:
                    mean_discriminator_loss = mean_discriminator_loss/self.display_step
                    mean_generator_loss = mean_generator_loss/self.display_step

                    print(f"Epoch {epoch}: Step {current_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")

                    fake_gen_input_vectors = self.noise_inputs_generator_func( cur_batch_size,
                                                                               self.input_noise_dim,
                                                                               self.device )

                    fake_images = self.generator( fake_gen_input_vectors )
                    
                    GANTrainer.plt_imgs( fake_images )
                    #GANTrainer.plt_imgs( real_images )

                    mean_discriminator_loss = 0
                    mean_generator_loss = 0

                current_step += 1
    
    def __str__(self):
        ret_str = str( self.__class__.__name__ ) + '(\n' + \
                  '(generator): ' + self.generator.__str__() + \
                  '\n(discriminator): ' + self.discriminator.__str__() + '\n)'
        
        return ret_str
    
    def _discriminator_loss( self, real_images ):
        '''
        Parameters:
            criterion: the loss function used to compare the discriminator's predictions 
                       to the ground truth of the images (fake = 0, real = 1)
            real_images: a mini batch of real images
        Returns:
            discriminator_loss: a torch scalar loss value for the current batch
        '''
        input_noise_vectors = self.noise_inputs_generator_func( len( real_images ),
                                                                self.input_noise_dim,
                                                                self.device )

        fake_images = self.generator( input_noise_vectors ).detach()

        y_pred_fake = self.discriminator( fake_images )
        y_expected_fake = torch.zeros_like( y_pred_fake )
        fake_loss = self.criterion( y_pred_fake, y_expected_fake )

        y_pred_real = self.discriminator( real_images )
        y_expected_real = torch.ones_like( y_pred_real )
        real_loss = self.criterion( y_pred_real, y_expected_real )

        discriminator_loss = (fake_loss + real_loss) / 2

        return discriminator_loss
    
    def _generator_loss( self, num_samples ):
        '''
        Parameters:
            criterion: the loss function used to compare the discriminator's predictions 
                       to the ground truth reality of the images (fake = 1 in case of generator loss)
            num_samples: the number of images the generator should produce
        Returns:
            generator_loss: a torch scalar loss value for the current batch
        '''
        input_noise_vectors = self.noise_inputs_generator_func( num_samples, self.input_noise_dim, self.device )

        fake_images = self.generator( input_noise_vectors )

        y_pred = self.discriminator( fake_images )
        y_expected = torch.ones_like( y_pred )

        generator_loss = self.criterion( y_pred, y_expected )

        return generator_loss
    
    def _train( self, real_images ):
        ### Update discriminator ###
        # Zero out the gradients before backpropagation
        self.disc_optimizer.zero_grad()

        discriminator_loss = self._discriminator_loss( real_images )

        # Update discriminator gradients
        discriminator_loss.backward( retain_graph=True )

        # Update discriminator optimizer
        self.disc_optimizer.step()
        
        ### Update generator ###
        self.gen_optimizer.zero_grad()
        
        num_samples = len( real_images )
        
        generator_loss = self._generator_loss( num_samples )
        
        generator_loss.backward( retain_graph=True )
        
        self.gen_optimizer.step()
        
        return discriminator_loss, generator_loss

In [None]:
def get_generator_inputs( num_samples, input_noise_dim, device='cpu'):
        '''
        Function to create a tensor of shape( num_samples, input_noise_dim ).
        Each element tensor filled with random numbers from the normal distribution.
        Parameters:
            num_samples: a scalar for number of noise vectors to generate
            input_noise_dim: a scalar representing the dimension of the noise vector
            device: the device type
        '''
        return torch.randn( num_samples, input_noise_dim, device=device )

In [None]:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )

batch_size = 25

img_transformer = transforms.Compose([
                        transforms.Resize(32),
                        transforms.CenterCrop(32),
                        transforms.Grayscale(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,))
                    ])


dataset = datasets.ImageFolder( 'alps/imgs', transform=img_transformer)

real_images_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

criterion = nn.BCEWithLogitsLoss()
num_epochs = 2
display_step = 250
input_noise_dim = 72

gen = HindiHWAlphabetGenerator( input_channels=72, final_image_channels=1, conv_filter_factor=72 ).to( device )

disc = HindiHWAlphabetDiscriminator( image_dim=1024, hidden_dim=320, num_hidden=4 ).to( device )

# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002

# These parameters control the optimizer's momentum
# https://distill.pub/2017/momentum/
beta_1 = 0.5 
beta_2 = 0.999

gen_optimizer = torch.optim.Adam( gen.parameters(), lr=lr,  betas=(beta_1, beta_2) )
disc_optimizer = torch.optim.Adam( disc.parameters(), lr=lr,  betas=(beta_1, beta_2) )

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
trainer = GANTrainer( input_noise_dim,
                      gen,
                      disc,
                      gen_optimizer,
                      disc_optimizer,
                      get_generator_inputs,
                      num_epochs,
                      real_images_loader,
                      criterion = criterion,
                      device = device,
                      display_step = display_step )

In [None]:
trainer.train()