In [39]:
gpu_info = !nvidia-smi
gpu_info = gpu_info[:10]
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

Mon Mar  8 19:03:05 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN V             On   | 00000000:17:00.0 Off |                  N/A |
| 29%   42C    P8    26W / 250W |   5069MiB / 12066MiB |      0%      Default |


In [40]:
# from google.colab import drive
# drive.mount('/content/drive')

In [41]:
from varname import nameof

import sys
import torch
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm_notebook as tq

from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.utils import save_image, make_grid
from torch.optim.lr_scheduler import StepLR, ExponentialLR
from pytorch_model_summary import summary
from tensorboardX import SummaryWriter

In [42]:
from src.dataset.utils import SavePath
from src.dataset.dataset import Cifar
from src.config import TrainConfig
from src.pytorch_msssim import MSSSIM, ssim

In [43]:
base_path = !pwd
base_path = base_path[0] + '/'

In [44]:
args = TrainConfig( base_path,              # project directory path
                    n_epochs = 200,         # number of epochs to train (default: 100)
                    batch_size = 128,       # input batch size for training (default: 128)
                    lr = 1e-3,            # learning rate (default: 0.0001)
                    dim_h = 128,            # hidden dimension (default: 128)')
                    n_z = 128,                # hidden dimension of z (default: 8)
                    LAMBDA = 10,            # regularization coef term (default: 10)
                    sigma = 1,              # variance of hidden dimension (default: 1)
                    n_channel = 3,          # input channels (default: 1)
                    img_size = 32 )         # image size

In [45]:
def unfreeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def freeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [46]:
def save_models(model_path, epoch_no, models):
    print("Saving models")
    for model_name, model in models.items():
        torch.save(model.state_dict(), model_path + model_name + "_" + "%d.pth" % epoch_no)

def save_values_to_tensorboard(writer, epoch_no, values_dict: dict):
    for name, val in values_dict.items():
        if type(val) == dict:
            writer.add_scalars(name, val, epoch_no)
        else:
            writer.add_scalar(name, val, epoch_no)

def save_images_to_tensorboard(writer, epoch_no, image, imname='im'):
    writer.add_image(imname +'_{}'.format(epoch_no), image, epoch_no)

In [47]:
sp = SavePath(args)

/home/pr/synth.data/Autoencoder/outs/Mon-Mar--8-19-03-07-2021/


In [48]:
transform = None # dont normalize

cdl = Cifar(args)
train_loader = cdl.get_data_loader(True, transform, [0,1,2,3,4])
test_loader = cdl.get_data_loader(False, transform, [0,1,2,3,4])

Files already downloaded and verified
Files already downloaded and verified


In [49]:
from src.models.model_cifar import Encoder as ConvEncoder
from src.models.model_cifar import Decoder as ConvDecoder
from src.models.model_cifar import GanDiscriminator

