In [1]:
##GAN

In [2]:
#create dataloaders
from random import randint as ri
import torch as torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('zero', 'one', 'two', 'three',
           'four', 'five', 'six', 'seven', 'eight', 'nine')

In [3]:
#check if data is loaded properly
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('      %4s' % classes[labels[j]] for j in range(4)))

<Figure size 640x480 with 1 Axes>

       one       five       three       seven


In [4]:
#Generator model 
import torch.nn as nn
import torch.nn.functional as F


class Generator(nn.Module):
    def __init__(self):
        
        super(Generator, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = 1, kernel_size = (7), stride=1, padding = 26),      #Apply a convolutional block
            nn.ReLU(),
            nn.BatchNorm2d(1),
            nn.Conv2d(in_channels = 1, out_channels = 1, kernel_size = (2), stride=1, padding = 1),
            nn.MaxPool2d(2, 2),
        )
        
    def forward(self, x):
        x = self.block(x);
        
        return abs(x%1)

In [5]:
#Discriminator model 

class Discriminator(nn.Module):
    def __init__(self):
        
        super(Discriminator, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(1, 3, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
        )
        
        self.fc1 = nn.Linear( 24*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)         

    def forward(self, x):
        x = self.block(x);
        x = x.view(-1,  24*4)
        x = F.relu(self.fc1(x));
        x = F.relu(self.fc2(x));
        x = self.fc3(x);
        #_, predicted = torch.max(x,1);
        return x


In [6]:
#Create an instance of the generator and discriminator
gen = Generator()
dis = Discriminator()

In [7]:
#create optimization function and loss function for both generator and discriminator
import torch.optim as optim

criteriondis = nn.CrossEntropyLoss()
criteriongen = nn.MSELoss()
gen_optimizer = optim.Adam(gen.parameters(), lr=0.00005)
dis_optimizer = optim.Adam(dis.parameters(), lr = 0.00005)

In [None]:
#Train the GAN
for epoch in range(1):

    running_loss_real = 0.0
    running_loss_fake = 0.0
    running_loss_gen  = 0.0
    
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        real_images, real_labels = data
        ones = torch.ones(real_labels.shape,dtype = torch.int64)
        zeros = torch.zeros(real_labels.shape,dtype = torch.int64)
        gen_ip = torch.randn((4,1,10,10))


        # zero the parameter gradients
        gen_optimizer.zero_grad()
        dis_optimizer.zero_grad()

        # forward + backward + optimize
        gen_op = gen(gen_ip)
        dis_op = dis(real_images)
        
        #Minimize discriminator loss against ones
        real_output = dis(real_images)
        real_loss = criteriondis(real_output.float(), ones)
        real_loss.backward()
        dis_optimizer.step()
        
        #Minimize discriminator loss against zeros
        fake_images = gen(gen_ip)
        fake_output = dis(fake_images)
        fake_loss = criteriondis(fake_output.float(), zeros)
        fake_loss.backward()
        dis_optimizer.step()
        
        #Minimize generator loss
        fake_images = gen(gen_ip)
        gen_loss = criteriongen(fake_images.float(), real_images.float())
        gen_loss.backward()
        gen_optimizer.step()

        # print statistics
        running_loss_real += real_loss.item()
        running_loss_real += fake_loss.item()
        running_loss_gen += gen_loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('Epoch: %d | No of images: %5d | real_loss: %.3f | fake_loss: %.3f | gen_loss: %.3f' %
                  (epoch + 1, i + 1, running_loss_real / 2000, running_loss_fake / 2000, running_loss_gen / 2000))
            running_loss_real = 0.0
            running_loss_fake = 0.0
            running_loss_gen  = 0.0

print('Finished Training')

In [None]:
#Save the learned weights
#conv1_bias = gen.block[0].bias
#conv2_weight = gen.block[0].weight
#conv2_bias = gen.block[0].bias


In [None]:
#validate over test set
dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

In [None]:
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = gen(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

In [None]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))