In [1]:
gpu_info = !nvidia-smi
gpu_info = gpu_info[:10]
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

Wed Sep 23 20:04:02 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.66       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P8     7W /  75W |      0MiB /  7611MiB |      0%      Default |


In [2]:
import sys
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm_notebook as tq

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR

In [3]:
base_path = "/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/"
sys.path.append(base_path)
torch.manual_seed(123)

<torch._C.Generator at 0x7feea4832210>

In [4]:
from src.data_utils import CustomDataLoader, SavePath
from src.model import Encoder, Decoder, Discriminator
from src.config import TrainConfig

In [5]:
args = TrainConfig( base_path,              # project directory path
                    n_epochs = 100,         # number of epochs to train (default: 100)
                    batch_size = 128,       # input batch size for training (default: 128)
                    lr = 0.0001,            # learning rate (default: 0.0001)
                    dim_h = 128,            # hidden dimension (default: 128)')
                    n_z = 8,                # hidden dimension of z (default: 8)
                    LAMBDA = 10,            # regularization coef term (default: 10)
                    sigma = 1,              # variance of hidden dimension (default: 1)
                    n_channel = 1,          # input channels (default: 1)
                    img_size = 28 )         # image size

In [None]:
# sp = SavePath(args, checkpoint_path=None)

/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Fri-Sep-18-19-47-27-2020/


In [None]:
# cdl = CustomDataLoader(args)
# train_loader = cdl.get_data_loader(train=True)
# test_loader = cdl.get_data_loader(train=False)

In [None]:
# criterion = nn.MSELoss()

In [None]:
# ae = AutoEncoderGAN(args, criterion, torch.cuda.is_available(), 
                    # train_loader, test_loader, optimizer="Adam")

In [None]:
# ae.train(out_frequency=1, save_model_frequency=25, save_paths=sp)

In [6]:
sp = SavePath(args, checkpoint_path="/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Fri-Sep-18-19-57-30-2020/")

/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Fri-Sep-18-19-57-30-2020/


In [7]:
def unfreeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def freeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [8]:
def save_models(model_path, epoch_no, encoder, decoder, discriminator):
    print("Saving models")
    if encoder:
        torch.save(encoder.state_dict(), model_path + "/encoder_" + "%d.pth" % epoch_no)
    if decoder:
        torch.save(decoder.state_dict(), model_path + "/decoder_" + "%d.pth" % epoch_no)
    if discriminator:
        torch.save(discriminator.state_dict(), model_path + "/discriminator_" + "%d.pth" % epoch_no)

def save_lists(list_path, epoch_no, reconstr_loss):
    print("Saving list")
    if reconstr_loss:
        np.savetxt(list_path + '/reconstr_loss_' + '%d.txt'% epoch_no, reconstr_loss)

In [9]:
cdl = CustomDataLoader(args)
train_loader = cdl.get_data_loader(train=True)
test_loader = cdl.get_data_loader(train=False)

In [10]:
encoder, decoder, discriminator = Encoder(args), Decoder(args), Discriminator(args)
criterion = nn.MSELoss()

encoder.train()
decoder.train()
discriminator.train()