In [50]:
class ZDiscriminator(nn.Module):
    def __init__(self, args):
        super(ZDiscriminator, self).__init__()

        self.n_channel = args.n_channel
        self.dim_h = args.dim_h
        self.n_z = args.n_z

        self.main = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 2),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 2, self.dim_h // 2),
            nn.ReLU(True),
            nn.Linear(self.dim_h // 2, self.dim_h // 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h // 4, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.main(x)
        x = x.squeeze()
        return x

In [51]:
conv_encoder, conv_decoder = ConvEncoder(args).cuda(), ConvDecoder(args).cuda()
gan_discriminator = GanDiscriminator(args).cuda()
z_discriminator = ZDiscriminator(args).cuda()

mse_loss_fn = nn.MSELoss().cuda()
adversarial_loss_fn = nn.BCELoss().cuda()

opt_encoder = optim.Adam(conv_encoder.parameters(), lr = args.lr)
opt_decoder = optim.Adam(conv_decoder.parameters(), lr = args.lr)
opt_discriminator = optim.Adam(gan_discriminator.parameters(), lr = args.lr)
opt_zdiscriminator = optim.Adam(z_discriminator.parameters(), lr = args.lr)
scheduler_opt_encoder = ExponentialLR(opt_encoder, gamma=0.1)
scheduler_opt_decoder = ExponentialLR(opt_decoder, gamma=0.1)
scheduler_opt_discriminator = ExponentialLR(opt_discriminator, gamma=0.1)
scheduler_opt_zdiscriminator = ExponentialLR(opt_zdiscriminator, gamma=0.1)

In [52]:
# print(summary(conv_encoder, torch.zeros((1, 3, 32, 32)).cuda(), show_input=True, show_hierarchical=False))

In [53]:
def load_models(checkpoint_path, checkpoint):
    lp = SavePath(args, checkpoint_path)
    _, _, model_load_path = lp.get_save_paths()
    conv_encoder.load_state_dict(torch.load(model_load_path + "/conv_encoder_{}.pth".format(checkpoint)))
    conv_decoder.load_state_dict(torch.load(model_load_path + "/conv_decoder_{}.pth".format(checkpoint)))
    gan_discriminator.load_state_dict(torch.load(model_load_path + "/gan_discriminator_{}.pth".format(checkpoint)))

In [54]:
checkpoint = 0
if checkpoint:
    load_models('/home/pr/synth.data/Autoencoder/outs/Mon-Mar--1-12-35-57-2021/', checkpoint)

reconstr_loss, disc_loss = [], []
norms = []
writer = SummaryWriter(log_dir = sp.results_path + "logs")

In [55]:
image_path, list_path, model_path = sp.get_save_paths()

In [56]:
# one = torch.tensor(1)
# mone = one * -1
# if torch.cuda.is_available():
#     one, mone = one.cuda(), mone.cuda()

In [57]:
for epoch in range(checkpoint, checkpoint+args.n_epochs):
    pbar = tq(enumerate(train_loader))
    for step, (images, _) in pbar:
        
        current_batch_size = images.size()[0]        
        images = images.cuda()

        conv_encoder.zero_grad()
        conv_encoder.zero_grad()
        z_discriminator.zero_grad()
        
        # Adversarial ground truths
        ones = Variable(torch.cuda.FloatTensor(current_batch_size).fill_(1.0), requires_grad=False) # real
        zeros = Variable(torch.cuda.FloatTensor(current_batch_size).fill_(0.0), requires_grad=False) # fake

        # ======== Train Discriminator ======== #

#         freeze_params(decoder)
#         freeze_params(encoder)
        unfreeze_params(z_discriminator)

        z_prior = torch.randn(current_batch_size, args.n_z).cuda()
        d_prior = z_discriminator(z_prior)
        z_real = conv_encoder(images)
        d_real = z_discriminator(z_real.detach())
        
        # save norms of encodings
        norms.append([torch.norm(z_real, dim=1).mean().item()])

        # negate for gradient ascent
#         d_loss = -(torch.log(d_fake).mean() + torch.log(1 - d_real).mean())
#         d_loss = d_prior.mean() + (1 - d_real).mean()
        d_loss = adversarial_loss_fn(d_prior, ones).mean() + adversarial_loss_fn(d_real, zeros).mean()

        disc_loss.append(d_loss.item())

        d_loss.backward()

        opt_zdiscriminator.step()


        # ======== Train Generator ======== #

#         unfreeze_params(decoder)
#         unfreeze_params(encoder)
        freeze_params(z_discriminator)

#         z_real = encoder(images)
        x_recon = conv_decoder(z_real)
        d_real = z_discriminator(conv_encoder(Variable(images.data)))

        # recon_loss = criterion(x_recon, images) - (args.LAMBDA * (torch.log(d_real)).mean())

        recon_loss = 10 * mse_loss_fn(x_recon, images).mean()
        reg_loss = 0.005 * adversarial_loss_fn(d_real, ones).mean()
#         d_loss = args.LAMBDA * (d_real).mean()

        reconstr_loss.append([recon_loss.data.item(), reg_loss.data.item()])

        recon_loss.backward()
        reg_loss.backward()

        opt_encoder.step()
        opt_decoder.step()
        
        pbar.set_description("Recon Loss: [%.3f], Reg loss: [%.3f], disc loss: [%.3f]" % (
            recon_loss.item(), reg_loss.item(), d_loss.item()))

    if (epoch + 1) % 1 == 0:
        reconstr_loss_mean = np.mean(reconstr_loss, axis=0)
        print("Epoch: [%d/%d], Recon Loss: [%.3f], Reg loss: [%.3f], disc_loss: [%.3f], " %
                (epoch + 1, args.n_epochs, 
                 reconstr_loss_mean[0], reconstr_loss_mean[1], np.mean(disc_loss)))
        disc_loss.clear()
        reconstr_loss.clear()

    if (epoch + 1) % 1 == 0:

        batch_size = 64
        test_iter = iter(test_loader)
        test_images, _ = next(test_iter)
        test_images = Variable(test_images)[:64]
        
        z_real = conv_encoder(test_images.cuda()).detach()
        test_recon = conv_decoder(z_real)
        
        test_mse_loss = mse_loss_fn(test_images, test_recon.cpu()).data.item()

        random_sample = conv_decoder(torch.randn_like(z_real)).cpu().view(batch_size, 3, 32, 32)

        test_image_recon = torch.cat((test_images.view(batch_size, 3, 32, 32), 
                                      test_recon.cpu().view(batch_size, 3, 32, 32).data), axis=3)
        
        val_dict = {
        "train_recon_loss": {"recon": reconstr_loss_mean[0],
                            "reg": reconstr_loss_mean[1]
                            },
        "mse_ssim": {"train_mse": reconstr_loss_mean[0], 
                     "test_mse": test_mse_loss,
                     "train_ssim": ssim(images.cpu(), x_recon.cpu()).data.item()
                    },
        "norms": {"real": np.mean(norms),
                  "gaussian": torch.norm(z_prior, dim=1).mean().item()
                 }
        }
        
        norms.clear()
        
        save_values_to_tensorboard(writer, epoch + 1, val_dict)
        
        save_images_to_tensorboard(writer, epoch+1, make_grid(torch.cat((images, x_recon), dim=3)
                                                                           , normalize=False), 'train')
        save_images_to_tensorboard(writer, epoch+1, make_grid(test_image_recon, normalize=False), 'test')
        save_images_to_tensorboard(writer, epoch+1, make_grid(random_sample, normalize=False), 'sample')
        
        

    if (epoch + 1) % 25 == 0:
        models = {nameof(conv_encoder): conv_encoder, 
                  nameof(conv_decoder): conv_decoder}
        save_models(model_path, epoch+1, models)

0it [00:00, ?it/s]

Epoch: [1/200], Recon Loss: [0.650], Reg loss: [0.048], disc_loss: [0.309], 


0it [00:00, ?it/s]

Epoch: [2/200], Recon Loss: [0.623], Reg loss: [0.055], disc_loss: [0.067], 


0it [00:00, ?it/s]

Epoch: [3/200], Recon Loss: [0.619], Reg loss: [0.072], disc_loss: [0.054], 


0it [00:00, ?it/s]

Epoch: [4/200], Recon Loss: [0.620], Reg loss: [0.068], disc_loss: [0.049], 


0it [00:00, ?it/s]

Epoch: [5/200], Recon Loss: [0.619], Reg loss: [0.061], disc_loss: [0.082], 


0it [00:00, ?it/s]

Epoch: [6/200], Recon Loss: [0.616], Reg loss: [0.049], disc_loss: [0.058], 


0it [00:00, ?it/s]

Epoch: [7/200], Recon Loss: [0.613], Reg loss: [0.066], disc_loss: [0.087], 


0it [00:00, ?it/s]

Epoch: [8/200], Recon Loss: [0.613], Reg loss: [0.062], disc_loss: [0.052], 


0it [00:00, ?it/s]

Epoch: [9/200], Recon Loss: [0.612], Reg loss: [0.075], disc_loss: [0.135], 


0it [00:00, ?it/s]

Epoch: [10/200], Recon Loss: [0.608], Reg loss: [0.050], disc_loss: [0.111], 


0it [00:00, ?it/s]

Epoch: [11/200], Recon Loss: [0.609], Reg loss: [0.066], disc_loss: [0.125], 


0it [00:00, ?it/s]

Epoch: [12/200], Recon Loss: [0.609], Reg loss: [0.045], disc_loss: [0.107], 


0it [00:00, ?it/s]

Epoch: [13/200], Recon Loss: [0.598], Reg loss: [0.071], disc_loss: [0.134], 


0it [00:00, ?it/s]

Epoch: [14/200], Recon Loss: [0.609], Reg loss: [0.047], disc_loss: [0.163], 


0it [00:00, ?it/s]

Epoch: [15/200], Recon Loss: [0.586], Reg loss: [0.067], disc_loss: [0.101], 


0it [00:00, ?it/s]

Epoch: [16/200], Recon Loss: [0.570], Reg loss: [0.045], disc_loss: [0.142], 


0it [00:00, ?it/s]

Epoch: [17/200], Recon Loss: [0.534], Reg loss: [0.050], disc_loss: [0.113], 


0it [00:00, ?it/s]

Epoch: [18/200], Recon Loss: [0.512], Reg loss: [0.044], disc_loss: [0.108], 


0it [00:00, ?it/s]

Epoch: [19/200], Recon Loss: [0.492], Reg loss: [0.046], disc_loss: [0.087], 


0it [00:00, ?it/s]

Epoch: [20/200], Recon Loss: [0.475], Reg loss: [0.044], disc_loss: [0.066], 


0it [00:00, ?it/s]

Epoch: [21/200], Recon Loss: [0.458], Reg loss: [0.052], disc_loss: [0.055], 


0it [00:00, ?it/s]

Epoch: [22/200], Recon Loss: [0.444], Reg loss: [0.046], disc_loss: [0.092], 


0it [00:00, ?it/s]

Epoch: [23/200], Recon Loss: [0.436], Reg loss: [0.049], disc_loss: [0.060], 


0it [00:00, ?it/s]

Epoch: [24/200], Recon Loss: [0.417], Reg loss: [0.045], disc_loss: [0.052], 


0it [00:00, ?it/s]

Epoch: [25/200], Recon Loss: [0.403], Reg loss: [0.041], disc_loss: [0.049], 
Saving models


0it [00:00, ?it/s]

Epoch: [26/200], Recon Loss: [0.387], Reg loss: [0.045], disc_loss: [0.064], 


0it [00:00, ?it/s]

Epoch: [27/200], Recon Loss: [0.385], Reg loss: [0.048], disc_loss: [0.062], 


0it [00:00, ?it/s]

Epoch: [28/200], Recon Loss: [0.370], Reg loss: [0.038], disc_loss: [0.059], 


0it [00:00, ?it/s]

Epoch: [29/200], Recon Loss: [0.361], Reg loss: [0.043], disc_loss: [0.078], 


0it [00:00, ?it/s]

Epoch: [30/200], Recon Loss: [0.356], Reg loss: [0.043], disc_loss: [0.061], 


0it [00:00, ?it/s]

Epoch: [31/200], Recon Loss: [0.344], Reg loss: [0.041], disc_loss: [0.069], 


0it [00:00, ?it/s]

Epoch: [32/200], Recon Loss: [0.343], Reg loss: [0.040], disc_loss: [0.066], 


0it [00:00, ?it/s]

Epoch: [33/200], Recon Loss: [0.339], Reg loss: [0.039], disc_loss: [0.071], 


0it [00:00, ?it/s]

Epoch: [34/200], Recon Loss: [0.334], Reg loss: [0.038], disc_loss: [0.074], 


0it [00:00, ?it/s]

Epoch: [35/200], Recon Loss: [0.335], Reg loss: [0.035], disc_loss: [0.063], 


0it [00:00, ?it/s]

Epoch: [36/200], Recon Loss: [0.338], Reg loss: [0.035], disc_loss: [0.074], 


0it [00:00, ?it/s]

Epoch: [37/200], Recon Loss: [0.328], Reg loss: [0.035], disc_loss: [0.070], 


0it [00:00, ?it/s]

Epoch: [38/200], Recon Loss: [0.322], Reg loss: [0.033], disc_loss: [0.077], 


0it [00:00, ?it/s]

Epoch: [39/200], Recon Loss: [0.326], Reg loss: [0.035], disc_loss: [0.077], 


0it [00:00, ?it/s]

Epoch: [40/200], Recon Loss: [0.373], Reg loss: [0.042], disc_loss: [0.053], 


0it [00:00, ?it/s]

Epoch: [41/200], Recon Loss: [0.343], Reg loss: [0.038], disc_loss: [0.049], 


0it [00:00, ?it/s]

Epoch: [42/200], Recon Loss: [0.318], Reg loss: [0.033], disc_loss: [0.060], 


0it [00:00, ?it/s]

Epoch: [43/200], Recon Loss: [0.316], Reg loss: [0.034], disc_loss: [0.071], 


0it [00:00, ?it/s]

Epoch: [44/200], Recon Loss: [0.319], Reg loss: [0.033], disc_loss: [0.084], 


0it [00:00, ?it/s]

Epoch: [45/200], Recon Loss: [0.317], Reg loss: [0.032], disc_loss: [0.091], 


0it [00:00, ?it/s]

Epoch: [46/200], Recon Loss: [0.315], Reg loss: [0.033], disc_loss: [0.076], 


0it [00:00, ?it/s]

Epoch: [47/200], Recon Loss: [0.310], Reg loss: [0.032], disc_loss: [0.091], 


0it [00:00, ?it/s]

Epoch: [48/200], Recon Loss: [0.311], Reg loss: [0.031], disc_loss: [0.087], 


0it [00:00, ?it/s]

Epoch: [49/200], Recon Loss: [0.307], Reg loss: [0.030], disc_loss: [0.086], 


0it [00:00, ?it/s]

Epoch: [50/200], Recon Loss: [0.308], Reg loss: [0.032], disc_loss: [0.101], 
Saving models


0it [00:00, ?it/s]

Epoch: [51/200], Recon Loss: [0.308], Reg loss: [0.031], disc_loss: [0.097], 


0it [00:00, ?it/s]

Epoch: [52/200], Recon Loss: [0.300], Reg loss: [0.031], disc_loss: [0.091], 


0it [00:00, ?it/s]

Epoch: [53/200], Recon Loss: [0.307], Reg loss: [0.032], disc_loss: [0.085], 


0it [00:00, ?it/s]

Epoch: [54/200], Recon Loss: [0.294], Reg loss: [0.029], disc_loss: [0.095], 


0it [00:00, ?it/s]

Epoch: [55/200], Recon Loss: [0.308], Reg loss: [0.032], disc_loss: [0.098], 


0it [00:00, ?it/s]

Epoch: [56/200], Recon Loss: [0.301], Reg loss: [0.032], disc_loss: [0.110], 


0it [00:00, ?it/s]

Epoch: [57/200], Recon Loss: [0.304], Reg loss: [0.030], disc_loss: [0.096], 


0it [00:00, ?it/s]

Epoch: [58/200], Recon Loss: [0.300], Reg loss: [0.030], disc_loss: [0.103], 


0it [00:00, ?it/s]

Epoch: [59/200], Recon Loss: [0.383], Reg loss: [0.036], disc_loss: [0.090], 


0it [00:00, ?it/s]

Epoch: [60/200], Recon Loss: [0.313], Reg loss: [0.035], disc_loss: [0.082], 


0it [00:00, ?it/s]

Epoch: [61/200], Recon Loss: [0.303], Reg loss: [0.031], disc_loss: [0.113], 


0it [00:00, ?it/s]

Epoch: [62/200], Recon Loss: [0.303], Reg loss: [0.029], disc_loss: [0.114], 


0it [00:00, ?it/s]

Epoch: [63/200], Recon Loss: [0.297], Reg loss: [0.030], disc_loss: [0.119], 


0it [00:00, ?it/s]

Epoch: [64/200], Recon Loss: [0.309], Reg loss: [0.028], disc_loss: [0.143], 


0it [00:00, ?it/s]

Epoch: [65/200], Recon Loss: [0.287], Reg loss: [0.029], disc_loss: [0.138], 


0it [00:00, ?it/s]

Epoch: [66/200], Recon Loss: [0.308], Reg loss: [0.028], disc_loss: [0.134], 


0it [00:00, ?it/s]

Epoch: [67/200], Recon Loss: [0.305], Reg loss: [0.028], disc_loss: [0.134], 


0it [00:00, ?it/s]

Epoch: [68/200], Recon Loss: [0.292], Reg loss: [0.028], disc_loss: [0.126], 


0it [00:00, ?it/s]

Epoch: [69/200], Recon Loss: [0.305], Reg loss: [0.027], disc_loss: [0.148], 


0it [00:00, ?it/s]

Epoch: [70/200], Recon Loss: [0.290], Reg loss: [0.027], disc_loss: [0.132], 


0it [00:00, ?it/s]

Epoch: [71/200], Recon Loss: [0.292], Reg loss: [0.027], disc_loss: [0.139], 


0it [00:00, ?it/s]

Epoch: [72/200], Recon Loss: [0.302], Reg loss: [0.026], disc_loss: [0.142], 


0it [00:00, ?it/s]

Epoch: [73/200], Recon Loss: [0.290], Reg loss: [0.027], disc_loss: [0.152], 


0it [00:00, ?it/s]

Epoch: [74/200], Recon Loss: [0.295], Reg loss: [0.027], disc_loss: [0.141], 


0it [00:00, ?it/s]

Epoch: [75/200], Recon Loss: [0.299], Reg loss: [0.027], disc_loss: [0.143], 
Saving models


0it [00:00, ?it/s]

Epoch: [76/200], Recon Loss: [0.278], Reg loss: [0.025], disc_loss: [0.154], 


0it [00:00, ?it/s]

Epoch: [77/200], Recon Loss: [0.298], Reg loss: [0.027], disc_loss: [0.147], 


0it [00:00, ?it/s]

Epoch: [78/200], Recon Loss: [0.293], Reg loss: [0.027], disc_loss: [0.147], 


0it [00:00, ?it/s]

Epoch: [79/200], Recon Loss: [0.281], Reg loss: [0.026], disc_loss: [0.163], 


0it [00:00, ?it/s]

Epoch: [80/200], Recon Loss: [0.302], Reg loss: [0.026], disc_loss: [0.172], 


0it [00:00, ?it/s]

Epoch: [81/200], Recon Loss: [0.307], Reg loss: [0.026], disc_loss: [0.182], 


0it [00:00, ?it/s]

Epoch: [82/200], Recon Loss: [0.412], Reg loss: [0.035], disc_loss: [0.105], 


0it [00:00, ?it/s]

Epoch: [83/200], Recon Loss: [0.314], Reg loss: [0.034], disc_loss: [0.072], 


0it [00:00, ?it/s]

Epoch: [84/200], Recon Loss: [0.309], Reg loss: [0.031], disc_loss: [0.090], 


0it [00:00, ?it/s]

Epoch: [85/200], Recon Loss: [0.281], Reg loss: [0.030], disc_loss: [0.103], 


0it [00:00, ?it/s]

Epoch: [86/200], Recon Loss: [0.301], Reg loss: [0.028], disc_loss: [0.127], 


0it [00:00, ?it/s]

Epoch: [87/200], Recon Loss: [0.302], Reg loss: [0.027], disc_loss: [0.125], 


0it [00:00, ?it/s]

Epoch: [88/200], Recon Loss: [0.283], Reg loss: [0.027], disc_loss: [0.134], 


0it [00:00, ?it/s]

Epoch: [89/200], Recon Loss: [0.290], Reg loss: [0.027], disc_loss: [0.126], 


0it [00:00, ?it/s]

Epoch: [90/200], Recon Loss: [0.306], Reg loss: [0.027], disc_loss: [0.135], 


0it [00:00, ?it/s]

Epoch: [91/200], Recon Loss: [0.281], Reg loss: [0.027], disc_loss: [0.131], 


0it [00:00, ?it/s]

Epoch: [92/200], Recon Loss: [0.279], Reg loss: [0.028], disc_loss: [0.140], 


0it [00:00, ?it/s]

Epoch: [93/200], Recon Loss: [0.301], Reg loss: [0.028], disc_loss: [0.141], 


0it [00:00, ?it/s]

Epoch: [94/200], Recon Loss: [0.286], Reg loss: [0.027], disc_loss: [0.145], 


0it [00:00, ?it/s]

Epoch: [95/200], Recon Loss: [0.275], Reg loss: [0.028], disc_loss: [0.144], 


0it [00:00, ?it/s]

Epoch: [96/200], Recon Loss: [0.291], Reg loss: [0.027], disc_loss: [0.150], 


0it [00:00, ?it/s]

Epoch: [97/200], Recon Loss: [0.295], Reg loss: [0.027], disc_loss: [0.134], 


0it [00:00, ?it/s]

Epoch: [98/200], Recon Loss: [0.270], Reg loss: [0.026], disc_loss: [0.163], 


0it [00:00, ?it/s]

Epoch: [99/200], Recon Loss: [0.286], Reg loss: [0.027], disc_loss: [0.185], 


0it [00:00, ?it/s]

Epoch: [100/200], Recon Loss: [0.295], Reg loss: [0.026], disc_loss: [0.178], 
Saving models


0it [00:00, ?it/s]

Epoch: [101/200], Recon Loss: [0.294], Reg loss: [0.026], disc_loss: [0.136], 


0it [00:00, ?it/s]

Epoch: [102/200], Recon Loss: [0.423], Reg loss: [0.034], disc_loss: [0.108], 


0it [00:00, ?it/s]

Epoch: [103/200], Recon Loss: [0.363], Reg loss: [0.044], disc_loss: [0.032], 


0it [00:00, ?it/s]

Epoch: [104/200], Recon Loss: [0.299], Reg loss: [0.036], disc_loss: [0.068], 


0it [00:00, ?it/s]

Epoch: [105/200], Recon Loss: [0.276], Reg loss: [0.039], disc_loss: [0.090], 


0it [00:00, ?it/s]

Epoch: [106/200], Recon Loss: [0.289], Reg loss: [0.034], disc_loss: [0.097], 


0it [00:00, ?it/s]

Epoch: [107/200], Recon Loss: [0.299], Reg loss: [0.034], disc_loss: [0.114], 


0it [00:00, ?it/s]

Epoch: [108/200], Recon Loss: [0.284], Reg loss: [0.030], disc_loss: [0.147], 


0it [00:00, ?it/s]

Epoch: [109/200], Recon Loss: [0.271], Reg loss: [0.028], disc_loss: [0.138], 


0it [00:00, ?it/s]

Epoch: [110/200], Recon Loss: [0.299], Reg loss: [0.029], disc_loss: [0.134], 


0it [00:00, ?it/s]

Epoch: [111/200], Recon Loss: [0.292], Reg loss: [0.027], disc_loss: [0.150], 


0it [00:00, ?it/s]

Epoch: [112/200], Recon Loss: [0.271], Reg loss: [0.027], disc_loss: [0.163], 


0it [00:00, ?it/s]

Epoch: [113/200], Recon Loss: [0.281], Reg loss: [0.027], disc_loss: [0.160], 


0it [00:00, ?it/s]

Epoch: [114/200], Recon Loss: [0.297], Reg loss: [0.027], disc_loss: [0.152], 


0it [00:00, ?it/s]

Epoch: [115/200], Recon Loss: [0.283], Reg loss: [0.025], disc_loss: [0.168], 


0it [00:00, ?it/s]

Epoch: [116/200], Recon Loss: [0.263], Reg loss: [0.024], disc_loss: [0.156], 


0it [00:00, ?it/s]

Epoch: [117/200], Recon Loss: [0.290], Reg loss: [0.026], disc_loss: [0.155], 


0it [00:00, ?it/s]

Epoch: [118/200], Recon Loss: [0.295], Reg loss: [0.026], disc_loss: [0.165], 


0it [00:00, ?it/s]

Epoch: [119/200], Recon Loss: [0.268], Reg loss: [0.026], disc_loss: [0.156], 


0it [00:00, ?it/s]

Epoch: [120/200], Recon Loss: [0.272], Reg loss: [0.027], disc_loss: [0.146], 


0it [00:00, ?it/s]

Epoch: [121/200], Recon Loss: [0.291], Reg loss: [0.026], disc_loss: [0.152], 


0it [00:00, ?it/s]

Epoch: [122/200], Recon Loss: [0.304], Reg loss: [0.027], disc_loss: [0.141], 


0it [00:00, ?it/s]

Epoch: [123/200], Recon Loss: [0.456], Reg loss: [0.033], disc_loss: [0.109], 


0it [00:00, ?it/s]

Epoch: [124/200], Recon Loss: [0.525], Reg loss: [0.048], disc_loss: [0.033], 


0it [00:00, ?it/s]

Epoch: [125/200], Recon Loss: [0.388], Reg loss: [0.042], disc_loss: [0.038], 
Saving models


0it [00:00, ?it/s]

Epoch: [126/200], Recon Loss: [0.331], Reg loss: [0.039], disc_loss: [0.042], 


0it [00:00, ?it/s]

Epoch: [127/200], Recon Loss: [0.298], Reg loss: [0.047], disc_loss: [0.052], 


0it [00:00, ?it/s]

Epoch: [128/200], Recon Loss: [0.313], Reg loss: [0.040], disc_loss: [0.063], 


0it [00:00, ?it/s]

Epoch: [129/200], Recon Loss: [0.334], Reg loss: [0.037], disc_loss: [0.069], 


0it [00:00, ?it/s]

Epoch: [130/200], Recon Loss: [0.325], Reg loss: [0.038], disc_loss: [0.074], 


0it [00:00, ?it/s]

Epoch: [131/200], Recon Loss: [0.290], Reg loss: [0.036], disc_loss: [0.067], 


0it [00:00, ?it/s]

Epoch: [132/200], Recon Loss: [0.300], Reg loss: [0.036], disc_loss: [0.058], 


0it [00:00, ?it/s]

Epoch: [133/200], Recon Loss: [0.319], Reg loss: [0.034], disc_loss: [0.067], 


0it [00:00, ?it/s]

Epoch: [134/200], Recon Loss: [0.300], Reg loss: [0.036], disc_loss: [0.065], 


0it [00:00, ?it/s]

Epoch: [135/200], Recon Loss: [0.276], Reg loss: [0.035], disc_loss: [0.078], 


0it [00:00, ?it/s]

Epoch: [136/200], Recon Loss: [0.291], Reg loss: [0.034], disc_loss: [0.080], 


0it [00:00, ?it/s]

Epoch: [137/200], Recon Loss: [0.320], Reg loss: [0.032], disc_loss: [0.083], 


0it [00:00, ?it/s]

Epoch: [138/200], Recon Loss: [0.313], Reg loss: [0.034], disc_loss: [0.083], 


0it [00:00, ?it/s]

Epoch: [139/200], Recon Loss: [0.331], Reg loss: [0.035], disc_loss: [0.090], 


0it [00:00, ?it/s]

Epoch: [140/200], Recon Loss: [0.354], Reg loss: [0.040], disc_loss: [0.074], 


0it [00:00, ?it/s]

Epoch: [141/200], Recon Loss: [0.329], Reg loss: [0.038], disc_loss: [0.063], 


0it [00:00, ?it/s]

Epoch: [142/200], Recon Loss: [0.305], Reg loss: [0.035], disc_loss: [0.067], 


0it [00:00, ?it/s]

Epoch: [143/200], Recon Loss: [0.271], Reg loss: [0.034], disc_loss: [0.070], 


0it [00:00, ?it/s]

Epoch: [144/200], Recon Loss: [0.283], Reg loss: [0.034], disc_loss: [0.087], 


0it [00:00, ?it/s]

Epoch: [145/200], Recon Loss: [0.317], Reg loss: [0.033], disc_loss: [0.082], 


0it [00:00, ?it/s]

Epoch: [146/200], Recon Loss: [0.300], Reg loss: [0.033], disc_loss: [0.072], 


0it [00:00, ?it/s]

Epoch: [147/200], Recon Loss: [0.270], Reg loss: [0.033], disc_loss: [0.085], 


0it [00:00, ?it/s]

Epoch: [148/200], Recon Loss: [0.277], Reg loss: [0.033], disc_loss: [0.086], 


0it [00:00, ?it/s]

Epoch: [149/200], Recon Loss: [0.302], Reg loss: [0.035], disc_loss: [0.076], 


0it [00:00, ?it/s]

Epoch: [150/200], Recon Loss: [0.284], Reg loss: [0.032], disc_loss: [0.099], 
Saving models


0it [00:00, ?it/s]

Epoch: [151/200], Recon Loss: [0.267], Reg loss: [0.030], disc_loss: [0.083], 


0it [00:00, ?it/s]

Epoch: [152/200], Recon Loss: [0.279], Reg loss: [0.033], disc_loss: [0.074], 


0it [00:00, ?it/s]

Epoch: [153/200], Recon Loss: [0.300], Reg loss: [0.031], disc_loss: [0.102], 


0it [00:00, ?it/s]

Epoch: [154/200], Recon Loss: [0.276], Reg loss: [0.033], disc_loss: [0.095], 


0it [00:00, ?it/s]

Epoch: [155/200], Recon Loss: [0.270], Reg loss: [0.033], disc_loss: [0.089], 


0it [00:00, ?it/s]

Epoch: [156/200], Recon Loss: [0.286], Reg loss: [0.033], disc_loss: [0.094], 


0it [00:00, ?it/s]

Epoch: [157/200], Recon Loss: [0.286], Reg loss: [0.031], disc_loss: [0.088], 


0it [00:00, ?it/s]

Epoch: [158/200], Recon Loss: [0.260], Reg loss: [0.031], disc_loss: [0.102], 


0it [00:00, ?it/s]

Epoch: [159/200], Recon Loss: [0.268], Reg loss: [0.031], disc_loss: [0.100], 


0it [00:00, ?it/s]

Epoch: [160/200], Recon Loss: [0.291], Reg loss: [0.030], disc_loss: [0.102], 


0it [00:00, ?it/s]

Epoch: [161/200], Recon Loss: [0.274], Reg loss: [0.030], disc_loss: [0.090], 


0it [00:00, ?it/s]

Epoch: [162/200], Recon Loss: [0.260], Reg loss: [0.030], disc_loss: [0.085], 


0it [00:00, ?it/s]

Epoch: [163/200], Recon Loss: [0.280], Reg loss: [0.032], disc_loss: [0.107], 


0it [00:00, ?it/s]

Epoch: [164/200], Recon Loss: [0.292], Reg loss: [0.030], disc_loss: [0.109], 


0it [00:00, ?it/s]

Epoch: [165/200], Recon Loss: [0.262], Reg loss: [0.030], disc_loss: [0.101], 


0it [00:00, ?it/s]

Epoch: [166/200], Recon Loss: [0.273], Reg loss: [0.031], disc_loss: [0.113], 


0it [00:00, ?it/s]

Epoch: [167/200], Recon Loss: [0.289], Reg loss: [0.029], disc_loss: [0.114], 


0it [00:00, ?it/s]

Epoch: [168/200], Recon Loss: [0.280], Reg loss: [0.030], disc_loss: [0.110], 


0it [00:00, ?it/s]

Epoch: [169/200], Recon Loss: [0.259], Reg loss: [0.029], disc_loss: [0.107], 


0it [00:00, ?it/s]

Epoch: [170/200], Recon Loss: [0.282], Reg loss: [0.030], disc_loss: [0.115], 


0it [00:00, ?it/s]

Epoch: [171/200], Recon Loss: [0.286], Reg loss: [0.029], disc_loss: [0.101], 


0it [00:00, ?it/s]

Epoch: [172/200], Recon Loss: [0.261], Reg loss: [0.031], disc_loss: [0.108], 


0it [00:00, ?it/s]

Epoch: [173/200], Recon Loss: [0.291], Reg loss: [0.028], disc_loss: [0.105], 


0it [00:00, ?it/s]

Epoch: [174/200], Recon Loss: [0.292], Reg loss: [0.030], disc_loss: [0.089], 


0it [00:00, ?it/s]

Epoch: [175/200], Recon Loss: [0.266], Reg loss: [0.029], disc_loss: [0.094], 
Saving models


0it [00:00, ?it/s]

Epoch: [176/200], Recon Loss: [0.262], Reg loss: [0.030], disc_loss: [0.096], 


0it [00:00, ?it/s]

Epoch: [177/200], Recon Loss: [0.291], Reg loss: [0.031], disc_loss: [0.091], 


0it [00:00, ?it/s]

Epoch: [178/200], Recon Loss: [0.270], Reg loss: [0.029], disc_loss: [0.114], 


0it [00:00, ?it/s]

Epoch: [179/200], Recon Loss: [0.260], Reg loss: [0.030], disc_loss: [0.111], 


0it [00:00, ?it/s]

Epoch: [180/200], Recon Loss: [0.287], Reg loss: [0.027], disc_loss: [0.111], 


0it [00:00, ?it/s]

Epoch: [181/200], Recon Loss: [0.277], Reg loss: [0.028], disc_loss: [0.121], 


0it [00:00, ?it/s]

Epoch: [182/200], Recon Loss: [0.258], Reg loss: [0.028], disc_loss: [0.106], 


0it [00:00, ?it/s]

Epoch: [183/200], Recon Loss: [0.284], Reg loss: [0.029], disc_loss: [0.119], 


0it [00:00, ?it/s]

Epoch: [184/200], Recon Loss: [0.286], Reg loss: [0.027], disc_loss: [0.115], 


0it [00:00, ?it/s]

Epoch: [185/200], Recon Loss: [0.253], Reg loss: [0.029], disc_loss: [0.116], 


0it [00:00, ?it/s]

Epoch: [186/200], Recon Loss: [0.281], Reg loss: [0.028], disc_loss: [0.122], 


0it [00:00, ?it/s]

Epoch: [187/200], Recon Loss: [0.288], Reg loss: [0.026], disc_loss: [0.142], 


0it [00:00, ?it/s]

Epoch: [188/200], Recon Loss: [0.264], Reg loss: [0.026], disc_loss: [0.132], 


0it [00:00, ?it/s]

Epoch: [189/200], Recon Loss: [0.275], Reg loss: [0.026], disc_loss: [0.133], 


0it [00:00, ?it/s]

Epoch: [190/200], Recon Loss: [0.299], Reg loss: [0.027], disc_loss: [0.116], 


0it [00:00, ?it/s]

Epoch: [191/200], Recon Loss: [0.265], Reg loss: [0.027], disc_loss: [0.115], 


0it [00:00, ?it/s]

Epoch: [192/200], Recon Loss: [0.263], Reg loss: [0.028], disc_loss: [0.104], 


0it [00:00, ?it/s]

Epoch: [193/200], Recon Loss: [0.297], Reg loss: [0.027], disc_loss: [0.127], 


0it [00:00, ?it/s]

Epoch: [194/200], Recon Loss: [0.275], Reg loss: [0.027], disc_loss: [0.117], 


0it [00:00, ?it/s]

Epoch: [195/200], Recon Loss: [0.265], Reg loss: [0.028], disc_loss: [0.111], 


0it [00:00, ?it/s]

Epoch: [196/200], Recon Loss: [0.292], Reg loss: [0.027], disc_loss: [0.120], 


0it [00:00, ?it/s]

Epoch: [197/200], Recon Loss: [0.290], Reg loss: [0.028], disc_loss: [0.124], 


0it [00:00, ?it/s]

Epoch: [198/200], Recon Loss: [0.254], Reg loss: [0.027], disc_loss: [0.138], 


0it [00:00, ?it/s]

Epoch: [199/200], Recon Loss: [0.275], Reg loss: [0.026], disc_loss: [0.144], 


0it [00:00, ?it/s]

Epoch: [200/200], Recon Loss: [0.302], Reg loss: [0.029], disc_loss: [0.109], 
Saving models


In [55]:
torch.norm(z_prior, dim=1).mean().item()

11.355522155761719

In [27]:
reg_loss

tensor(32.2346, device='cuda:0', grad_fn=<MulBackward0>)

In [None]:
image = torch.cat((test_data[0].view(batch_size, 1, 28, 28), 
                                reconst.data), axis=3)

In [None]:
image.shape

torch.Size([128, 1, 28, 56])

In [None]:
np.mean(disc_loss, axis=0)

In [None]:
disc_loss[0]

Output hidden; open in https://colab.research.google.com to view.

In [3]:
# encoder block (used in encoder and discriminator)
class EncoderBlock(nn.Module):
    def __init__(self, channel_in, channel_out):
        super(EncoderBlock, self).__init__()
        # convolution to halve the dimensions
        self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=5, padding=2, stride=2,
                              bias=False)
        self.bn = nn.BatchNorm2d(num_features=channel_out, momentum=0.9)

    def forward(self, ten, out=False,t = False):
        # here we want to be able to take an intermediate output for reconstruction error
        if out:
            ten = self.conv(ten)
            ten_out = ten
            ten = self.bn(ten)
            ten = F.relu(ten, False)
            return ten, ten_out
        else:
            ten = self.conv(ten)
            ten = self.bn(ten)
            ten = F.relu(ten, True)
            return ten


