In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
import numpy as np
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import imageio


enSN = True
enGP = True

In [22]:
device = 'cuda:0'

In [23]:
from torch.nn.utils import spectral_norm as SN_
if enSN:
    SN=SN_
else:
    SN=lambda x:x
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.n_features = 36
        self.n_out = 36
        self.model = nn.Sequential(
                    nn.Linear(self.n_features, 256),
                    nn.LeakyReLU(0.01),
                    nn.Linear(256, 512),
                    nn.LeakyReLU(0.01),
                    nn.Linear(512, 1024),
                    nn.LeakyReLU(0.01),
                    nn.Linear(1024, 512),
                    nn.LeakyReLU(0.01),
                    nn.Linear(512, 256),
                    nn.LeakyReLU(0.01),
                    nn.Linear(256, self.n_out),
                    nn.Sigmoid()
                    )
    def forward(self, x):
        x = self.model(x)
        return x

class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.n_in = 36
        self.n_out = 1
        self.model = nn.Sequential(
                    SN(nn.Linear(self.n_in, 256)),
                    nn.LeakyReLU(0.01),
#                    nn.Dropout(0.3),
                    SN(nn.Linear(256, 512)),
                    nn.LeakyReLU(0.01),
#                    nn.Dropout(0.3),
                    SN(nn.Linear(512, 1024)),
                    nn.LeakyReLU(0.01),
#                    nn.Dropout(0.3),
                    SN(nn.Linear(1024, 512)),
                    nn.LeakyReLU(0.01),
#                    nn.Dropout(0.3),
                    SN(nn.Linear(512, 256)),
                    nn.LeakyReLU(0.01),
#                    nn.Dropout(0.3),
                    SN(nn.Linear(256, self.n_out))
                    )
    def forward(self, x):
        x = self.model(x)
        return x

In [24]:
generator = Generator()
critic = Critic()

generator.to(device)
critic.to(device)

alpha = 5e-5

g_optim = optim.RMSprop(generator.parameters(), lr=alpha)
c_optim = optim.RMSprop(critic.parameters(), lr=alpha)

g_losses = []
c_losses = []
images = []

row_state = 0.0

def row(n, n_features=1):
    global row_state
    random0 = np.random.rand(n,1)*2*np.pi
    forarray = np.arange(n_features)
    forones = np.ones((n,n_features))
    forrandarray = forarray*forones
    randarray = random0*forrandarray 
    random1 = np.random.rand(n,1)*2*np.pi
    output2 = random1 + randarray
    output3 = np.cos(output2)*0.5+0.5
    return torch.from_numpy(output3.astype(np.float32)).to(device)
    
def noise(n, n_features=1):
    return torch.rand(n, n_features).to(device)

In [26]:
if not enGP:
    def train_critic(optimizer, real_data, fake_data, c=0.01):
        optimizer.zero_grad()
        error_real = critic(real_data).mean()
        error_fake = critic(fake_data).mean()
        total_error = -(error_real - error_fake)
        total_error.backward()
        optimizer.step()
        for p in critic.parameters():
            p.data.clamp_(-c, c)
        return -total_error

from torch.autograd import grad as tg
if enGP:
    def train_critic(optimizer, real_data, fake_data, c=0.01):
        optimizer.zero_grad()
        error_real = critic(real_data).mean()
        error_fake = critic(fake_data).mean()
        total_error = -(error_real - error_fake)

        batch_size = real_data.size()[0]
        ep = torch.from_numpy(np.random.rand(batch_size, 1).astype(np.float32)).to(device)

        middle_data = ep * real_data + (1 - ep) * fake_data
        middle_data.requires_grad_() 
        critic_middle = critic(middle_data)

        grad = tg(outputs=critic_middle, inputs=middle_data, grad_outputs=torch.ones([batch_size,1]).to(device), create_graph=True, retain_graph=True)[0]
        GP_COEF = 5.
        loss= total_error + GP_COEF * (torch.norm(grad.view(batch_size, -1), dim=1)**2).mean()
        loss.backward()
        optimizer.step()
        return -total_error

def train_generator(optimizer, fake_data):
    optimizer.zero_grad()
    error = -critic(fake_data).mean()
    error.backward()
    optimizer.step()
    return error

In [None]:
num_epochs = 900001
n_critic = 5

generator.train()
critic.train()
for epoch in range(num_epochs):
    g_error = 0.0
    c_error = 0.0
    STEP_PER_EPOCH = 1
    for i in range(n_critic*STEP_PER_EPOCH):
        n = 64
        fake_data = generator(row(n,36)).detach()
        real_data = noise(n,36)
        c_error += train_critic(c_optim, real_data, fake_data)
        if (i+1)%n_critic==0:
            fake_data = generator(row(n,36))
            g_error += train_generator(g_optim, fake_data)
    if epoch%100==0:
        img = generator(row(n,36)).cpu().detach()
        img = make_grid(img)
        to_image = lambda x: (np.clip(x.numpy().copy().transpose(1,2,0),0,1)*255).astype(np.uint8)
        img = to_image(img)
        images.append(img)
        c_losses.append(c_error)
        
        r_img = to_image(make_grid(real_data.cpu().detach()))
                
        from IPython.display import clear_output
        clear_output()
        plt.clf()
        plt.plot(c_losses, label='Critic Losses')
        plt.plot([0,len(g_losses)],[0,0])
        plt.yscale("log")
        plt.legend()
        plt.savefig('loss.png', dpi=300)

        

    if epoch%3000==0:
        #Gの乱数出力
        with open('G_output_10_'+str(epoch)+'.txt', 'a') as G_output:
            for num in range(80000):
                rand_img = generator(row(64,36)).cpu().detach().numpy().copy().flatten() 
                np.savetxt(G_output,rand_img)
        G_output.close()
        #output random number pictures
        rand_img = generator(row(64,36)).cpu().detach()
        rand_img_real = torch.rand(64,1,6,6).cpu().detach()
        #to array
        im_list = np.asarray(rand_img.view(-1,6*6))
        im_list_real = np.asarray(rand_img_real.view(-1,6*6))
        plt.figure(figsize=(10,10))
        plt.subplot(1,2,1)
        plt.imshow(im_list,cmap='gray',vmin=0, vmax=1)
        plt.colorbar()
        plt.subplot(1,2,2)
        plt.imshow(im_list_real,cmap='gray',vmin=0, vmax=1)
        plt.colorbar()
        plt.savefig('rand_image_'+str(epoch)+'.jpg')

        
print('Training Finished')
torch.save(generator.state_dict(), 'mnist_generator.pth')




In [None]:
#Code to convert a random number in decimal to a random number in binary
from ctypes import *
class POINT(Union):
    _fields_ = [("x", c_float), ("y", c_uint)]
    
for num in range(0,900001,3000):    
    read_data = open('G_output_10_'+str(num)+'.txt', 'r')
    write_data = open('G_output_2_13to19bit/G_output_2_'+str(num)+'.txt', 'w')

    for line in read_data:

        point = POINT(float(line))
        write_data.write(bin(point.y)[13:20])


    read_data.close()
    write_data.close()