In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import autograd
from torch.autograd import Variable
import torch.nn.functional as F
import cv2
import random as rand

In [14]:
INPUT, OUT = 28**2, 28**2
HIDDEN_1, HIDDEN_2 = int(0.8*INPUT), int(0.8*INPUT)
CLASSES = 2

class Discriminator(torch.nn.Module):
    '''
     Network to discriminate between real and fake images.
    '''
    def __init__(self):
        super().__init__()
        self.l_one = torch.nn.Conv2d(1, 8, kernel_size=12, stride=1)
        self.l_two = torch.nn.Conv2d(8, 16, kernel_size=12, stride=1)
        self.l_three = torch.nn.Conv2d(16, 2, kernel_size=1, stride=1)
        
    def forward(self, x):
        x = x.unsqueeze(1)
        x = F.relu(self.l_one(x))
        x = F.relu(self.l_two(x))
        n, c, h, w = x.size()
        x = F.avg_pool2d(x, kernel_size=[h, w])
        x = self.l_three(x).view(-1, CLASSES)
        return x
    
class Generator(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.l_one = torch.nn.Conv2d(1, 3, kernel_size=4, stride=1, padding=2) 
        self.l_two = torch.nn.Conv2d(3, 12, kernel_size=8, stride=1, padding=4) 
        self.l_three = torch.nn.Conv2d(12, 24, kernel_size=4, stride=1, padding=2) 
        self.l_four = torch.nn.Conv2d(24, 1, kernel_size=4, stride=1, padding=2) 
        
    def forward(self, x):
        x = x.view(28, 28).unsqueeze(0).unsqueeze(1)
        x = F.relu(self.l_one(x))
        print(x.data.shape)
        x = F.relu(self.l_two(x))
        print(x.data.shape)
        x = F.relu(self.l_three(x))
        print(x.data.shape)
        x = F.relu(self.l_four(x))
        print(x.data.shape)
        x = x.view(-1, 28, 28)
        return x.squeeze()
'''
class Generator(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.l_one = torch.nn.Linear(INPUT, HIDDEN_1)
        self.l_two = torch.nn.Linear(HIDDEN_1, HIDDEN_2)
        self.l_three = torch.nn.Linear(HIDDEN_2, OUT)
        
    def forward(self, x):
        x = F.relu(self.l_one(x))
        x = F.relu(self.l_two(x))
        x = F.relu(self.l_three(x))
        return x.view(-1, 28, 28)
'''

'\nclass Generator(torch.nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        self.l_one = torch.nn.Linear(INPUT, HIDDEN_1)\n        self.l_two = torch.nn.Linear(HIDDEN_1, HIDDEN_2)\n        self.l_three = torch.nn.Linear(HIDDEN_2, OUT)\n        \n    def forward(self, x):\n        x = F.relu(self.l_one(x))\n        x = F.relu(self.l_two(x))\n        x = F.relu(self.l_three(x))\n        return x.view(-1, 28, 28)\n'

In [15]:
images = np.load('minst_test_images.npy')
N = len(images)
images = (images.reshape((N, 28, 28)) / 255).astype(np.float_) # reshape data to 28 x 28 image and normalize
loss = torch.nn.CrossEntropyLoss()

In [16]:
def train_disc(batch_size):
    # train Discriminator on real data
    dis_model.zero_grad()
    perm = np.random.choice(N, size=batch_size, replace=False)
    real_data = Variable(torch.Tensor(images[perm]))
    real_out = dis_model(real_data)
    real_error = loss(real_out, Variable(torch.LongTensor(np.ones(batch_size))))
    real_error.backward()
    # train Discriminator on generated data
    input_ = Variable(torch.Tensor(np.random.randn(batch_size, INPUT)))
    fake_data = gen_model(input_).detach()
    fake_out = dis_model(fake_data)
    fake_error = loss(fake_out, Variable(torch.LongTensor(np.zeros(batch_size))))
    fake_error.backward()
    dis_optim.step()
    return real_error.data[0], fake_error.data[0]

def train_gen(batch_size):
    gen_model.zero_grad()
    input_ = Variable(torch.Tensor(np.random.randn(batch_size, INPUT)))
    gen_data = gen_model(input_)
    gen_out = dis_model(gen_data)
    gen_error = loss(gen_out, Variable(torch.LongTensor(np.ones(batch_size))))
    gen_error.backward()
    gen_optim.step()
    return gen_error.data[0]

def get_dis_class():
    dis_model.eval()
    perm = np.random.choice(N, size=100, replace=False)
    data = Variable(torch.Tensor(images[perm]))
    max_, out = torch.max(dis_model(data), dim=1)
    return (out.data.numpy() == np.ones(100)).astype(float).mean()

def get_gen_class():
    gen_model.eval()
    dis_model.eval()
    input_ = Variable(torch.Tensor(np.random.randn(100, INPUT)))
    fake_data = gen_model(input_)
    max_, out = torch.max(dis_model(fake_data), dim=1)
    return (out.data.numpy() == np.ones(100)).astype(float).mean()

In [17]:
lr = 1e-5
gen_model = Generator()
dis_model = Discriminator()
gen_optim = torch.optim.Adam(gen_model.parameters(), lr)
dis_optim = torch.optim.Adam(dis_model.parameters(), lr)

In [18]:
ITERS, BATCH_SIZE = 30000, 100
for i in range(ITERS):
    if i % 1000 == 0:
        print('Iteration %d' % (i))
        input_ = Variable(torch.Tensor(np.random.randn(INPUT)))
        img = gen_model(input_).data.numpy()[0]
        plt.imshow(img)
        plt.show()
    train_disc(BATCH_SIZE)
    train_gen(BATCH_SIZE)
print('TESTING')
input_ = Variable(torch.Tensor(np.random.randn(INPUT)))
img = gen_model(input_).data.numpy()[0]
plt.imshow(img)
plt.show()

Iteration 0


RuntimeError: invalid argument 2: size '[-1 x 28 x 28]' is invalid for input with 1024 elements at /opt/conda/conda-bld/pytorch_1512387374934/work/torch/lib/TH/THStorage.c:37