In [1]:
import torch
import torchvision
from torchvision import transforms, datasets, models
from torchvision.utils import save_image

import torch.optim as optim


import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import math


import os 
import glob
import time
import cv2

from tqdm.notebook import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import imageio



In [2]:
torch.cuda.is_available()
device = torch.device("cuda: 0")

In [3]:
train = datasets.MNIST("", train = True, download = True, transform = transforms.Compose([transforms.ToTensor()]))

In [4]:
test = datasets.MNIST("", train = False, download = True, transform = transforms.Compose([transforms.ToTensor()]))

In [5]:
len(train)

60000

In [6]:
#hyper params
#I HAVE FLIPPED LABELS
bs = 100
gIns = 100
EPOCHS = 500
lrG = 0.0002
lrD = 0.0002
labelSmoothing = 0.1 # keep in [0,1)
noiseAmp = 1  # should i? https://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/
loss_fn = nn.BCELoss()


In [7]:
trainset = torch.utils.data.DataLoader(train, batch_size=bs, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=bs, shuffle=True)




In [8]:
class D(nn.Module):
    def __init__(self):
        super(D, self).__init__() 
        self.fc1 = nn.Linear((28*28), 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = self.fc4(x)
                         
        return torch.sigmoid(x)
    
    
    
            
dsk = D().to(device)
#dsk.load_state_dict(torch.load(".genT.pth"))
print(dsk)
optimizerD = optim.Adam(dsk.parameters(), lr=lrD)


D(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)


In [9]:
class G(nn.Module):
    def __init__(self):
        super().__init__() 
        self.fc1 = nn.Linear(gIns, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, 784)

    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        #x = F.dropout(x, 0.5)
        x = F.leaky_relu(self.fc2(x), 0.2)
        #x = F.dropout(x, 0.5)
        x = F.leaky_relu(self.fc3(x), 0.2)
        #x = F.dropout(x, 0.5)
        x = self.fc4(x)
                         
        return torch.tanh(x)
    
    
    
            
gen = G().to(device)
print(gen)
optimizerG = optim.Adam(gen.parameters(), lr=lrG)



G(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)


In [10]:
#clearing image folder
files = glob.glob('./mnistGanImgG/*')
for file in files:
    os.remove(file)


with open("toyGanMnistTuningLog.log", "w") as f:
    
    f.write(f"time, dskOnReal, dskOnFake, epoch\n")
    
    seed = torch.randn(64, 1, gIns).to(device)
    
    
    for epoch in tqdm(range(EPOCHS)):
        #print(epoch)
        
        

        
        #saving batch_size sheet of generated images every epoch
        #https://ezgif.com/maker/ezgif-5-d0395d50-gif  CAN make gifs here
        with torch.no_grad():
            img = gen(seed)
            imgpath = f"./mnistGanImgG/{epoch}.png"
            save_image(img.view(img.size(0), 1, 28, 28), imgpath)


        
        
        
        for _, data in enumerate(trainset):


            batchOfData, _ = data 
            batchOfData = batchOfData.view(-1, 1, 28*28).to(device)
########################################################################
            dsk.zero_grad()
            dskOnReal = dsk(batchOfData)
            lossD_real = loss_fn(dskOnReal, torch.zeros(bs,1,1).to(device) + torch.ones(bs,1,1).to(device)*(labelSmoothing))
            lossD_real.backward()


            batchOfFake = gen(torch.randn(bs, 1, gIns).to(device))
            dskOnFake = dsk(batchOfFake.detach())
            #https://github.com/pytorch/examples/issues/116
            lossD_fake = loss_fn(dskOnFake, torch.ones(bs,1,1).to(device) - torch.ones(bs,1,1).to(device)*(labelSmoothing))
            lossD_fake.backward()

            #lossD = loss_fn(dskOnReal, torch.ones(100,1,1).to(device))) + loss_fn(dskOnFake, torch.zeros(100,1,1).to(device))

            optimizerD.step()
################################################################################
        
            gen.zero_grad()
            dskOnFake = dsk(batchOfFake)
            lossG = loss_fn(dskOnFake, torch.zeros(bs,1,1).to(device) + torch.ones(bs,1,1).to(device)*(labelSmoothing))
            #loss(fake, realLabel)
            lossG.backward()
            optimizerG.step()

            #mean over the batch
            
            #print(f"dskOnReal:{dskOnReal.mean().item()}   dskOnFake:{dskOnFake.mean().item()}")
            
            #dskOnReal should start close to 1 then theoretically converge to 0.5 when G gets better.
            #dskOnFake should start near 0 and converge to 0.5 as G gets better.
            f.write(f"{round(time.time(), 4)}, {round(float(dskOnReal.mean()),3)}, {round(float(dskOnFake.mean()), 4)}, {epoch}\n")

HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

KeyboardInterrupt: 

In [None]:
df = pd.read_csv("toyGanMnistTuningLog.log")
df.shape
#help(df.loc)

#plt.subplot(1,2,1)
#df[' dskOnReal'].plot(legend = True)
#df[' dskOnFake'].plot(legend = True)
#df.plot(x = ' dskOnFake', y = ' dskOnReal', legend = True)


df.rolling(window=bs)[' dskOnReal'].mean().plot(legend = True)
df.rolling(window=bs)[' dskOnFake'].mean().plot(legend = True)



In [11]:
see = gen(torch.randn(1, 1, gIns).to(device))


In [12]:
see.shape

torch.Size([1, 1, 784])

In [15]:
torch.randn(1, 1, gIns).to(device).shape

torch.Size([1, 1, 100])

In [14]:
gIns

100

In [None]:
see = see.view(28,28)

In [None]:
see.shape

In [None]:
torch.save(gen.state_dict(), ".genT.pth")

In [None]:
torch.save(dsk.state_dict(), ".dskT.pth")

In [None]:
netG = G()

In [None]:
netD = D()

In [None]:
netG = G()
netG.load_state_dict(torch.load(".genT.pth"))
netG.eval()

In [None]:
#RESPONSE TO NOISE
see = netG(torch.randn(1, 1, gIns))

see.shape

see = see.view(28, 28)

plt.imshow(see.detach().numpy(), cmap = "gray")

In [None]:
#RESPONSE TO ZEROS
see = netG(torch.zeros(1, gIns))

see.shape

see = see.view(28, 28)

plt.imshow(see.detach().numpy(), cmap = "gray")
see.min()

In [None]:
stadf = "./mnistGanImgG\\" + str(3) + ".png"


In [None]:
frames = []


for i in range(EPOCHS):
    imgpath = glob.glob(f"./mnistGanImgG\{i}.png")
    frames.append(imageio.imread(imgpath))

imageio.mimwrite('epochs.gif', frames)