In [93]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
import torchvision.datasets as dset
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [97]:
class DatasetLoop(torch.utils.data.Dataset):
    def __init__(self, ndim, nsamples, loop_type="raw"):
        self.ndim = ndim
        self.nsamples = nsamples
        self.loop_type = loop_type
        self.dir = "data/{0}/{0}_{1}/".format(self.loop_type, self.ndim)
        self.loops = {}
        self.label = 0
        
        self.get_loops()
        
    def get_loops(self):
        for i in range(self.nsamples):
            try:
                self.loops[i] = (np.loadtxt(self.dir + "{}.loop".format(i)).astype("int"), self.label)
            except OSError as e:
                print(self.dir + "{}.loop".format(i) + " loop not found.")
           
    def __getitem__(self, i):
        return self.loops[i]
    
    def __len__(self):
        return self.nsamples
    

In [98]:
dataset = DatasetLoop(opt.img_size, 10000)

In [99]:
# os.makedirs("images", exist_ok=True)

# parser = argparse.ArgumentParser()
# parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
# parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
# parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
# parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
# parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
# parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
# parser.add_argument("--channels", type=int, default=1, help="number of image channels")
# parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
# opt = parser.parse_args()
# print(opt)

class Opt:
    def __init__(self):
        self.n_epochs = 10
        self.batch_size = 128
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 8
        self.latent_dim = 20
        self.img_size = 8
        self.channels = 1
        self.sample_interval = 100

opt = Opt()
img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
# os.makedirs("../../data/mnist", exist_ok=True)
# dataroot = "data/raw/raw_8_png"
# dataset = dset.ImageFolder(root=dataroot,
#                            transform=transforms.Compose([
#                                transforms.Grayscale(),
#                                transforms.ToTensor(),
#                            ]))

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        
        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)