Discriminator(
  (main): Sequential(
    (0): Linear(in_features=8, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU(inplace=True)
    (8): Linear(in_features=512, out_features=1, bias=True)
    (9): Sigmoid()
  )
)

In [11]:
# Optimizers
enc_optim = optim.Adam(encoder.parameters(), lr = args.lr)
dec_optim = optim.Adam(decoder.parameters(), lr = args.lr)
dis_optim = optim.Adam(discriminator.parameters(), lr = 0.5 * args.lr)

enc_scheduler = StepLR(enc_optim, step_size=30, gamma=0.5)
dec_scheduler = StepLR(dec_optim, step_size=30, gamma=0.5)
dis_scheduler = StepLR(dis_optim, step_size=30, gamma=0.5)

In [12]:
if torch.cuda.is_available():
    encoder, decoder, discriminator = encoder.cuda(), decoder.cuda(), discriminator.cuda()

In [13]:
one = torch.tensor(1)
mone = one * -1

In [14]:
if torch.cuda.is_available():
    one, mone = one.cuda(), mone.cuda()

In [15]:
checkpoint = 200

_, list_path, model_path = sp.get_save_paths()

encoder.load_state_dict(torch.load(
                model_path + "/encoder_{}.pth".format(checkpoint)))
decoder.load_state_dict(torch.load(
                model_path + "/decoder_{}.pth".format(checkpoint)))
discriminator.load_state_dict(torch.load(
        model_path + "/discriminator_{}.pth".format(checkpoint)))
reconstr_loss_epoch = np.loadtxt(
    list_path + '/reconstr_loss_{}.txt'.format(checkpoint)).tolist()

In [16]:
image_path, list_path, model_path = sp.get_save_paths()

In [17]:
# reconstr_loss_epoch = []
reconstr_loss = []

In [18]:
for epoch in range(checkpoint, checkpoint+args.n_epochs):
    for step, (images, _) in tq(enumerate(train_loader)):
        
        reconstr_loss.clear() 
        # discriminator_loss.clear()
        
        if torch.cuda.is_available():
            images = images.cuda()

        encoder.zero_grad()
        decoder.zero_grad()
        discriminator.zero_grad()

        # ======== Train Discriminator ======== #

        freeze_params(decoder)
        freeze_params(encoder)
        unfreeze_params(discriminator)

        z_fake = torch.randn(images.size()[0], args.n_z) * args.sigma

        if torch.cuda.is_available():
            z_fake = z_fake.cuda()

        d_fake = discriminator(z_fake)

        z_real = encoder(images)
        d_real = discriminator(z_real)

        torch.log(d_fake).mean().backward(mone)
        torch.log(1 - d_real).mean().backward(mone)

        dis_optim.step()

        # ======== Train Generator ======== #

        unfreeze_params(decoder)
        unfreeze_params(encoder)
        freeze_params(discriminator)

        batch_size = images.size()[0]

        z_real = encoder(images)
        x_recon = decoder(z_real)
        d_real = discriminator(encoder(Variable(images.data)))

        recon_loss = criterion(x_recon, images)
        d_loss = args.LAMBDA * (torch.log(d_real)).mean()

        recon_loss.backward(one)
        d_loss.backward(mone)

        enc_optim.step()
        dec_optim.step()

        reconstr_loss.append(recon_loss.data.item())

    if (epoch + 1) % 1 == 0:
        print("Epoch: [%d/%d], Step: [%d/%d], Reconstruction Loss: %.4f" %
                (epoch + 1, args.n_epochs, step + 1, len(train_loader), recon_loss.data.item()))

    if (epoch + 1) % 1 == 0:

        reconstr_loss_epoch.append(np.mean(reconstr_loss))

        batch_size = 128
        test_iter = iter(test_loader)
        test_data = next(test_iter)

        z_real = encoder(Variable(test_data[0]).cuda())
        reconst = decoder(z_real).cpu().view(batch_size, 1, 28, 28)
        sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 1, 28, 28)

        image = torch.cat((test_data[0].view(batch_size, 1, 28, 28), 
                                reconst.data), axis=3)
        
        save_image(image, image_path +
                                '/inputs_reconstr_{}.png'.format(epoch+1))
        
        save_image(sample, image_path + 
                                '/sample_{}.png'.format(epoch+1))

        # save_image(test_data[0].view(batch_size, 1, 28, 28), image_path + '/wae_gan_input.png')
        # save_image(reconst.data, image_path + '/wae_gan_images_%d.png' % (epoch + 1))
    
    if (epoch + 1) % 25 == 0:
        save_models(model_path, epoch+1, encoder, decoder, discriminator)
        save_lists(list_path, epoch+1, reconstr_loss_epoch)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [201/100], Step: [469/469], Reconstruction Loss: 0.0250


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [202/100], Step: [469/469], Reconstruction Loss: 0.0297


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [203/100], Step: [469/469], Reconstruction Loss: 0.0241


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [204/100], Step: [469/469], Reconstruction Loss: 0.0260


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [205/100], Step: [469/469], Reconstruction Loss: 0.0236


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [206/100], Step: [469/469], Reconstruction Loss: 0.0265


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [207/100], Step: [469/469], Reconstruction Loss: 0.0294


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [208/100], Step: [469/469], Reconstruction Loss: 0.0240


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [209/100], Step: [469/469], Reconstruction Loss: 0.0257


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [210/100], Step: [469/469], Reconstruction Loss: 0.0244


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [211/100], Step: [469/469], Reconstruction Loss: 0.0252


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [212/100], Step: [469/469], Reconstruction Loss: 0.0225


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [213/100], Step: [469/469], Reconstruction Loss: 0.0230


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [214/100], Step: [469/469], Reconstruction Loss: 0.0282


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [215/100], Step: [469/469], Reconstruction Loss: 0.0238


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [216/100], Step: [469/469], Reconstruction Loss: 0.0228


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [217/100], Step: [469/469], Reconstruction Loss: 0.0218


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [218/100], Step: [469/469], Reconstruction Loss: 0.0246


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [219/100], Step: [469/469], Reconstruction Loss: 0.0217


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [220/100], Step: [469/469], Reconstruction Loss: 0.0244


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [221/100], Step: [469/469], Reconstruction Loss: 0.0224


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [222/100], Step: [469/469], Reconstruction Loss: 0.0231


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [223/100], Step: [469/469], Reconstruction Loss: 0.0254


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [224/100], Step: [469/469], Reconstruction Loss: 0.0267


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [225/100], Step: [469/469], Reconstruction Loss: 0.0273
Saving models
Saving list


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [226/100], Step: [469/469], Reconstruction Loss: 0.0232


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [227/100], Step: [469/469], Reconstruction Loss: 0.0241


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [228/100], Step: [469/469], Reconstruction Loss: 0.0225


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [229/100], Step: [469/469], Reconstruction Loss: 0.0247


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [230/100], Step: [469/469], Reconstruction Loss: 0.0266


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [231/100], Step: [469/469], Reconstruction Loss: 0.0242


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [232/100], Step: [469/469], Reconstruction Loss: 0.0243


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [233/100], Step: [469/469], Reconstruction Loss: 0.0221


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [234/100], Step: [469/469], Reconstruction Loss: 0.0267


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [235/100], Step: [469/469], Reconstruction Loss: 0.0238


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [236/100], Step: [469/469], Reconstruction Loss: 0.0258


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [237/100], Step: [469/469], Reconstruction Loss: 0.0287


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [238/100], Step: [469/469], Reconstruction Loss: 0.0221


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [239/100], Step: [469/469], Reconstruction Loss: 0.0237


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [240/100], Step: [469/469], Reconstruction Loss: 0.0272


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [241/100], Step: [469/469], Reconstruction Loss: 0.0263


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [242/100], Step: [469/469], Reconstruction Loss: 0.0207


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [243/100], Step: [469/469], Reconstruction Loss: 0.0244


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [244/100], Step: [469/469], Reconstruction Loss: 0.0260


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [245/100], Step: [469/469], Reconstruction Loss: 0.0216


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [246/100], Step: [469/469], Reconstruction Loss: 0.0224


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [247/100], Step: [469/469], Reconstruction Loss: 0.0263


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [248/100], Step: [469/469], Reconstruction Loss: 0.0208


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [249/100], Step: [469/469], Reconstruction Loss: 0.0304


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [250/100], Step: [469/469], Reconstruction Loss: 0.0234
Saving models
Saving list


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [251/100], Step: [469/469], Reconstruction Loss: 0.0224


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [252/100], Step: [469/469], Reconstruction Loss: 0.0223


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [253/100], Step: [469/469], Reconstruction Loss: 0.0244


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [254/100], Step: [469/469], Reconstruction Loss: 0.0224


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [255/100], Step: [469/469], Reconstruction Loss: 0.0237


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [256/100], Step: [469/469], Reconstruction Loss: 0.0278


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [257/100], Step: [469/469], Reconstruction Loss: 0.0241


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [258/100], Step: [469/469], Reconstruction Loss: 0.0319


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [259/100], Step: [469/469], Reconstruction Loss: 0.0230


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [260/100], Step: [469/469], Reconstruction Loss: 0.0247


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [261/100], Step: [469/469], Reconstruction Loss: 0.0207


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [262/100], Step: [469/469], Reconstruction Loss: 0.0264


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [263/100], Step: [469/469], Reconstruction Loss: 0.0271


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [264/100], Step: [469/469], Reconstruction Loss: 0.0223


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [265/100], Step: [469/469], Reconstruction Loss: 0.0246


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [266/100], Step: [469/469], Reconstruction Loss: 0.0238


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [267/100], Step: [469/469], Reconstruction Loss: 0.0211


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [268/100], Step: [469/469], Reconstruction Loss: 0.0234


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [269/100], Step: [469/469], Reconstruction Loss: 0.0216


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [270/100], Step: [469/469], Reconstruction Loss: 0.0241


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [271/100], Step: [469/469], Reconstruction Loss: 0.0260


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [272/100], Step: [469/469], Reconstruction Loss: 0.0229


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [273/100], Step: [469/469], Reconstruction Loss: 0.0229


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [274/100], Step: [469/469], Reconstruction Loss: 0.0211


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [275/100], Step: [469/469], Reconstruction Loss: 0.0224
Saving models
Saving list


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [276/100], Step: [469/469], Reconstruction Loss: 0.0230


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [277/100], Step: [469/469], Reconstruction Loss: 0.0210


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [278/100], Step: [469/469], Reconstruction Loss: 0.0219


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [279/100], Step: [469/469], Reconstruction Loss: 0.0215


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [280/100], Step: [469/469], Reconstruction Loss: 0.0271


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [281/100], Step: [469/469], Reconstruction Loss: 0.0222


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [282/100], Step: [469/469], Reconstruction Loss: 0.0230


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [283/100], Step: [469/469], Reconstruction Loss: 0.0217


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [284/100], Step: [469/469], Reconstruction Loss: 0.0224


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [285/100], Step: [469/469], Reconstruction Loss: 0.0203


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [286/100], Step: [469/469], Reconstruction Loss: 0.0203


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [287/100], Step: [469/469], Reconstruction Loss: 0.0220


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [288/100], Step: [469/469], Reconstruction Loss: 0.0233


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [289/100], Step: [469/469], Reconstruction Loss: 0.0223


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [290/100], Step: [469/469], Reconstruction Loss: 0.0227


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [291/100], Step: [469/469], Reconstruction Loss: 0.0299


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [292/100], Step: [469/469], Reconstruction Loss: 0.0215


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [293/100], Step: [469/469], Reconstruction Loss: 0.0227


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [294/100], Step: [469/469], Reconstruction Loss: 0.0214


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [295/100], Step: [469/469], Reconstruction Loss: 0.0225


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [296/100], Step: [469/469], Reconstruction Loss: 0.0210


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [297/100], Step: [469/469], Reconstruction Loss: 0.0226


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [298/100], Step: [469/469], Reconstruction Loss: 0.0217


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [299/100], Step: [469/469], Reconstruction Loss: 0.0208


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [300/100], Step: [469/469], Reconstruction Loss: 0.0232
Saving models
Saving list


