In [1]:
import numpy as np
import matplotlib.pyplot as plt
# import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from plotly.offline import init_notebook_mode, iplot
import plotly.io as pio
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import datetime
import os
import sys
from tqdm.notebook import tqdm

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5, ))
])

In [3]:
train_dataset = torchvision.datasets.MNIST(root=".", transform=transform)

In [4]:
batch_size = 128
train_iter = torch.utils.data.DataLoader(train_dataset, shuffle=True, 
                                        batch_size=batch_size)

In [5]:
def get_discriminator_block(input_dim, output_dim):
    """
    Description : To create a discriminator neural network block
    
    Parameters :
    @param input_dim -- a python integer for the input dimension
    @param output_dim -- a python integer for the output dimension
    
    Return :
    @ret neural_block -- a Sequential layer
    """
    
    neural_block = nn.Sequential(nn.Linear(input_dim, output_dim),
                         nn.LeakyReLU(negative_slope=0.2))
    
    return neural_block

In [6]:
def get_generator_block(input_dim, output_dim):
    """
    Description : To create a generator neural network block
    
    Parameters:
    @param input_dim -- a python integer for the input dimension
    @param output_dim -- a python integer for the output dimension
    
    Return :
    @ret neural_block -- a Sequential Layer
    """
    
    neural_block = nn.Sequential(nn.Linear(input_dim, output_dim),
                                nn.BatchNorm1d(output_dim, momentum=0.7),
                                nn.LeakyReLU(negative_slope=0.2))
    
    return neural_block

