In [1]:
import os
import numpy as np
import pandas as pd
import PIL.Image as Image
import matplotlib.pyplot as plt
import torch
from torch import nn,optim
from torchvision import transforms, datasets

In [2]:
traindatanp = np.loadtxt("dogs").reshape(5000,32,32,3)
traindatanp = traindatanp.astype('float32')
traindata = torch.from_numpy(traindatanp)
traindata = traindata / 255

In [3]:
class Generative(nn.Module):
    def __init__(self):
        super(Generative, self).__init__()
        
        self.linear1 = nn.Linear(16384, 8192)
        self.linear2 = nn.Linear(8192, 4096)
        self.linear3 = nn.Linear(4096, 3072)
        
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,X):
        
        X = X.view(X.size(0),-1)
        
        X = self.relu(self.linear1(X))
        X = self.relu(self.linear2(X))
        X = self.sigmoid(self.linear3(X))
        
        return X

In [4]:
class Discriminative(nn.Module):
    def __init__(self):
        super(Discriminative, self).__init__()
        
        self.linear1 = nn.Linear(3072, 1024)
        self.linear2 = nn.Linear(1024, 256)
        self.linear3 = nn.Linear(256, 1)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, X):
        
        X = X.view(X.size(0), -1)
        
        X = self.relu(self.linear1(X))
        X = self.relu(self.linear2(X))
        X = self.sigmoid(self.linear3(X))
        
        return X

In [5]:
lr = 0.001

disc = Discriminative()
disc_optimizer = optim.SGD(params=disc.parameters(), lr=lr, momentum=0.9)

loss = nn.BCELoss()

gen = Generative()
gen_optimizer = optim.SGD(params=gen.parameters(), lr=lr, momentum=0.9)

n_epochs = 10


for e in range(n_epochs):
    for i in range(50):
        
        real_images = traindata[(i*100):(i+1)*100]
        
        disc_optimizer.zero_grad()

        # Define real_images, real_labels
        real_labels = torch.full((real_images.size(0),1), 0)
        # Generate noise and define fake_labels
        noise_vectors = torch.randn(real_images.size(0), 16384, 1, 1)
        fake_labels = torch.full((real_images.size(0),1), 1)

        # Train discriminative network one step using batch of real images
        #real_images = real_images + (torch.randn(real_images.size())/10)
        output = disc(real_images)
        real_loss = loss(output, real_labels)
        real_loss.backward()
        disc_real_avg = output.mean().item()

        # Generate fake images from noise and pass them through disc. net.
        fake_images = gen(noise_vectors)
        #fake_images = fake_images + (torch.randn(fake_images.size())/10)
        output = disc(fake_images.detach())
        fake_loss = loss(output, fake_labels)
        fake_loss.backward()

        # Aggregative real and fake loss and update weights
        step_loss = real_loss + fake_loss
        disc_fake_avg = output.mean().item()
        disc_optimizer.step()

        # Train generative network
        gen_optimizer.zero_grad()

        output = disc(fake_images)
        gen_loss = loss(output, real_labels)
        gen_loss.backward()
        gen_avg = output.mean().item()
        gen_optimizer.step()

        if e % 1 == 0:
            print("Epoch: "+str(e+1)+" | Disc. Loss: "+str(step_loss.item())+" | Gen. Loss: "+str(gen_loss.item())+
             " | D(X): "+str(disc_real_avg)+" | D(G(Z)): "+str(disc_fake_avg)+" -> "+str(gen_avg))


Epoch: 1 | Disc. Loss: 1.396163821220398 | Gen. Loss: 0.7283696532249451 | D(X): 0.5206355452537537 | D(G(Z)): 0.516430139541626 -> 0.517304539680481
Epoch: 1 | Disc. Loss: 1.3945224285125732 | Gen. Loss: 0.7316297292709351 | D(X): 0.5207528471946716 | D(G(Z)): 0.5174075961112976 -> 0.5188755393028259
Epoch: 1 | Disc. Loss: 1.393312931060791 | Gen. Loss: 0.7359839081764221 | D(X): 0.521486759185791 | D(G(Z)): 0.5188178420066833 -> 0.5209659337997437
Epoch: 1 | Disc. Loss: 1.3903647661209106 | Gen. Loss: 0.741080641746521 | D(X): 0.522063136100769 | D(G(Z)): 0.5209885239601135 -> 0.5234013199806213
Epoch: 1 | Disc. Loss: 1.3871430158615112 | Gen. Loss: 0.746279239654541 | D(X): 0.5226706862449646 | D(G(Z)): 0.5233340263366699 -> 0.5258722901344299
Epoch: 1 | Disc. Loss: 1.383317232131958 | Gen. Loss: 0.7521868348121643 | D(X): 0.523200273513794 | D(G(Z)): 0.5259231328964233 -> 0.5286650657653809
Epoch: 1 | Disc. Loss: 1.3818652629852295 | Gen. Loss: 0.758110523223877 | D(X): 0.524943172

KeyboardInterrupt: 

In [None]:
noise = torch.randn(50,16384,1,1)
p = gen(noise)
fig, axes = plt.subplots(10, 5, figsize=(32,32))

for i,im in enumerate(axes.flat):
    im.imshow(p[i].view(32,32,3).detach())
