In [None]:
gpu_info = !nvidia-smi
gpu_info = gpu_info[:10]
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

Wed Sep 30 19:42:26 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.05    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm_notebook as tq

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR

In [None]:
base_path = "/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/"
sys.path.append(base_path)
torch.manual_seed(123)

<torch._C.Generator at 0x7f456f1da678>

In [None]:
from src.data_utils import LoadDataset, SavePath
from src.model import Encoder, Decoder, Discriminator
from src.config import TrainConfig

In [None]:
args = TrainConfig( base_path,              # project directory path
                    n_epochs = 100,         # number of epochs to train (default: 100)
                    batch_size = 128,       # input batch size for training (default: 128)
                    lr = 0.0002,            # learning rate (default: 0.0001)
                    dim_h = 128,            # hidden dimension (default: 128)')
                    n_z = 8,                # hidden dimension of z (default: 8)
                    LAMBDA = 10,            # regularization coef term (default: 10)
                    sigma = 1,              # variance of hidden dimension (default: 1)
                    n_channel = 1,          # input channels (default: 1)
                    img_size = 28 )         # image size

In [None]:
# sp = SavePath(args, checkpoint_path=None)

/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Fri-Sep-18-19-47-27-2020/


In [None]:
# cdl = CustomDataLoader(args)
# train_loader = cdl.get_data_loader(train=True)
# test_loader = cdl.get_data_loader(train=False)

In [None]:
# criterion = nn.MSELoss()

In [None]:
# ae = AutoEncoderGAN(args, criterion, torch.cuda.is_available(), 
                    # train_loader, test_loader, optimizer="Adam")

In [None]:
# ae.train(out_frequency=1, save_model_frequency=25, save_paths=sp)

In [None]:
sp = SavePath(args, checkpoint_path="/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Wed-Sep-30-18-59-15-2020/")
# sp = SavePath(args)

/content/drive/My Drive/UC Davis Synthetic Data/Prashanth's/Autoencoder/Autoencoder/outs/Wed-Sep-30-18-59-15-2020/


In [None]:
def unfreeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def freeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [None]:
def save_models(model_path, epoch_no, encoder, decoder, discriminator):
    print("Saving models")
    if encoder:
        torch.save(encoder.state_dict(), model_path + "/encoder_" + "%d.pth" % epoch_no)
    if decoder:
        torch.save(decoder.state_dict(), model_path + "/decoder_" + "%d.pth" % epoch_no)
    if discriminator:
        torch.save(discriminator.state_dict(), model_path + "/discriminator_" + "%d.pth" % epoch_no)

def save_lists(list_path, epoch_no, reconstr_loss):
    print("Saving list")
    if reconstr_loss:
        np.savetxt(list_path + '/reconstr_loss_' + '%d.txt'% epoch_no, reconstr_loss)

In [None]:
cdl = LoadDataset(args)
train_loader = cdl.get_data_loader(train=True)
test_loader = cdl.get_data_loader(train=False)

In [None]:
encoder, decoder, discriminator = Encoder(args), Decoder(args), Discriminator(args)
criterion = nn.MSELoss()

encoder.train()
decoder.train()
discriminator.train()

# Optimizers
enc_optim = optim.Adam(encoder.parameters(), lr = args.lr)
dec_optim = optim.Adam(decoder.parameters(), lr = args.lr)
dis_optim = optim.Adam(discriminator.parameters(), lr = 0.5 * args.lr)

enc_scheduler = StepLR(enc_optim, step_size=30, gamma=0.5)
dec_scheduler = StepLR(dec_optim, step_size=30, gamma=0.5)
dis_scheduler = StepLR(dis_optim, step_size=30, gamma=0.5)

In [None]:
if torch.cuda.is_available():
    encoder, decoder, discriminator = encoder.cuda(), decoder.cuda(), discriminator.cuda()

In [None]:
one = torch.tensor(1)
mone = one * -1
if torch.cuda.is_available():
    one, mone = one.cuda(), mone.cuda()

In [29]:
checkpoint = 200