In [7]:
class Discriminator(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        
        super(Discriminator, self).__init__()
        
        self.disc = nn.Sequential(get_discriminator_block(input_dim, hidden_dim*4),
                                 get_discriminator_block(hidden_dim*4, hidden_dim*2),
                                 nn.Linear(hidden_dim*2, output_dim))
        
    
    def forward(self, X):
        out = self.disc(X)
        return out

In [8]:
class Generator(nn.Module):
    
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super(Generator, self).__init__()
        
        self.gen = nn.Sequential(get_generator_block(input_dim, hidden_dim*2),
                                get_generator_block(hidden_dim*2, hidden_dim*4),
                                nn.Linear(hidden_dim*4, output_dim),
                                nn.Tanh())
        
    def forward(self, X):
        out = self.gen(X)
        return out
        

In [9]:
def generate_noise(batch_size, latent_dim=100):
    """
    Description : Function to generate noise from Gaussian Distribution
    
    Parameters : 
    @param batch_size -- a python integer representing the batch size
    @param latent_dim -- a python integer representing the dimension of the noise
    
    Return :
    Random sampling of data from Gaussian Distribution as per the size
    """
    return torch.randn(batch_size, latent_dim)

In [10]:
criterion = nn.BCEWithLogitsLoss()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
latent_dim = 100

In [11]:
G = Generator(latent_dim, 784).to(device)
D = Discriminator(784, 1).to(device)

g_optim = torch.optim.Adam(G.parameters(), lr=1e-4 * 2)
d_optim = torch.optim.Adam(D.parameters(), lr=1e-4 * 2)

In [12]:
def discriminator_loss(G, D, criterion, real, \
                       batch_size, latent_dim,device):
    """
    Description : Function to calculate the loss for the discriminator
    
    Parameters:
    @param G -- The Generator Network
    @param D -- The Discriminator Network
    @param criterion -- The loss function
    @param real -- a numpy array representing the real images
    @param batch_size -- python integer representing the batch_size
    @param latent_dim -- python integer representing the latent dimension
    @param device -- whether 'cpu' or 'cuda:0'
    
    Return :
    @ret disc_loss -- the Discriminator Loss
    """
    
    zeros_ = torch.zeros(batch_size, 1).to(device)
    ones_ = torch.ones(batch_size, 1).to(device)
    
    real_loss = criterion(D(real), ones_)
    
    noise = generate_noise(batch_size, latent_dim).to(device)
    fake_img = G(noise).detach()
    
    fake_loss = criterion(D(fake_img), zeros_)
    
    disc_loss = (fake_loss + real_loss) * 0.5
    
    return disc_loss

In [13]:
def generator_loss(G, D, criterion, batch_size, latent_dim, device):
    
    """
    Description : Function to calculate the generator loss
    
    Parameters:
    @param G -- The Generator Network
    @param D -- The Discriminator Network
    @param criterion -- the Loss function
    @param batch_size -- the batch size
    @param latent_dim -- the latent dimension
    @param device -- whether 'cpu' or 'cuda:0'
    
    Return:
    @ret gen_loss -- the Generator Loss
    """
    
    noise = generate_noise(batch_size, latent_dim).to(device)
    ones_ = torch.ones(batch_size, 1).to(device)
    
    fake_image = G(noise)
    
    gen_loss = criterion(D(fake_image), ones_)
    
    return gen_loss

In [14]:
if not os.path.exists("./IMAGES/SimpleGANs"):
    os.mkdir("./IMAGES/SimpleGANs")

In [15]:
def scale_image(img):
    return (img + 1) / 2

In [16]:
def train(G, D, criterion, g_optim, d_optim, data_iter, latent_dim, epochs=200, save=True):
    """
    Description : To train the GANs model
    
    Parameters:
    @param G -- The Generator Network
    @param D -- The Discriminator Network
    @param g_optim -- the Generator Optimizer
    @param d_optim -- the Discriminator Optimizer
    @param data_iter -- the data to train on
    @param latent_dim -- the dimension for the noise
    @param epochs -- number of epochs to run (default=200)
    @param save -- whether to save the images or not (default=True)
    
    Return :
    @ret g_losses -- the Generator Losses
    @ret d_losses -- the Discriminator Losses
    """
    g_losses = []
    d_losses = []
    
    for epoch in range(epochs):
        
        d_loss = []
        g_loss = []
        batch_size = 0
        for inputs, _ in tqdm(data_iter):
            
            batch_size = inputs.size(0)
            inputs = inputs.resize(batch_size, 784).to(device)
            
            
            ##############################################
            ########### TRAIN DISCRIMINATOR ############
            #############################################
            d_optim.zero_grad()
            
            dLoss = discriminator_loss(G, D, criterion, inputs, batch_size,\
                                      latent_dim, device) 
            
            dLoss.backward()
            d_optim.step()
            
            d_loss.append(dLoss.item())
            
            ##############################################
            ########### TRAIN GENERATOR ############
            #############################################
            gLoss = []
            for _ in range(2):
                g_optim.zero_grad()
                _gLoss = generator_loss(G, D, criterion, batch_size, latent_dim, device)

                _gLoss.backward()
                g_optim.step()
                gLoss.append(_gLoss.item())
            g_loss.append(np.mean(gLoss))
            
        
        d_loss = np.mean(d_loss)
        g_loss = np.mean(g_loss)
        
        g_losses.append(g_loss)
        d_losses.append(d_loss)
        
        print(f"Epoch:{epoch+1}/{epochs} || Disc Loss: {d_loss} || Gen Loss : {g_loss}")
        
        if save:
            noise = generate_noise(batch_size, latent_dim)
            fake_img = G(noise.to(device))
            fake_img = fake_img.reshape(-1, 1, 28, 28)
            save_image(scale_image(fake_img), f"./IMAGES/SimpleGANs/gan_{epoch}.png")
            
    return g_losses, d_losses
            
            
            
            
            
            
            
            

In [17]:
g_losses, d_losses = train(G, D, criterion, g_optim, d_optim, train_iter, latent_dim)

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


non-inplace resize is deprecated




Epoch:1/200 || Disc Loss: 0.25948005675602315 || Gen Loss : 3.5365867766934924


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:2/200 || Disc Loss: 0.22322057572000825 || Gen Loss : 4.065337242602285


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:3/200 || Disc Loss: 0.3088390910104394 || Gen Loss : 4.110610055770955


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:4/200 || Disc Loss: 0.2643479453697642 || Gen Loss : 3.9793219167286398


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:5/200 || Disc Loss: 0.27618916211987354 || Gen Loss : 4.5745916394536685


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:6/200 || Disc Loss: 0.2801161622251275 || Gen Loss : 3.85561735124222


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:7/200 || Disc Loss: 0.36805373204669467 || Gen Loss : 3.7193260546177944


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:8/200 || Disc Loss: 0.3421338169114676 || Gen Loss : 3.1592105815151355


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:9/200 || Disc Loss: 0.33762428485381324 || Gen Loss : 3.06870951174673


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:10/200 || Disc Loss: 0.3492627377067802 || Gen Loss : 2.8946260753979307


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:11/200 || Disc Loss: 0.3766948907677807 || Gen Loss : 3.033590363159871


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:12/200 || Disc Loss: 0.4136512265213009 || Gen Loss : 2.5481435965373316


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:13/200 || Disc Loss: 0.4360938586278765 || Gen Loss : 2.2122439175272293


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:14/200 || Disc Loss: 0.45245728741830854 || Gen Loss : 2.065919626877506


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:15/200 || Disc Loss: 0.46343463424172227 || Gen Loss : 2.397614436617284


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:16/200 || Disc Loss: 0.4547085899915268 || Gen Loss : 2.159361891273751


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:17/200 || Disc Loss: 0.4612529446194167 || Gen Loss : 1.9249720794559797


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:18/200 || Disc Loss: 0.4575925083684006 || Gen Loss : 1.874892636911193


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:19/200 || Disc Loss: 0.46506978936795235 || Gen Loss : 2.1786709712512455


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:20/200 || Disc Loss: 0.4530769294894326 || Gen Loss : 1.9610130883483237


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:21/200 || Disc Loss: 0.501845785740342 || Gen Loss : 1.888008281366149


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:22/200 || Disc Loss: 0.5064354374972996 || Gen Loss : 1.6127155793628205


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:23/200 || Disc Loss: 0.5165283432139008 || Gen Loss : 1.5868683259116052


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:24/200 || Disc Loss: 0.5336710284513705 || Gen Loss : 1.6665287951289465


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:25/200 || Disc Loss: 0.5304457099834231 || Gen Loss : 1.7094916782653662


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:26/200 || Disc Loss: 0.5175309225694457 || Gen Loss : 1.8351826489861331


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:27/200 || Disc Loss: 0.5549970129405511 || Gen Loss : 1.6186070301766589


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:28/200 || Disc Loss: 0.5355281047602453 || Gen Loss : 1.5082367049859786


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:29/200 || Disc Loss: 0.5463538474238503 || Gen Loss : 1.400563839719748


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:30/200 || Disc Loss: 0.5885228631593017 || Gen Loss : 1.3208818577372952


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:31/200 || Disc Loss: 0.599081797831094 || Gen Loss : 1.4758109288937502


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:32/200 || Disc Loss: 0.5503141216631892 || Gen Loss : 1.3032855277757909


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:33/200 || Disc Loss: 0.5889696615464144 || Gen Loss : 1.3248316402882656


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:34/200 || Disc Loss: 0.5908077671837959 || Gen Loss : 1.1469149184760763


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:35/200 || Disc Loss: 0.6165459011790595 || Gen Loss : 1.1903375031343146


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:36/200 || Disc Loss: 0.6394494012601848 || Gen Loss : 1.3989211948695721


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:37/200 || Disc Loss: 0.5787467619122219 || Gen Loss : 1.4200630412300004


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:38/200 || Disc Loss: 0.5697527184669398 || Gen Loss : 1.271472767789735


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:39/200 || Disc Loss: 0.5503899545938984 || Gen Loss : 1.268044768238881


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:40/200 || Disc Loss: 0.5910920202732086 || Gen Loss : 1.266985959462774


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:41/200 || Disc Loss: 0.5825782356612972 || Gen Loss : 1.2110839969059553


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:42/200 || Disc Loss: 0.6089207670454786 || Gen Loss : 1.1482176955447776


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:43/200 || Disc Loss: 0.6178619320204517 || Gen Loss : 1.3039857354372548


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:44/200 || Disc Loss: 0.585934220092383 || Gen Loss : 1.199356532490838


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:45/200 || Disc Loss: 0.610100908256543 || Gen Loss : 1.1928590137694182


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:46/200 || Disc Loss: 0.5981348840666733 || Gen Loss : 1.2083358882205573


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:47/200 || Disc Loss: 0.6014075739297278 || Gen Loss : 1.174898240357828


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:48/200 || Disc Loss: 0.5833047295430067 || Gen Loss : 1.181887387721015


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:49/200 || Disc Loss: 0.5956164498064818 || Gen Loss : 1.3024521690568944


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:50/200 || Disc Loss: 0.5884387681860406 || Gen Loss : 1.2530758201059249


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:51/200 || Disc Loss: 0.578362832636213 || Gen Loss : 1.2012755507980581


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:52/200 || Disc Loss: 0.6083240936051554 || Gen Loss : 1.1646497099638493


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:53/200 || Disc Loss: 0.5928539560674858 || Gen Loss : 1.1368864744838112


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:54/200 || Disc Loss: 0.625286822634211 || Gen Loss : 1.0763349613146995


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:55/200 || Disc Loss: 0.6285134313076035 || Gen Loss : 1.0213804030850497


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:56/200 || Disc Loss: 0.6293547349189644 || Gen Loss : 1.0424914002926873


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:57/200 || Disc Loss: 0.6250638800389223 || Gen Loss : 1.0156662633805387


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:58/200 || Disc Loss: 0.6253723191427015 || Gen Loss : 1.083938527717265


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:59/200 || Disc Loss: 0.6100040339966064 || Gen Loss : 1.0924964333012668


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:60/200 || Disc Loss: 0.6094193813134866 || Gen Loss : 1.094285752092089


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:61/200 || Disc Loss: 0.6174144878316281 || Gen Loss : 1.083309571816723


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:62/200 || Disc Loss: 0.6122817046987985 || Gen Loss : 1.0608623173318183


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:63/200 || Disc Loss: 0.6073338914908835 || Gen Loss : 1.0586527449362821


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:64/200 || Disc Loss: 0.6065881751747783 || Gen Loss : 1.0928294528394873


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:65/200 || Disc Loss: 0.6074245367477189 || Gen Loss : 1.0761038338833018


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:66/200 || Disc Loss: 0.6182178837149891 || Gen Loss : 1.0146859726036535


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:67/200 || Disc Loss: 0.6178198905387667 || Gen Loss : 0.9832983395056938


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:68/200 || Disc Loss: 0.6290886004342199 || Gen Loss : 0.9874259958516306


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:69/200 || Disc Loss: 0.629076826801178 || Gen Loss : 1.0290287699399472


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:70/200 || Disc Loss: 0.61923708920794 || Gen Loss : 1.112348984299438


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:71/200 || Disc Loss: 0.6130176278065517 || Gen Loss : 1.1034394834341525


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:72/200 || Disc Loss: 0.6217809499962244 || Gen Loss : 1.0558801105917135


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:73/200 || Disc Loss: 0.619221405815214 || Gen Loss : 0.9672460029882662


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:74/200 || Disc Loss: 0.6255185071593409 || Gen Loss : 0.9784116017411767


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:75/200 || Disc Loss: 0.6275393568884844 || Gen Loss : 0.9877090060126299


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:76/200 || Disc Loss: 0.6289478246845416 || Gen Loss : 1.005021088730806


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:77/200 || Disc Loss: 0.6239772311278752 || Gen Loss : 1.0925178294624094


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:78/200 || Disc Loss: 0.615694543318962 || Gen Loss : 1.0394123515594742


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:79/200 || Disc Loss: 0.6265480891982121 || Gen Loss : 0.9729347432345978


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:80/200 || Disc Loss: 0.6370764670849863 || Gen Loss : 0.9613458573309852


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:81/200 || Disc Loss: 0.6362183498167026 || Gen Loss : 0.9718550171043827


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:82/200 || Disc Loss: 0.6339493829812577 || Gen Loss : 1.0415587293059587


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:83/200 || Disc Loss: 0.6189476330397226 || Gen Loss : 1.0733242387583515


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:84/200 || Disc Loss: 0.6285532622703357 || Gen Loss : 1.012818298042456


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:85/200 || Disc Loss: 0.6357858217855507 || Gen Loss : 0.9418327476678372


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:86/200 || Disc Loss: 0.6333192816929523 || Gen Loss : 0.9308608641375357


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:87/200 || Disc Loss: 0.6412879736947098 || Gen Loss : 0.9284990108979028


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:88/200 || Disc Loss: 0.6213657758129176 || Gen Loss : 0.9861058548315248


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:89/200 || Disc Loss: 0.6253369791167123 || Gen Loss : 1.0680003295829301


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:90/200 || Disc Loss: 0.6323253819937391 || Gen Loss : 0.9865535467164095


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:91/200 || Disc Loss: 0.6353958406682207 || Gen Loss : 0.9437905331410325


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:92/200 || Disc Loss: 0.6441489379289054 || Gen Loss : 0.909084147481776


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:93/200 || Disc Loss: 0.6248110807272417 || Gen Loss : 0.9400016471012823


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:94/200 || Disc Loss: 0.6431521238294492 || Gen Loss : 0.9334568129673696


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:95/200 || Disc Loss: 0.6378083159166105 || Gen Loss : 0.9423186914371783


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:96/200 || Disc Loss: 0.6509470233022531 || Gen Loss : 1.0266678686589321


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:97/200 || Disc Loss: 0.6264301655389098 || Gen Loss : 0.9905919487288257


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:98/200 || Disc Loss: 0.6361490434675074 || Gen Loss : 0.9168886061289163


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:99/200 || Disc Loss: 0.6395130649304339 || Gen Loss : 0.9045221238121041


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:100/200 || Disc Loss: 0.6559561466865702 || Gen Loss : 0.9078495829090126


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:101/200 || Disc Loss: 0.642890529274178 || Gen Loss : 0.9409640069836492


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:102/200 || Disc Loss: 0.6445057062960383 || Gen Loss : 0.9605889958994729


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:103/200 || Disc Loss: 0.6416848497604256 || Gen Loss : 0.9070874557439198


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:104/200 || Disc Loss: 0.6503415832133181 || Gen Loss : 0.8896549277976632


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:105/200 || Disc Loss: 0.6475361501738461 || Gen Loss : 0.8736505448691118


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:106/200 || Disc Loss: 0.644409924300749 || Gen Loss : 0.9032728567179332


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:107/200 || Disc Loss: 0.6466970758905797 || Gen Loss : 0.9201373124935988


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:108/200 || Disc Loss: 0.6458910176240559 || Gen Loss : 0.9225017254286484


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:109/200 || Disc Loss: 0.6476598395975922 || Gen Loss : 0.919373977984955


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:110/200 || Disc Loss: 0.6443984714398252 || Gen Loss : 0.9202587054863668


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:111/200 || Disc Loss: 0.6431124681857094 || Gen Loss : 0.9102796172536513


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:112/200 || Disc Loss: 0.6477170446788324 || Gen Loss : 0.8908265767448238


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:113/200 || Disc Loss: 0.6527847859905218 || Gen Loss : 0.9421266633183208


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:114/200 || Disc Loss: 0.6427087842274323 || Gen Loss : 0.9899106131814944


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:115/200 || Disc Loss: 0.645676264630706 || Gen Loss : 0.9509968737295187


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:116/200 || Disc Loss: 0.6409303817921864 || Gen Loss : 0.9220656999774071


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:117/200 || Disc Loss: 0.6537591729845319 || Gen Loss : 0.8559206201831924


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:118/200 || Disc Loss: 0.6621607235753968 || Gen Loss : 0.8422191066782612


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:119/200 || Disc Loss: 0.6518023120823191 || Gen Loss : 0.85958777098005


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:120/200 || Disc Loss: 0.6638145093470494 || Gen Loss : 0.871382007847971


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:121/200 || Disc Loss: 0.6520755682418595 || Gen Loss : 0.87370193811622


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:122/200 || Disc Loss: 0.654213923126904 || Gen Loss : 0.8904824114557522


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:123/200 || Disc Loss: 0.6547912591810165 || Gen Loss : 0.8946315088887205


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:124/200 || Disc Loss: 0.6527042382561576 || Gen Loss : 0.9049471689185609


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:125/200 || Disc Loss: 0.6583102504327607 || Gen Loss : 0.8897551779808012


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:126/200 || Disc Loss: 0.6481239425856421 || Gen Loss : 0.910632451396507


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:127/200 || Disc Loss: 0.6462410737964899 || Gen Loss : 0.9295251695459077


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:128/200 || Disc Loss: 0.6503266247350779 || Gen Loss : 0.8849503823053608


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:129/200 || Disc Loss: 0.6532371309774516 || Gen Loss : 0.8646981644986281


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:130/200 || Disc Loss: 0.6603711500350855 || Gen Loss : 0.8535297795780686


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:131/200 || Disc Loss: 0.654464724475641 || Gen Loss : 0.8867576915953459


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:132/200 || Disc Loss: 0.6576784851708646 || Gen Loss : 0.8875024380333134


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:133/200 || Disc Loss: 0.6506852818958795 || Gen Loss : 0.9113278524962061


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:134/200 || Disc Loss: 0.6538452389143677 || Gen Loss : 0.8798338939894491


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:135/200 || Disc Loss: 0.6574127480927815 || Gen Loss : 0.8541306477746984


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:136/200 || Disc Loss: 0.6591639230246229 || Gen Loss : 0.8400547571146666


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:137/200 || Disc Loss: 0.661907055357626 || Gen Loss : 0.8315784883524563


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:138/200 || Disc Loss: 0.661968173502859 || Gen Loss : 0.8385201416798492


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:139/200 || Disc Loss: 0.6540625916361046 || Gen Loss : 0.8548837810564143


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:140/200 || Disc Loss: 0.6572971209280019 || Gen Loss : 0.865273578525352


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:141/200 || Disc Loss: 0.6474513274266013 || Gen Loss : 0.8900808013324291


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:142/200 || Disc Loss: 0.6532799276207556 || Gen Loss : 0.8947275101121809


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:143/200 || Disc Loss: 0.6506068097757124 || Gen Loss : 0.9005862659991168


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:144/200 || Disc Loss: 0.6528325967951369 || Gen Loss : 0.87926749462512


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:145/200 || Disc Loss: 0.6533001849392076 || Gen Loss : 0.8606717317089089


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:146/200 || Disc Loss: 0.6574222484885502 || Gen Loss : 0.8548194485813824


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:147/200 || Disc Loss: 0.6523398313441002 || Gen Loss : 0.8646242376456637


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:148/200 || Disc Loss: 0.652002986560244 || Gen Loss : 0.8627476666146504


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:149/200 || Disc Loss: 0.6531206655349813 || Gen Loss : 0.8579670882174201


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:150/200 || Disc Loss: 0.6515815076289146 || Gen Loss : 0.86002949261462


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:151/200 || Disc Loss: 0.6545600709376305 || Gen Loss : 0.8721548146657598


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:152/200 || Disc Loss: 0.6489160637865697 || Gen Loss : 0.8760315142969078


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:153/200 || Disc Loss: 0.6568527011983176 || Gen Loss : 0.889382083088096


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:154/200 || Disc Loss: 0.6501005560096139 || Gen Loss : 0.8775047959168074


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:155/200 || Disc Loss: 0.6551526904360317 || Gen Loss : 0.8555078756199208


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:156/200 || Disc Loss: 0.6519254704019917 || Gen Loss : 0.8639814299561067


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:157/200 || Disc Loss: 0.651178237725931 || Gen Loss : 0.8630675297937414


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:158/200 || Disc Loss: 0.6598968106800559 || Gen Loss : 0.9066262187352822


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:159/200 || Disc Loss: 0.6567865537682067 || Gen Loss : 0.8880376946697357


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:160/200 || Disc Loss: 0.6527725796201336 || Gen Loss : 0.8622346020329481


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:161/200 || Disc Loss: 0.6530471911816709 || Gen Loss : 0.8571086293344559


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:162/200 || Disc Loss: 0.6497246755211592 || Gen Loss : 0.8674692560487719


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:163/200 || Disc Loss: 0.6539823663260128 || Gen Loss : 0.84429874590465


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:164/200 || Disc Loss: 0.6540381265347446 || Gen Loss : 0.858598248091842


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:165/200 || Disc Loss: 0.6543749967363598 || Gen Loss : 0.8475251835800691


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:166/200 || Disc Loss: 0.6554129836655883 || Gen Loss : 0.8706973777142669


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:167/200 || Disc Loss: 0.6524091348973419 || Gen Loss : 0.8841891799654279


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:168/200 || Disc Loss: 0.655505674480121 || Gen Loss : 0.8804770111402215


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:169/200 || Disc Loss: 0.6508912217896631 || Gen Loss : 0.8520020218546203


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:170/200 || Disc Loss: 0.6525765726052876 || Gen Loss : 0.8639882985971121


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:171/200 || Disc Loss: 0.6480821956957835 || Gen Loss : 0.8745612273973696


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:172/200 || Disc Loss: 0.6501210495861355 || Gen Loss : 0.8636534575587397


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:173/200 || Disc Loss: 0.6545008006634743 || Gen Loss : 0.8526143354774792


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:174/200 || Disc Loss: 0.6546251761125349 || Gen Loss : 0.8436771046632389


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:175/200 || Disc Loss: 0.6542411282626804 || Gen Loss : 0.86575144808938


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:176/200 || Disc Loss: 0.6483564168405431 || Gen Loss : 0.8797941314004886


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:177/200 || Disc Loss: 0.6508744191259209 || Gen Loss : 0.8586900947825995


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:178/200 || Disc Loss: 0.65149741144831 || Gen Loss : 0.8674867169053824


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:179/200 || Disc Loss: 0.6477557821060295 || Gen Loss : 0.8713098740908128


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:180/200 || Disc Loss: 0.6485202236216205 || Gen Loss : 0.8665701905166162


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:181/200 || Disc Loss: 0.6472511853236379 || Gen Loss : 0.8743118650750565


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:182/200 || Disc Loss: 0.6514857657936844 || Gen Loss : 0.8616876276825537


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:183/200 || Disc Loss: 0.6453692832989479 || Gen Loss : 0.8679358682144426


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:184/200 || Disc Loss: 0.650584371105186 || Gen Loss : 0.8617275249220924


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:185/200 || Disc Loss: 0.6455188453324568 || Gen Loss : 0.8721440567899106


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:186/200 || Disc Loss: 0.6468222929216397 || Gen Loss : 0.8699577374498981


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:187/200 || Disc Loss: 0.6454171323572903 || Gen Loss : 0.883594560089396


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:188/200 || Disc Loss: 0.6446111781764895 || Gen Loss : 0.8850846941918452


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:189/200 || Disc Loss: 0.6472735243565493 || Gen Loss : 0.878437474719497


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:190/200 || Disc Loss: 0.6438693984993485 || Gen Loss : 0.8642200957229142


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:191/200 || Disc Loss: 0.6430340654560244 || Gen Loss : 0.8727532932753248


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:192/200 || Disc Loss: 0.6476270579325873 || Gen Loss : 0.8766417475397399


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:193/200 || Disc Loss: 0.6432969299460779 || Gen Loss : 0.8832931010199508


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:194/200 || Disc Loss: 0.6420051568606769 || Gen Loss : 0.8823439098243266


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:195/200 || Disc Loss: 0.6427634153793107 || Gen Loss : 0.8944445549170854


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:196/200 || Disc Loss: 0.6337229205347074 || Gen Loss : 0.9089722222229565


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:197/200 || Disc Loss: 0.6436653449845466 || Gen Loss : 0.8843934754572952


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:198/200 || Disc Loss: 0.637661492773719 || Gen Loss : 0.8883445601600574


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:199/200 || Disc Loss: 0.6377832948017731 || Gen Loss : 0.8895135456438004


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


Epoch:200/200 || Disc Loss: 0.6374745320663777 || Gen Loss : 0.8913704658240906


In [21]:
# fig = go.Figure(data = [
#     go.Scatter(name="Discriminator", x=[i for i in range(len(d_losses))], y=d_losses),
#     go.Scatter(name="Generator", x=[i for i in range(len(g_losses))], y=g_losses)
# ])
fig = go.Figure()
fig.add_trace(go.Scatter(name="Discriminator", x=[i for i in range(len(d_losses))], 
                         y=d_losses))

fig.add_trace(go.Scatter(name="Generator", x=[i for i in range(len(g_losses))], 
                         y=g_losses))

fig.update_layout(title=" Epochs vs Losses",
                 xaxis_title="Epochs",
                 yaxis_title="Loss", )
fig.show()