# decoder block (used in the decoder)
class DecoderBlock(nn.Module):
    def __init__(self, channel_in, channel_out):
        super(DecoderBlock, self).__init__()
        # transpose convolution to double the dimensions
        self.conv = nn.ConvTranspose2d(channel_in, channel_out, kernel_size=5, padding=2, stride=2, output_padding=1,
                                       bias=False)
        self.bn = nn.BatchNorm2d(channel_out, momentum=0.9)

    def forward(self, ten):
        ten = self.conv(ten)
        ten = self.bn(ten)
        ten = F.relu(ten, True)
        return ten


class Encoder(nn.Module):
    def __init__(self, channel_in=3, z_size=128):
        super(Encoder, self).__init__()
        self.size = channel_in
        layers_list = []
        # the first time 3->64, for every other double the channel size
        for i in range(3):
            if i == 0:
                layers_list.append(EncoderBlock(channel_in=self.size, channel_out=64))
                self.size = 64
            else:
                layers_list.append(EncoderBlock(channel_in=self.size, channel_out=self.size * 2))
                self.size *= 2

        # final shape Bx256x8x8
        self.conv = nn.Sequential(*layers_list)
        self.fc = nn.Sequential(nn.Linear(in_features=8 * 8 * self.size, out_features=1024, bias=False),
                                nn.BatchNorm1d(num_features=1024,momentum=0.9),
                                nn.ReLU(True))
        # two linear to get the mu vector and the diagonal of the log_variance
        self.l_mu = nn.Linear(in_features=1024, out_features=z_size)
        self.l_var = nn.Linear(in_features=1024, out_features=z_size)

    def forward(self, ten):
        ten = self.conv(ten)
        ten = ten.view(len(ten), -1)
        ten = self.fc(ten)
        mu = self.l_mu(ten)
        logvar = self.l_var(ten)
        return mu, logvar

    def __call__(self, *args, **kwargs):
        return super(Encoder, self).__call__(*args, **kwargs)