In [None]:
z_real = encoder(Variable(test_data[0]).cuda())        
reconst = decoder(z_real).cpu().view(batch_size, 1, 28, 28)
sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 1, 28, 28)

RuntimeError: ignored

In [None]:
image = torch.cat((test_data[0].view(batch_size, 1, 28, 28), 
                                reconst.data), axis=3)

In [None]:
image.shape

torch.Size([128, 1, 28, 56])

In [None]:
!pwd

/content


In [None]:
 save_image(image, 'inputs_reconstr_{}.png'.format(epoch+1))

In [None]:
class Nested:
    def __init__(self, nestedList):
            
            self.curr = nestedList
            self.stack, self.index = [], 0
            
    def flatten(self, curr, stack, index):

        if index < len(curr):
            if type(curr[index]) is int:
                # return curr[index], curr, stack, index+1
                retval = curr[index]
                index += 1
            else:
                stack.append((curr, index+1))
                curr = curr[index]
                index = 0
                retval = None
                
        elif stack:
            curr, index = stack.pop()
            retval = None
        else:
            curr = None
            retval = None

        return retval, curr, stack, index

    def value_generator(self):
        while self.curr or self.stack:
            val, self.curr, self.stack, self.index = self.flatten(self.curr, self.stack, self.index)
            print(val, self.curr, self.stack, self.index)
            if val:
                # print(val, self.curr, self.stack, self.index)
                yield val