if checkpoint:
    _, list_path, model_path = sp.get_save_paths()

    encoder.load_state_dict(torch.load(
                    model_path + "/encoder_{}.pth".format(checkpoint)))
    decoder.load_state_dict(torch.load(
                    model_path + "/decoder_{}.pth".format(checkpoint)))
    discriminator.load_state_dict(torch.load(
            model_path + "/discriminator_{}.pth".format(checkpoint)))
    reconstr_loss_epoch = np.loadtxt(
        list_path + '/reconstr_loss_{}.txt'.format(checkpoint)).tolist()

if not checkpoint:
    reconstr_loss_epoch = []
reconstr_loss = []
disc_loss = []
d_real_val = []

image_path, list_path, model_path = sp.get_save_paths()

In [30]:
for epoch in range(checkpoint, checkpoint+args.n_epochs):
    for step, (images, _) in tq(enumerate(train_loader)):
        
        reconstr_loss.clear() 
        # discriminator_loss.clear()
        
        if torch.cuda.is_available():
            images = images.cuda()

        encoder.zero_grad()
        decoder.zero_grad()
        discriminator.zero_grad()

        # ======== Train Discriminator ======== #

        freeze_params(decoder)
        freeze_params(encoder)
        unfreeze_params(discriminator)

        z_fake = torch.randn(images.size()[0], args.n_z) * args.sigma

        if torch.cuda.is_available():
            z_fake = z_fake.cuda()

        d_fake = discriminator(z_fake)

        z_real = encoder(images)
        d_real = discriminator(z_real)

        # df = round(, 3)
        # dr = round(torch.true_divide((d_real > 0.5).sum(), d_real.shape[0]).item(), 3)

        # negate for gradient ascent
        d_loss = -(torch.log(d_fake).mean() + torch.log(1 - d_real).mean())

        disc_loss.append([d_fake.mean().item(), d_real.mean().item(), d_loss.item()])

        d_loss.backward()

        dis_optim.step()

        if (step + 1) % 10 == 0:

            # ======== Train Generator ======== #

            unfreeze_params(decoder)
            unfreeze_params(encoder)
            freeze_params(discriminator)

            batch_size = images.size()[0]

            z_real = encoder(images)
            x_recon = decoder(z_real)
            d_real = discriminator(encoder(Variable(images.data)))

            d_real_val.append((torch.log(d_real)).mean().item())
            # recon_loss = criterion(x_recon, images) - (args.LAMBDA * (torch.log(d_real)).mean())

            recon_loss = criterion(x_recon, images)
            d_loss = args.LAMBDA * (torch.log(d_real)).mean()

            recon_loss.backward(one)
            d_loss.backward(mone)

            enc_optim.step()
            dec_optim.step()

            reconstr_loss.append(recon_loss.data.item())

    if (epoch + 1) % 1 == 0:
        # print("Epoch: [%d/%d], Step: [%d/%d], Reconstruction Loss: %.4f" %
        #         (epoch + 1, args.n_epochs, step + 1, len(train_loader), recon_loss.data.item()))
        temp = np.mean(disc_loss, axis=0)
        print("Epoch: [%d/%d], Recon Loss: [%.4f], Reg loss: [%.4f], D_fake: [%.4f], D_real: [%.4f], D_loss: [%.4f]" %
                (epoch + 1, args.n_epochs, 
                 np.mean(reconstr_loss), np.mean(d_real_val), temp[0], temp[1], temp[2]))
        disc_loss.clear()
        d_real_val.clear()

    if (epoch + 1) % 1 == 0:

        reconstr_loss_epoch.append(np.mean(reconstr_loss))
        reconstr_loss.clear()

        batch_size = 128
        test_iter = iter(test_loader)
        test_data = next(test_iter)

        z_real = encoder(Variable(test_data[0]).cuda())
        reconst = decoder(z_real).cpu().view(batch_size, 1, 28, 28)
        sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 1, 28, 28)

        image = torch.cat((test_data[0].view(batch_size, 1, 28, 28), 
                                reconst.data), axis=3)
        
        save_image(image, image_path +
                                '/inputs_reconstr_{}.png'.format(epoch+1))
        
        save_image(sample, image_path + 
                                '/sample_{}.png'.format(epoch+1))

        # save_image(test_data[0].view(batch_size, 1, 28, 28), image_path + '/wae_gan_input.png')
        # save_image(reconst.data, image_path + '/wae_gan_images_%d.png' % (epoch + 1))
    
    if (epoch + 1) % 25 == 0:
        save_models(model_path, epoch+1, encoder, decoder, discriminator)
        save_lists(list_path, epoch+1, reconstr_loss_epoch)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [201/100], Recon Loss: [nan], Reg loss: [-0.7097], D_fake: [0.5062], D_real: [0.4949], D_loss: [1.3749]


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [202/100], Recon Loss: [nan], Reg loss: [-0.7193], D_fake: [0.5050], D_real: [0.4917], D_loss: [1.3733]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [203/100], Recon Loss: [nan], Reg loss: [-0.7117], D_fake: [0.5058], D_real: [0.4931], D_loss: [1.3740]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [204/100], Recon Loss: [nan], Reg loss: [-0.7125], D_fake: [0.5040], D_real: [0.4944], D_loss: [1.3777]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [205/100], Recon Loss: [nan], Reg loss: [-0.7043], D_fake: [0.5047], D_real: [0.4966], D_loss: [1.3780]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [206/100], Recon Loss: [nan], Reg loss: [-0.7131], D_fake: [0.5054], D_real: [0.4941], D_loss: [1.3751]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [207/100], Recon Loss: [nan], Reg loss: [-0.7140], D_fake: [0.5052], D_real: [0.4938], D_loss: [1.3752]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [208/100], Recon Loss: [nan], Reg loss: [-0.7067], D_fake: [0.5059], D_real: [0.4956], D_loss: [1.3762]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [209/100], Recon Loss: [nan], Reg loss: [-0.7097], D_fake: [0.5073], D_real: [0.4944], D_loss: [1.3732]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [210/100], Recon Loss: [nan], Reg loss: [-0.7163], D_fake: [0.5057], D_real: [0.4926], D_loss: [1.3737]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [211/100], Recon Loss: [nan], Reg loss: [-0.7135], D_fake: [0.5048], D_real: [0.4940], D_loss: [1.3757]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [212/100], Recon Loss: [nan], Reg loss: [-0.7180], D_fake: [0.5068], D_real: [0.4921], D_loss: [1.3715]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [213/100], Recon Loss: [nan], Reg loss: [-0.7144], D_fake: [0.5065], D_real: [0.4929], D_loss: [1.3728]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [214/100], Recon Loss: [nan], Reg loss: [-0.7098], D_fake: [0.5070], D_real: [0.4949], D_loss: [1.3744]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [215/100], Recon Loss: [nan], Reg loss: [-0.7133], D_fake: [0.5076], D_real: [0.4942], D_loss: [1.3735]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [216/100], Recon Loss: [nan], Reg loss: [-0.7144], D_fake: [0.5067], D_real: [0.4942], D_loss: [1.3740]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [217/100], Recon Loss: [nan], Reg loss: [-0.7156], D_fake: [0.5032], D_real: [0.4932], D_loss: [1.3761]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [218/100], Recon Loss: [nan], Reg loss: [-0.7192], D_fake: [0.5061], D_real: [0.4917], D_loss: [1.3724]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [219/100], Recon Loss: [nan], Reg loss: [-0.7131], D_fake: [0.5052], D_real: [0.4934], D_loss: [1.3750]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [220/100], Recon Loss: [nan], Reg loss: [-0.7131], D_fake: [0.5074], D_real: [0.4946], D_loss: [1.3735]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [221/100], Recon Loss: [nan], Reg loss: [-0.7155], D_fake: [0.5060], D_real: [0.4926], D_loss: [1.3731]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [222/100], Recon Loss: [nan], Reg loss: [-0.7145], D_fake: [0.5073], D_real: [0.4925], D_loss: [1.3718]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [223/100], Recon Loss: [nan], Reg loss: [-0.7131], D_fake: [0.5074], D_real: [0.4944], D_loss: [1.3734]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [224/100], Recon Loss: [nan], Reg loss: [-0.7102], D_fake: [0.5058], D_real: [0.4944], D_loss: [1.3755]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [225/100], Recon Loss: [nan], Reg loss: [-0.7173], D_fake: [0.5064], D_real: [0.4924], D_loss: [1.3725]
Saving models
Saving list


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [226/100], Recon Loss: [nan], Reg loss: [-0.7171], D_fake: [0.5059], D_real: [0.4929], D_loss: [1.3739]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [227/100], Recon Loss: [nan], Reg loss: [-0.7115], D_fake: [0.5055], D_real: [0.4947], D_loss: [1.3754]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [228/100], Recon Loss: [nan], Reg loss: [-0.7125], D_fake: [0.5081], D_real: [0.4942], D_loss: [1.3730]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [229/100], Recon Loss: [nan], Reg loss: [-0.7215], D_fake: [0.5028], D_real: [0.4911], D_loss: [1.3749]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [230/100], Recon Loss: [nan], Reg loss: [-0.7202], D_fake: [0.5087], D_real: [0.4924], D_loss: [1.3709]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [231/100], Recon Loss: [nan], Reg loss: [-0.7150], D_fake: [0.5072], D_real: [0.4927], D_loss: [1.3719]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [232/100], Recon Loss: [nan], Reg loss: [-0.7227], D_fake: [0.5059], D_real: [0.4917], D_loss: [1.3730]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [233/100], Recon Loss: [nan], Reg loss: [-0.7136], D_fake: [0.5072], D_real: [0.4941], D_loss: [1.3730]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [234/100], Recon Loss: [nan], Reg loss: [-0.7154], D_fake: [0.5066], D_real: [0.4937], D_loss: [1.3739]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [235/100], Recon Loss: [nan], Reg loss: [-0.7129], D_fake: [0.5035], D_real: [0.4939], D_loss: [1.3766]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [236/100], Recon Loss: [nan], Reg loss: [-0.7112], D_fake: [0.5051], D_real: [0.4936], D_loss: [1.3748]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [237/100], Recon Loss: [nan], Reg loss: [-0.7066], D_fake: [0.5058], D_real: [0.4959], D_loss: [1.3766]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [238/100], Recon Loss: [nan], Reg loss: [-0.7122], D_fake: [0.5047], D_real: [0.4941], D_loss: [1.3758]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [239/100], Recon Loss: [nan], Reg loss: [-0.7082], D_fake: [0.5063], D_real: [0.4951], D_loss: [1.3751]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [240/100], Recon Loss: [nan], Reg loss: [-0.7117], D_fake: [0.5053], D_real: [0.4941], D_loss: [1.3754]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [241/100], Recon Loss: [nan], Reg loss: [-0.7118], D_fake: [0.5048], D_real: [0.4933], D_loss: [1.3745]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [242/100], Recon Loss: [nan], Reg loss: [-0.7169], D_fake: [0.5064], D_real: [0.4930], D_loss: [1.3739]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [243/100], Recon Loss: [nan], Reg loss: [-0.7129], D_fake: [0.5070], D_real: [0.4947], D_loss: [1.3741]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [244/100], Recon Loss: [nan], Reg loss: [-0.7154], D_fake: [0.5057], D_real: [0.4926], D_loss: [1.3732]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [245/100], Recon Loss: [nan], Reg loss: [-0.7070], D_fake: [0.5079], D_real: [0.4965], D_loss: [1.3753]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [246/100], Recon Loss: [nan], Reg loss: [-0.7123], D_fake: [0.5058], D_real: [0.4935], D_loss: [1.3746]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [247/100], Recon Loss: [nan], Reg loss: [-0.7115], D_fake: [0.5068], D_real: [0.4946], D_loss: [1.3737]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [248/100], Recon Loss: [nan], Reg loss: [-0.7148], D_fake: [0.5037], D_real: [0.4934], D_loss: [1.3771]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [249/100], Recon Loss: [nan], Reg loss: [-0.7181], D_fake: [0.5033], D_real: [0.4927], D_loss: [1.3756]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [250/100], Recon Loss: [nan], Reg loss: [-0.7078], D_fake: [0.5063], D_real: [0.4957], D_loss: [1.3762]
Saving models
Saving list


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [251/100], Recon Loss: [nan], Reg loss: [-0.7071], D_fake: [0.5066], D_real: [0.4955], D_loss: [1.3749]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [252/100], Recon Loss: [nan], Reg loss: [-0.7083], D_fake: [0.5049], D_real: [0.4952], D_loss: [1.3768]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [253/100], Recon Loss: [nan], Reg loss: [-0.7082], D_fake: [0.5045], D_real: [0.4948], D_loss: [1.3768]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [254/100], Recon Loss: [nan], Reg loss: [-0.7137], D_fake: [0.5063], D_real: [0.4931], D_loss: [1.3731]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [255/100], Recon Loss: [nan], Reg loss: [-0.7132], D_fake: [0.5063], D_real: [0.4935], D_loss: [1.3737]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [256/100], Recon Loss: [nan], Reg loss: [-0.7122], D_fake: [0.5087], D_real: [0.4948], D_loss: [1.3730]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [257/100], Recon Loss: [nan], Reg loss: [-0.7153], D_fake: [0.5027], D_real: [0.4921], D_loss: [1.3760]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [258/100], Recon Loss: [nan], Reg loss: [-0.7119], D_fake: [0.5059], D_real: [0.4942], D_loss: [1.3743]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [259/100], Recon Loss: [nan], Reg loss: [-0.7059], D_fake: [0.5068], D_real: [0.4964], D_loss: [1.3759]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [260/100], Recon Loss: [nan], Reg loss: [-0.7093], D_fake: [0.5038], D_real: [0.4948], D_loss: [1.3773]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [261/100], Recon Loss: [nan], Reg loss: [-0.7105], D_fake: [0.5064], D_real: [0.4947], D_loss: [1.3746]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [262/100], Recon Loss: [nan], Reg loss: [-0.7170], D_fake: [0.5037], D_real: [0.4926], D_loss: [1.3754]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [263/100], Recon Loss: [nan], Reg loss: [-0.7087], D_fake: [0.5057], D_real: [0.4958], D_loss: [1.3764]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [264/100], Recon Loss: [nan], Reg loss: [-0.7114], D_fake: [0.5043], D_real: [0.4943], D_loss: [1.3765]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [265/100], Recon Loss: [nan], Reg loss: [-0.7137], D_fake: [0.5042], D_real: [0.4934], D_loss: [1.3754]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [266/100], Recon Loss: [nan], Reg loss: [-0.7102], D_fake: [0.5055], D_real: [0.4945], D_loss: [1.3753]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [267/100], Recon Loss: [nan], Reg loss: [-0.7040], D_fake: [0.5050], D_real: [0.4963], D_loss: [1.3776]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [268/100], Recon Loss: [nan], Reg loss: [-0.7081], D_fake: [0.5047], D_real: [0.4950], D_loss: [1.3768]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [269/100], Recon Loss: [nan], Reg loss: [-0.7096], D_fake: [0.5062], D_real: [0.4958], D_loss: [1.3761]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [270/100], Recon Loss: [nan], Reg loss: [-0.7107], D_fake: [0.5051], D_real: [0.4940], D_loss: [1.3753]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [271/100], Recon Loss: [nan], Reg loss: [-0.7140], D_fake: [0.5075], D_real: [0.4947], D_loss: [1.3736]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [272/100], Recon Loss: [nan], Reg loss: [-0.7174], D_fake: [0.5059], D_real: [0.4931], D_loss: [1.3737]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [273/100], Recon Loss: [nan], Reg loss: [-0.7164], D_fake: [0.5054], D_real: [0.4915], D_loss: [1.3725]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [274/100], Recon Loss: [nan], Reg loss: [-0.7144], D_fake: [0.5091], D_real: [0.4946], D_loss: [1.3719]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [275/100], Recon Loss: [nan], Reg loss: [-0.7127], D_fake: [0.5038], D_real: [0.4931], D_loss: [1.3759]
Saving models
Saving list


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [276/100], Recon Loss: [nan], Reg loss: [-0.7069], D_fake: [0.5068], D_real: [0.4967], D_loss: [1.3769]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [277/100], Recon Loss: [nan], Reg loss: [-0.7120], D_fake: [0.5051], D_real: [0.4947], D_loss: [1.3762]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [278/100], Recon Loss: [nan], Reg loss: [-0.7128], D_fake: [0.5037], D_real: [0.4937], D_loss: [1.3770]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [279/100], Recon Loss: [nan], Reg loss: [-0.7082], D_fake: [0.5048], D_real: [0.4956], D_loss: [1.3769]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [280/100], Recon Loss: [nan], Reg loss: [-0.7060], D_fake: [0.5062], D_real: [0.4966], D_loss: [1.3770]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [281/100], Recon Loss: [nan], Reg loss: [-0.7083], D_fake: [0.5035], D_real: [0.4956], D_loss: [1.3785]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [282/100], Recon Loss: [nan], Reg loss: [-0.7108], D_fake: [0.5052], D_real: [0.4941], D_loss: [1.3752]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [283/100], Recon Loss: [nan], Reg loss: [-0.7173], D_fake: [0.5053], D_real: [0.4926], D_loss: [1.3736]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [284/100], Recon Loss: [nan], Reg loss: [-0.7146], D_fake: [0.5048], D_real: [0.4934], D_loss: [1.3758]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [285/100], Recon Loss: [nan], Reg loss: [-0.7092], D_fake: [0.5072], D_real: [0.4952], D_loss: [1.3741]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [286/100], Recon Loss: [nan], Reg loss: [-0.7059], D_fake: [0.5079], D_real: [0.4964], D_loss: [1.3748]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [287/100], Recon Loss: [nan], Reg loss: [-0.7099], D_fake: [0.5061], D_real: [0.4944], D_loss: [1.3745]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [288/100], Recon Loss: [nan], Reg loss: [-0.7114], D_fake: [0.5048], D_real: [0.4945], D_loss: [1.3764]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [289/100], Recon Loss: [nan], Reg loss: [-0.7202], D_fake: [0.5030], D_real: [0.4909], D_loss: [1.3743]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [290/100], Recon Loss: [nan], Reg loss: [-0.7130], D_fake: [0.5082], D_real: [0.4949], D_loss: [1.3731]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [291/100], Recon Loss: [nan], Reg loss: [-0.7122], D_fake: [0.5093], D_real: [0.4950], D_loss: [1.3717]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [292/100], Recon Loss: [nan], Reg loss: [-0.7206], D_fake: [0.5036], D_real: [0.4912], D_loss: [1.3741]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [293/100], Recon Loss: [nan], Reg loss: [-0.7180], D_fake: [0.5044], D_real: [0.4921], D_loss: [1.3744]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [294/100], Recon Loss: [nan], Reg loss: [-0.7142], D_fake: [0.5066], D_real: [0.4938], D_loss: [1.3735]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [295/100], Recon Loss: [nan], Reg loss: [-0.7097], D_fake: [0.5051], D_real: [0.4946], D_loss: [1.3761]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [296/100], Recon Loss: [nan], Reg loss: [-0.7101], D_fake: [0.5059], D_real: [0.4942], D_loss: [1.3749]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [297/100], Recon Loss: [nan], Reg loss: [-0.7104], D_fake: [0.5060], D_real: [0.4956], D_loss: [1.3758]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [298/100], Recon Loss: [nan], Reg loss: [-0.7085], D_fake: [0.5047], D_real: [0.4944], D_loss: [1.3758]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [299/100], Recon Loss: [nan], Reg loss: [-0.7088], D_fake: [0.5031], D_real: [0.4951], D_loss: [1.3785]


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


Epoch: [300/100], Recon Loss: [nan], Reg loss: [-0.7043], D_fake: [0.5043], D_real: [0.4969], D_loss: [1.3791]
Saving models
Saving list


In [None]:
reconstr_loss

[]

In [None]:
z_real = encoder(Variable(test_data[0]).cuda())        
reconst = decoder(z_real).cpu().view(batch_size, 1, 28, 28)
sample = decoder(torch.randn_like(z_real)).cpu().view(batch_size, 1, 28, 28)

RuntimeError: ignored

In [None]:
image = torch.cat((test_data[0].view(batch_size, 1, 28, 28), 
                                reconst.data), axis=3)

In [None]:
image.shape

torch.Size([128, 1, 28, 56])

In [None]:
np.mean(disc_loss, axis=0)

In [None]:
disc_loss[0]

Output hidden; open in https://colab.research.google.com to view.