class Decoder(nn.Module):
    def __init__(self, z_size, size):
        super(Decoder, self).__init__()
        # start from B*z_size
        self.fc = nn.Sequential(nn.Linear(in_features=z_size, out_features=8 * 8 * size, bias=False),
                                nn.BatchNorm1d(num_features=8 * 8 * size,momentum=0.9),
                                nn.ReLU(True))
        self.size = size
        layers_list = []
        layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size))
        layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size//2))
        self.size = self.size//2
        layers_list.append(DecoderBlock(channel_in=self.size, channel_out=self.size//4))
        self.size = self.size//4
        # final conv to get 3 channels and tanh layer
        layers_list.append(nn.Sequential(
            nn.Conv2d(in_channels=self.size, out_channels=3, kernel_size=5, stride=1, padding=2),
            nn.Tanh()
        ))

        self.conv = nn.Sequential(*layers_list)

    def forward(self, ten):

        ten = self.fc(ten)
        ten = ten.view(len(ten), -1, 8, 8)
        ten = self.conv(ten)
        return ten

    def __call__(self, *args, **kwargs):
        return super(Decoder, self).__call__(*args, **kwargs)


class Discriminator(nn.Module):
    def __init__(self, channel_in=3,recon_level=3):
        super(Discriminator, self).__init__()
        self.size = channel_in
        self.recon_levl = recon_level
        # module list because we need need to extract an intermediate output
        self.conv = nn.ModuleList()
        self.conv.append(nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True)))
        self.size = 32
        self.conv.append(EncoderBlock(channel_in=self.size, channel_out=128))
        self.size = 128
        self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256))
        self.size = 256
        self.conv.append(EncoderBlock(channel_in=self.size, channel_out=256))
        # final fc to get the score (real or fake)
        self.fc = nn.Sequential(
            nn.Linear(in_features=8 * 8 * self.size, out_features=512, bias=False),
            nn.BatchNorm1d(num_features=512,momentum=0.9),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=1),
        )

    def forward(self, ten_orig, ten_predicted, ten_sampled, mode='REC'):
        if mode == "REC":
            ten = torch.cat((ten_orig, ten_predicted, ten_sampled), 0)
            for i, lay in enumerate(self.conv):
                # we take the 9th layer as one of the outputs
                if i == self.recon_levl:
                    ten, layer_ten = lay(ten, True)
                    # we need the layer representations just for the original and reconstructed,
                    # flatten, because it's a convolutional shape
                    layer_ten = layer_ten.view(len(layer_ten), -1)
                    return layer_ten
                else:
                    ten = lay(ten)
        else:
            ten = torch.cat((ten_orig, ten_predicted, ten_sampled), 0)
            for i, lay in enumerate(self.conv):
                    ten = lay(ten)

            ten = ten.view(len(ten), -1)
            ten = self.fc(ten)
            return F.sigmoid(ten)


    def __call__(self, *args, **kwargs):
        return super(Discriminator, self).__call__(*args, **kwargs)

In [4]:
disc = Discriminator()

In [None]:
# print(summary(conv_encoder, torch.zeros((1, 3, 32, 32)).cuda(), show_input=False, show_hierarchical=False))
# print(summary(conv_decoder, torch.zeros((1, 1, 100)).cuda(), show_input=False, show_hierarchical=False))
print(summary(disc, torch.zeros((1, 1, 32, 32)).cuda(), show_input=False, show_hierarchical=False))