[Epoch 0/10] [Batch 0/79] [D loss: 0.626474] [G loss: 0.706296]
[Epoch 0/10] [Batch 1/79] [D loss: 0.570631] [G loss: 0.705580]
[Epoch 0/10] [Batch 2/79] [D loss: 0.522406] [G loss: 0.705029]
[Epoch 0/10] [Batch 3/79] [D loss: 0.487363] [G loss: 0.704383]
[Epoch 0/10] [Batch 4/79] [D loss: 0.460753] [G loss: 0.703659]
[Epoch 0/10] [Batch 5/79] [D loss: 0.440937] [G loss: 0.702783]
[Epoch 0/10] [Batch 6/79] [D loss: 0.419770] [G loss: 0.702260]
[Epoch 0/10] [Batch 7/79] [D loss: 0.409598] [G loss: 0.700072]
[Epoch 0/10] [Batch 8/79] [D loss: 0.399852] [G loss: 0.698955]
[Epoch 0/10] [Batch 9/79] [D loss: 0.392688] [G loss: 0.697356]
[Epoch 0/10] [Batch 10/79] [D loss: 0.388570] [G loss: 0.695473]
[Epoch 0/10] [Batch 11/79] [D loss: 0.384205] [G loss: 0.692455]
[Epoch 0/10] [Batch 12/79] [D loss: 0.381006] [G loss: 0.689505]
[Epoch 0/10] [Batch 13/79] [D loss: 0.379109] [G loss: 0.686606]
[Epoch 0/10] [Batch 14/79] [D loss: 0.377874] [G loss: 0.681837]
[Epoch 0/10] [Batch 15/79] [D loss:

[Epoch 1/10] [Batch 51/79] [D loss: 0.282806] [G loss: 0.929213]
[Epoch 1/10] [Batch 52/79] [D loss: 0.285075] [G loss: 0.924746]
[Epoch 1/10] [Batch 53/79] [D loss: 0.280509] [G loss: 0.919785]
[Epoch 1/10] [Batch 54/79] [D loss: 0.276718] [G loss: 0.938131]
[Epoch 1/10] [Batch 55/79] [D loss: 0.268490] [G loss: 0.968905]
[Epoch 1/10] [Batch 56/79] [D loss: 0.265905] [G loss: 0.985333]
[Epoch 1/10] [Batch 57/79] [D loss: 0.262630] [G loss: 0.978073]
[Epoch 1/10] [Batch 58/79] [D loss: 0.256958] [G loss: 0.995288]
[Epoch 1/10] [Batch 59/79] [D loss: 0.252840] [G loss: 1.016137]
[Epoch 1/10] [Batch 60/79] [D loss: 0.248391] [G loss: 1.044282]
[Epoch 1/10] [Batch 61/79] [D loss: 0.254468] [G loss: 1.027235]
[Epoch 1/10] [Batch 62/79] [D loss: 0.254051] [G loss: 1.010519]
[Epoch 1/10] [Batch 63/79] [D loss: 0.258900] [G loss: 1.000717]
[Epoch 1/10] [Batch 64/79] [D loss: 0.257402] [G loss: 1.017622]
[Epoch 1/10] [Batch 65/79] [D loss: 0.256615] [G loss: 1.014609]
[Epoch 1/10] [Batch 66/79

[Epoch 3/10] [Batch 25/79] [D loss: 0.053580] [G loss: 2.369500]
[Epoch 3/10] [Batch 26/79] [D loss: 0.053778] [G loss: 2.369090]
[Epoch 3/10] [Batch 27/79] [D loss: 0.052457] [G loss: 2.376623]
[Epoch 3/10] [Batch 28/79] [D loss: 0.051094] [G loss: 2.402036]
[Epoch 3/10] [Batch 29/79] [D loss: 0.050422] [G loss: 2.438753]
[Epoch 3/10] [Batch 30/79] [D loss: 0.048711] [G loss: 2.460680]
[Epoch 3/10] [Batch 31/79] [D loss: 0.048146] [G loss: 2.480382]
[Epoch 3/10] [Batch 32/79] [D loss: 0.046464] [G loss: 2.501016]
[Epoch 3/10] [Batch 33/79] [D loss: 0.045299] [G loss: 2.527896]
[Epoch 3/10] [Batch 34/79] [D loss: 0.043213] [G loss: 2.567840]
[Epoch 3/10] [Batch 35/79] [D loss: 0.042387] [G loss: 2.597712]
[Epoch 3/10] [Batch 36/79] [D loss: 0.041884] [G loss: 2.612894]
[Epoch 3/10] [Batch 37/79] [D loss: 0.041345] [G loss: 2.617947]
[Epoch 3/10] [Batch 38/79] [D loss: 0.041069] [G loss: 2.623634]
[Epoch 3/10] [Batch 39/79] [D loss: 0.041015] [G loss: 2.632008]
[Epoch 3/10] [Batch 40/79

[Epoch 4/10] [Batch 73/79] [D loss: 0.013673] [G loss: 3.722352]
[Epoch 4/10] [Batch 74/79] [D loss: 0.013455] [G loss: 3.727642]
[Epoch 4/10] [Batch 75/79] [D loss: 0.013121] [G loss: 3.744153]
[Epoch 4/10] [Batch 76/79] [D loss: 0.013089] [G loss: 3.763997]
[Epoch 4/10] [Batch 77/79] [D loss: 0.012567] [G loss: 3.786461]
[Epoch 4/10] [Batch 78/79] [D loss: 0.012084] [G loss: 3.816429]
[Epoch 5/10] [Batch 0/79] [D loss: 0.012188] [G loss: 3.838733]
[Epoch 5/10] [Batch 1/79] [D loss: 0.012145] [G loss: 3.841089]
[Epoch 5/10] [Batch 2/79] [D loss: 0.012124] [G loss: 3.840488]
[Epoch 5/10] [Batch 3/79] [D loss: 0.012047] [G loss: 3.838442]
[Epoch 5/10] [Batch 4/79] [D loss: 0.011912] [G loss: 3.862681]
[Epoch 5/10] [Batch 5/79] [D loss: 0.011753] [G loss: 3.870749]
[Epoch 5/10] [Batch 6/79] [D loss: 0.011604] [G loss: 3.881207]
[Epoch 5/10] [Batch 7/79] [D loss: 0.011643] [G loss: 3.887780]
[Epoch 5/10] [Batch 8/79] [D loss: 0.011656] [G loss: 3.896701]
[Epoch 5/10] [Batch 9/79] [D loss:

[Epoch 6/10] [Batch 42/79] [D loss: 0.005677] [G loss: 4.590944]
[Epoch 6/10] [Batch 43/79] [D loss: 0.005697] [G loss: 4.592676]
[Epoch 6/10] [Batch 44/79] [D loss: 0.005531] [G loss: 4.604144]
[Epoch 6/10] [Batch 45/79] [D loss: 0.005612] [G loss: 4.593297]
[Epoch 6/10] [Batch 46/79] [D loss: 0.005541] [G loss: 4.623785]
[Epoch 6/10] [Batch 47/79] [D loss: 0.005423] [G loss: 4.634469]
[Epoch 6/10] [Batch 48/79] [D loss: 0.005538] [G loss: 4.612797]
[Epoch 6/10] [Batch 49/79] [D loss: 0.005438] [G loss: 4.622783]
[Epoch 6/10] [Batch 50/79] [D loss: 0.005299] [G loss: 4.646350]
[Epoch 6/10] [Batch 51/79] [D loss: 0.005444] [G loss: 4.635669]
[Epoch 6/10] [Batch 52/79] [D loss: 0.005472] [G loss: 4.631164]
[Epoch 6/10] [Batch 53/79] [D loss: 0.005425] [G loss: 4.637330]
[Epoch 6/10] [Batch 54/79] [D loss: 0.005540] [G loss: 4.605769]
[Epoch 6/10] [Batch 55/79] [D loss: 0.005553] [G loss: 4.615571]
[Epoch 6/10] [Batch 56/79] [D loss: 0.005430] [G loss: 4.621800]
[Epoch 6/10] [Batch 57/79

[Epoch 8/10] [Batch 11/79] [D loss: 0.003098] [G loss: 5.175171]
[Epoch 8/10] [Batch 12/79] [D loss: 0.003076] [G loss: 5.174009]
[Epoch 8/10] [Batch 13/79] [D loss: 0.003015] [G loss: 5.187126]
[Epoch 8/10] [Batch 14/79] [D loss: 0.003009] [G loss: 5.198951]
[Epoch 8/10] [Batch 15/79] [D loss: 0.003025] [G loss: 5.196106]
[Epoch 8/10] [Batch 16/79] [D loss: 0.003019] [G loss: 5.199329]
[Epoch 8/10] [Batch 17/79] [D loss: 0.002992] [G loss: 5.221085]
[Epoch 8/10] [Batch 18/79] [D loss: 0.002970] [G loss: 5.212900]
[Epoch 8/10] [Batch 19/79] [D loss: 0.002992] [G loss: 5.205290]
[Epoch 8/10] [Batch 20/79] [D loss: 0.002985] [G loss: 5.212064]
[Epoch 8/10] [Batch 21/79] [D loss: 0.002946] [G loss: 5.220957]
[Epoch 8/10] [Batch 22/79] [D loss: 0.002953] [G loss: 5.214458]
[Epoch 8/10] [Batch 23/79] [D loss: 0.002943] [G loss: 5.227901]
[Epoch 8/10] [Batch 24/79] [D loss: 0.002944] [G loss: 5.227380]
[Epoch 8/10] [Batch 25/79] [D loss: 0.002884] [G loss: 5.235212]
[Epoch 8/10] [Batch 26/79

[Epoch 9/10] [Batch 61/79] [D loss: 0.002049] [G loss: 5.604762]
[Epoch 9/10] [Batch 62/79] [D loss: 0.002082] [G loss: 5.593681]
[Epoch 9/10] [Batch 63/79] [D loss: 0.002089] [G loss: 5.581590]
[Epoch 9/10] [Batch 64/79] [D loss: 0.002115] [G loss: 5.564958]
[Epoch 9/10] [Batch 65/79] [D loss: 0.002132] [G loss: 5.567775]
[Epoch 9/10] [Batch 66/79] [D loss: 0.002140] [G loss: 5.562744]
[Epoch 9/10] [Batch 67/79] [D loss: 0.002141] [G loss: 5.574406]
[Epoch 9/10] [Batch 68/79] [D loss: 0.002171] [G loss: 5.548184]
[Epoch 9/10] [Batch 69/79] [D loss: 0.002132] [G loss: 5.563546]
[Epoch 9/10] [Batch 70/79] [D loss: 0.002136] [G loss: 5.559607]
[Epoch 9/10] [Batch 71/79] [D loss: 0.002160] [G loss: 5.545588]
[Epoch 9/10] [Batch 72/79] [D loss: 0.002119] [G loss: 5.576668]
[Epoch 9/10] [Batch 73/79] [D loss: 0.002079] [G loss: 5.597938]
[Epoch 9/10] [Batch 74/79] [D loss: 0.002132] [G loss: 5.582868]
[Epoch 9/10] [Batch 75/79] [D loss: 0.002112] [G loss: 5.577043]
[Epoch 9/10] [Batch 76/79

In [None]:
print(gen_imgs.shape)

def change_a_value(arr):
    arr[2, 2] = 10
    return arr

imgs = np.array([change_a_value(dataset[i][0]) for i in range(20)]).astype('float')
a = discriminator(torch.tensor(imgs, dtype=torch.float))
a

In [100]:
from lib.loop_gen import *

N = 8

def renormalize(arr):
    loop = np.zeros((N, N))
    base_elts = arr[0]
    for i in range(N):
        for j in range(N):
            loop[i, j] = np.argmin(np.abs(arr[i, j] - base_elts))
    return loop

for tensor in gen_imgs:
    arr = tensor.detach().numpy()[0]
    loop_arr = renormalize(arr)
    my_loop = LoopModel(loop_arr.astype('int'))
    print("It's a loop :", my_loop.is_loop())

It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
It's a loop : False
