In [25]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torchviz import make_dot
from torch.autograd import Variable

Heavily adapted from https://github.com/devnag/pytorch-generative-adversarial-networks/blob/master/gan_pytorch.py

In [103]:
x = np.load('x_10000-18000.npy')
x_1 = x[:1030]
x_2 = x[1030:2060]
y_1 = np.load('y_10000-18000.npy')

In [9]:
# here, we get the real data to sample against
def get_data(numpy_arr):
    return torch.Tensor(numpy_arr)

In [85]:
# generates potential real images based on the noisy data
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.f = f
    
    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        return x
    

In [96]:
# tries to tell apart real data vs fake looking data
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, 4000)
        self.f = f
    
    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        return x

In [118]:
def extract(v):
    return v.data.storage().tolist()

def get_stat(v):
    return[np.mean(v), np.std(v)]

In [104]:
# define the constants
num_epochs = 10
input_size = 1030
data_output_size = 784 # image size
clean_data = get_data(x_1)
noisy_data = get_data(x_2)

In [None]:
data_input_size = input_size
sgd_momentum = 0.9
# generator info
g_activation_fn = torch.sigmoid
g_in_size = input_size
g_hid_size = 1030
# discrimintator info
d_activation_fn = torch.sigmoid
d_hid_size = 900

# generator
G = Generator(input_size=data_input_size, hidden_size=g_hid_size, output_size=data_output_size, 
              f=g_activation_fn)
D = Discriminator(input_size=data_input_size, hidden_size=900, output_size=data_output_size, 
              f=d_activation_fn)

d_learning_rate = 1e-3
g_learning_rate = 1e-3
criterion = nn.MSELoss()
d_optimizer = optim.SGD(D.parameters(), lr= d_learning_rate, momentum=sgd_momentum)
g_optimizer = optim.SGD(D.parameters(), lr= g_learning_rate, momentum=sgd_momentum)

d_steps = 2
g_steps = 2

for epoch in range(num_epochs):
    # training the detective
    for i in range(d_steps):
        D.zero_grad()
        
        # on real 
        d_r_data = Variable(clean_data) # get sample of data
        d_r_decision = D(d_r_data)
        d_r_error = criterion(d_r_decision,  Variable(torch.zeros([len(d_r_decision),4000])))
        d_r_error.backward() # Compute/store gradients, but don't change params
        
        # on fake
        d_gen_input = Variable(noisy_data)
        d_fake_data = G(d_gen_input).detach()
        print(len(d_fake_data[0]))
        d_fake_decision = D(d_fake_data.t())
        d_fake_error = criterion(d_fake_decision, Variable(torch.ones([len(d_fake_decision), \
                                        len(d_fake_decision[0])])))
        d_fake_error.backward()
        d_optimizer.step()
        d_re, d_fe = extract(d_r_error)[0], extract(d_fake_error)[0]
    
    # training the forger
    for j in range(g_steps):
        G.zero_grad()
        
        gen_input = Variable(noisy_data)
        g_fake_data = G(gen_input)
        dg_f_decision = D(g_fake_data.t())
        g_error = criterion(dg_f_decision, Variable(torch.ones([len(dg_f_decision), \
                                        len(dg_f_decision[0])])))
        g_error.backward()
        g_optimizer.step()
        ge = extract(g_error)[0]
    
    if epoch % 1 == 0:
        print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %
            (epoch, d_re, d_fe, ge, get_stat(extract(d_r_data)), get_stat(extract(d_fake_data))))

1030
1030
Epoch 0: D (0.25720641016960144 real_err, 0.25292137265205383 fake_err) G (0.25291723012924194 err); Real Dist ([0.1362411574090119, 0.3380443160561318]),  Fake Dist ([0.5028033388217754, 0.07119051639517802]) 
1030
1030


KeyboardInterrupt: 