In [None]:
nestedList = [1, [2, [3]],4,[5,6],7]

In [None]:
nested = Nested(nestedList)

In [None]:
while 1:
    try:
        next(nested.value_generator())
    except StopIteration:
        print(False)
        break

1 [1, [2, [3]], 4, [5, 6], 7] [] 1
None [2, [3]] [([1, [2, [3]], 4, [5, 6], 7], 2)] 0
2 [2, [3]] [([1, [2, [3]], 4, [5, 6], 7], 2)] 1
None [3] [([1, [2, [3]], 4, [5, 6], 7], 2), ([2, [3]], 2)] 0
3 [3] [([1, [2, [3]], 4, [5, 6], 7], 2), ([2, [3]], 2)] 1
None [2, [3]] [([1, [2, [3]], 4, [5, 6], 7], 2)] 2
None [1, [2, [3]], 4, [5, 6], 7] [] 2
4 [1, [2, [3]], 4, [5, 6], 7] [] 3
None [5, 6] [([1, [2, [3]], 4, [5, 6], 7], 4)] 0
5 [5, 6] [([1, [2, [3]], 4, [5, 6], 7], 4)] 1
6 [5, 6] [([1, [2, [3]], 4, [5, 6], 7], 4)] 2
None [1, [2, [3]], 4, [5, 6], 7] [] 4
7 [1, [2, [3]], 4, [5, 6], 7] [] 5
None None [] 5
False
