In [57]:
from varname import nameof
from collections import defaultdict, OrderedDict
import sys
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import interpolate
from torch.autograd import Variable
from torchvision.transforms import transforms
from torchvision.utils import save_image, make_grid
from torch.optim.lr_scheduler import StepLR, ExponentialLR, ReduceLROnPlateau, CyclicLR
from pytorch_model_summary import summary
from tensorboardX import SummaryWriter
from tqdm.notebook import tqdm_notebook as tq

In [58]:
from torch.utils.data import Dataset, DataLoader
from src.dataset.dataset import Cifar
from src.dataset.utils import SavePath
from src.config import TrainConfig
from src.pytorch_msssim import MSSSIM, ssim

In [59]:
base_path = !pwd
base_path = base_path[0] + '/'

In [60]:
args = TrainConfig( base_path,              # project directory path
                    n_epochs = 200,         # number of epochs to train (default: 100)
                    batch_size = 64,        # input batch size for training (default: 128)
                    lr = 1e-4,              # learning rate (default: 0.0001)
                    dim_h = 128,            # hidden dimension (default: 128)')
                    n_z = 128,              # hidden dimension of z (default: 8)
                    LAMBDA = 10,            # regularization coef term (default: 10)
                    sigma = 1,              # variance of hidden dimension (default: 1)
                    n_channel = 3,          # input channels (default: 1)
                    img_size = (160, 120))         # image size

In [61]:
def unfreeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = True

def freeze_params(module: nn.Module):
    for p in module.parameters():
        p.requires_grad = False

In [62]:
def save_models(model_path, epoch_no, models):
    print("Saving models")
    for model_name, model in models.items():
        torch.save(model.state_dict(), model_path + model_name + "_" + "%d.pth" % epoch_no)

def save_values_to_tensorboard(writer, epoch_no, values_dict: dict):
    for name, val in values_dict.items():
        if type(val) == dict:
            writer.add_scalars(name, val, epoch_no)
        else:
            writer.add_scalar(name, val, epoch_no)

def save_images_to_tensorboard(writer, epoch_no, image, imname='im'):
    writer.add_image(imname +'_{}'.format(epoch_no), image, epoch_no)

In [63]:
# sp = SavePath(args, checkpoint_path='/x1/data/synth.data/Autoencoder/outs/Fri-Apr-16-02-46-03-2021/')
sp = SavePath(args)

/x1/data/synth.data/Autoencoder/outs/Mon-May-31-23-27-06-2021/


In [64]:
transform = None # dont normalize

cdl = Cifar(args)
train_loader = cdl.get_data_loader(True, transform, [0,1,2,3,4,5,6,7,8,9])
test_loader = cdl.get_data_loader(False, transform, [0,1,2,3,4,5,6,7,8,9])

Files already downloaded and verified
Files already downloaded and verified


In [65]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, conv_out=False):
        super(ConvBlock, self).__init__()
        self.conv_out = conv_out  
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.relu = nn.ReLU(True)
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        x = self.conv(x)
        if self.conv_out:
            x1 = x
        x = self.relu(x)
        x = self.bn(x)
        if self.conv_out:
            return x, x1
        return x

class ConvTransposeBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, sigmoid_out=False):
        super(ConvTransposeBlock, self).__init__() 
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.sigmoid_out = sigmoid_out
        self.activation = nn.ReLU(True)
        if self.sigmoid_out:
            self.activation = nn.Sigmoid() 
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.activation(x)
        if self.sigmoid_out:
            return x
        x = self.bn(x)
        return x

class Encoder(nn.Module):
    def __init__(self, n_channel, dim_h, n_z):
        super(Encoder, self).__init__()

        self.n_channel = n_channel
        self.dim_h = dim_h
        self.n_z = n_z
        
        self.blocks = nn.ModuleList([ConvBlock(self.n_channel, self.dim_h, conv_out=True),
                       ConvBlock(self.dim_h, self.dim_h * 2, conv_out=True),
                       ConvBlock(self.dim_h * 2, self.dim_h * 4, conv_out=True),
                       ConvBlock(self.dim_h * 4, self.dim_h * 8, conv_out=True),
                       ConvBlock(self.dim_h * 8, self.dim_h * 8, conv_out=True)])
                
        self.fc = nn.Sequential(nn.Linear(self.dim_h * 8, self.dim_h * 4),
                                nn.ReLU(True),
                                nn.Linear(self.dim_h * 4, self.n_z)
                               )

    def forward(self, x, skip_from=6):
        
        for block_id in range(len(self.blocks)):
            x, x1 = self.blocks[block_id](x)
            if block_id == skip_from-1:
                return x, x1
            
        x = x.reshape(-1, self.dim_h * 8)
        x = self.fc(x)
        return x

class Decoder(nn.Module):
    def __init__(self, n_channel, dim_h, n_z):
        super(Decoder, self).__init__()

        self.n_channel = n_channel
        self.dim_h = dim_h
        self.n_z = n_z

        self.proj = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 4),
            nn.ReLU(True),
            nn.Linear(self.dim_h * 4, self.dim_h * 8),
            nn.ReLU(True))
        
        self.blocks = nn.ModuleList([ConvTransposeBlock(self.dim_h * 8, self.dim_h * 8),
                         ConvTransposeBlock(self.dim_h * 8, self.dim_h * 4),
                       ConvTransposeBlock(self.dim_h * 4, self.dim_h * 2),
                       ConvTransposeBlock(self.dim_h * 2, self.dim_h * 1),
                       ConvTransposeBlock(self.dim_h * 1, self.n_channel, sigmoid_out=True)
                                    ])

    def forward(self, x, start_from=0):
        if start_from == 0:
            start_from += 1
            x = self.proj(x)              
            x = x.view(-1, self.dim_h * 8, 1, 1)
        
        for block_id in range(start_from, len(self.blocks)+1):
            x = self.blocks[block_id-1](x)
        return x



class AutoEncoder(nn.Module):
    def __init__(self, n_channel, dim_h, n_z):
        super(AutoEncoder, self).__init__()
        self.E = Encoder(n_channel, dim_h, n_z)
        self.G = Decoder(n_channel, dim_h, n_z)

    def forward(self, x, skip_from=6):
        if skip_from == 6:
            z = self.E(x)
            x_tilde = self.G(z)
        else:
            z, x1 = self.E(x, skip_from) 
            x_tilde = self.G(x1, 6-skip_from)
            
        return z, x_tilde
            

In [45]:
e = Encoder(3, 128, 128)
d = Decoder(3, 128, 128)

In [46]:
print(summary(e, torch.zeros((args.batch_size, 3, 32, 32)), show_input=False, show_hierarchical=False))

--------------------------------------------------------------------------------------------
      Layer (type)                             Output Shape         Param #     Tr. Param #
       ConvBlock-1     [16, 128, 16, 16], [16, 128, 16, 16]           6,528           6,528
       ConvBlock-2         [16, 256, 8, 8], [16, 256, 8, 8]         525,056         525,056
       ConvBlock-3         [16, 512, 4, 4], [16, 512, 4, 4]       2,098,688       2,098,688
       ConvBlock-4       [16, 1024, 2, 2], [16, 1024, 2, 2]       8,391,680       8,391,680
       ConvBlock-5       [16, 1024, 1, 1], [16, 1024, 1, 1]      16,780,288      16,780,288
          Linear-6                                [16, 512]         524,800         524,800
            ReLU-7                                [16, 512]               0               0
          Linear-8                                [16, 128]          65,664          65,664
Total params: 28,392,704
Trainable params: 28,392,704
Non-trainable params: 0
-

In [47]:
print(summary(d, torch.zeros((args.batch_size, 128)), show_input=False, show_hierarchical=False))

------------------------------------------------------------------------------
           Layer (type)          Output Shape         Param #     Tr. Param #
               Linear-1             [16, 512]          66,048          66,048
                 ReLU-2             [16, 512]               0               0
               Linear-3            [16, 1024]         525,312         525,312
                 ReLU-4            [16, 1024]               0               0
   ConvTransposeBlock-5      [16, 1024, 2, 2]      16,780,288      16,780,288
   ConvTransposeBlock-6       [16, 512, 4, 4]       8,390,144       8,390,144
   ConvTransposeBlock-7       [16, 256, 8, 8]       2,097,920       2,097,920
   ConvTransposeBlock-8     [16, 128, 16, 16]         524,672         524,672
   ConvTransposeBlock-9       [16, 3, 32, 32]           6,153           6,153
Total params: 28,390,537
Trainable params: 28,390,537
Non-trainable params: 0
---------------------------------------------------------------

In [48]:
AE_img = AutoEncoder(3, 128, 128).cuda()
# AE_depth = AutoEncoder(1, 128, 128).cuda()

mse_loss_fn = nn.MSELoss().cuda()
adversarial_loss_fn = nn.BCELoss().cuda()

opt_AE_img = optim.Adam(AE_img.parameters(), lr = args.lr)

# opt_AE_depth = optim.Adam(AE_depth.parameters(), lr = args.lr)

In [49]:
AE_img

AutoEncoder(
  (E): Encoder(
    (blocks): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(3, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (relu): ReLU(inplace=True)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ConvBlock(
        (conv): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (relu): ReLU(inplace=True)
        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (2): ConvBlock(
        (conv): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (relu): ReLU(inplace=True)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): ConvBlock(
        (conv): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (relu): ReLU(inplace=True)
        (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, tra

In [67]:
len(train_loader)

782

In [51]:
print(summary(AE_img, torch.zeros((args.batch_size, 3, 32, 32)).cuda(), show_input=False, show_hierarchical=False))

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
         Encoder-1           [16, 128]      28,392,704      28,392,704
         Decoder-2     [16, 3, 32, 32]      28,390,537      28,390,537
Total params: 56,783,241
Trainable params: 56,783,241
Non-trainable params: 0
-----------------------------------------------------------------------


In [53]:
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
        nn.init.xavier_uniform_(m.weight)
        
for model in [AE_img]:
    model.apply(init_weights)

In [54]:
def load_models(checkpoint_path, checkpoint):
    lp = SavePath(args, checkpoint_path)
    _, _, model_load_path = lp.get_save_paths()
    AE_img.load_state_dict(torch.load(model_load_path + "/AE_img_{}.pth".format(checkpoint)))
    AE_depth.load_state_dict(torch.load(model_load_path + "/AE_depth_{}.pth".format(checkpoint)))

In [55]:
checkpoint = 0
if checkpoint:
    load_models('/x1/data/synth.data/Autoencoder/outs/Fri-Apr-16-02-46-03-2021/', checkpoint)

running_losses = defaultdict(list)
running_norms = defaultdict(list)

writer = SummaryWriter(log_dir = sp.results_path + "logs")

In [56]:
image_path, list_path, model_path = sp.get_save_paths()

In [None]:
# checkpoint = 300
skip_from = 6
for epoch in range(checkpoint, checkpoint+args.n_epochs):
    pbar = tq(enumerate(train_loader))
    for step, (img, _) in pbar:

        current_batch_size = img.size()[0]
        x = img.cuda()
        
        AE_img.zero_grad()

        z, x_tilde = AE_img(x, skip_from)
        
        l_x = mse_loss_fn(x, x_tilde)
    
        l_ae = l_x
        l_ae.backward()
        opt_AE_img.step()

        running_losses["l_ae"].append(l_ae.item())

        s = 'Losses: l_ae: ' + str(l_ae.item())
        pbar.set_description(s)

        if (step + 1) % 781 == 0:
            
            s = 'Epoch:[{}/{}] '.format(epoch+1, args.n_epochs) + \
            str(np.mean(running_losses["l_ae"], axis=0).round(4).item())
            print(s)

            test_iter = iter(test_loader)
            test_img, _ = next(test_iter)
            test_x = test_img.cuda()
            
            test_z, test_x_tilde = AE_img(test_x)
            test_l_x = mse_loss_fn(test_x, test_x_tilde)
            running_losses["l_ae"] = np.mean(running_losses["l_ae"])
            running_losses["test_l_ae"] = test_l_x.item()

            test_img = torch.cat((test_x, test_x_tilde), axis=3).cpu()
            train_img = torch.cat((x, x_tilde), axis=3).cpu()

            val_dict = {
            "train_losses": dict(running_losses),
            }

            save_values_to_tensorboard(writer, epoch + 1, val_dict)
            save_images_to_tensorboard(writer, epoch+1, make_grid(test_img, normalize=False), 'test_img')
            save_images_to_tensorboard(writer, epoch+1, make_grid(train_img, normalize=False), 'train_img')

            running_losses.clear()

    if (epoch + 1) % 50 == 0:
        models = {nameof(AE_img): AE_img}
        save_models(model_path, epoch+1, models)

0it [00:00, ?it/s]

Epoch:[1/200] 0.0206


0it [00:00, ?it/s]

Epoch:[2/200] 0.0144


0it [00:00, ?it/s]

Epoch:[3/200] 0.0128


0it [00:00, ?it/s]

Epoch:[4/200] 0.0118


0it [00:00, ?it/s]

Epoch:[5/200] 0.0111


0it [00:00, ?it/s]

Epoch:[6/200] 0.0106


0it [00:00, ?it/s]

Epoch:[7/200] 0.0102


0it [00:00, ?it/s]

Epoch:[8/200] 0.0098


0it [00:00, ?it/s]

Epoch:[9/200] 0.0095


0it [00:00, ?it/s]

Epoch:[10/200] 0.0093


0it [00:00, ?it/s]

Epoch:[11/200] 0.009


0it [00:00, ?it/s]

Epoch:[12/200] 0.0088


0it [00:00, ?it/s]

Epoch:[13/200] 0.0085


0it [00:00, ?it/s]

Epoch:[14/200] 0.0084


0it [00:00, ?it/s]

Epoch:[15/200] 0.0082


0it [00:00, ?it/s]

Epoch:[16/200] 0.008


0it [00:00, ?it/s]

Epoch:[17/200] 0.0079


0it [00:00, ?it/s]

Epoch:[18/200] 0.0078


0it [00:00, ?it/s]

Epoch:[19/200] 0.0076


0it [00:00, ?it/s]

Epoch:[20/200] 0.0075


0it [00:00, ?it/s]

Epoch:[21/200] 0.0074


0it [00:00, ?it/s]

Epoch:[22/200] 0.0072


0it [00:00, ?it/s]

Epoch:[23/200] 0.0071


0it [00:00, ?it/s]

Epoch:[24/200] 0.007


0it [00:00, ?it/s]

Epoch:[25/200] 0.0069


0it [00:00, ?it/s]

Epoch:[26/200] 0.0068


0it [00:00, ?it/s]

Epoch:[27/200] 0.0066


0it [00:00, ?it/s]

Epoch:[28/200] 0.0066


0it [00:00, ?it/s]

Epoch:[29/200] 0.0065


0it [00:00, ?it/s]

Epoch:[30/200] 0.0064


0it [00:00, ?it/s]

Epoch:[31/200] 0.0063


0it [00:00, ?it/s]

Epoch:[32/200] 0.0062


0it [00:00, ?it/s]

Epoch:[33/200] 0.0061


0it [00:00, ?it/s]

Epoch:[34/200] 0.006


0it [00:00, ?it/s]

Epoch:[35/200] 0.006


0it [00:00, ?it/s]

Epoch:[36/200] 0.0059


0it [00:00, ?it/s]

Epoch:[37/200] 0.0058


0it [00:00, ?it/s]

Epoch:[38/200] 0.0057


0it [00:00, ?it/s]

Epoch:[39/200] 0.0057


0it [00:00, ?it/s]

Epoch:[40/200] 0.0056


0it [00:00, ?it/s]

Epoch:[41/200] 0.0055


0it [00:00, ?it/s]

Epoch:[42/200] 0.0055


0it [00:00, ?it/s]

Epoch:[43/200] 0.0054


0it [00:00, ?it/s]

Epoch:[44/200] 0.0054


0it [00:00, ?it/s]

Epoch:[45/200] 0.0053


0it [00:00, ?it/s]

Epoch:[46/200] 0.0053


0it [00:00, ?it/s]

Epoch:[47/200] 0.0052


0it [00:00, ?it/s]

Epoch:[48/200] 0.0052


0it [00:00, ?it/s]

Epoch:[49/200] 0.0051


0it [00:00, ?it/s]

Epoch:[50/200] 0.0051
Saving models


0it [00:00, ?it/s]

Epoch:[51/200] 0.005


0it [00:00, ?it/s]

Epoch:[52/200] 0.005


0it [00:00, ?it/s]

Epoch:[53/200] 0.0049


0it [00:00, ?it/s]

Epoch:[54/200] 0.0049


0it [00:00, ?it/s]

Epoch:[55/200] 0.0049


0it [00:00, ?it/s]

Epoch:[56/200] 0.0049


0it [00:00, ?it/s]

Epoch:[57/200] 0.0048


0it [00:00, ?it/s]

Epoch:[58/200] 0.0048


0it [00:00, ?it/s]

Epoch:[59/200] 0.0047


0it [00:00, ?it/s]

Epoch:[60/200] 0.0047


0it [00:00, ?it/s]

Epoch:[61/200] 0.0047


0it [00:00, ?it/s]

Epoch:[62/200] 0.0046


0it [00:00, ?it/s]

Epoch:[63/200] 0.0046


0it [00:00, ?it/s]

Epoch:[64/200] 0.0046


0it [00:00, ?it/s]

Epoch:[65/200] 0.0046


0it [00:00, ?it/s]

Epoch:[66/200] 0.0045


0it [00:00, ?it/s]

Epoch:[67/200] 0.0045


0it [00:00, ?it/s]

Epoch:[68/200] 0.0044


0it [00:00, ?it/s]

Epoch:[69/200] 0.0044


0it [00:00, ?it/s]

Epoch:[70/200] 0.0044


0it [00:00, ?it/s]

Epoch:[71/200] 0.0044


0it [00:00, ?it/s]

Epoch:[72/200] 0.0044


0it [00:00, ?it/s]

Epoch:[73/200] 0.0043


0it [00:00, ?it/s]

Epoch:[74/200] 0.0043


0it [00:00, ?it/s]

Epoch:[75/200] 0.0043


0it [00:00, ?it/s]

Epoch:[76/200] 0.0043


0it [00:00, ?it/s]

Epoch:[77/200] 0.0042


0it [00:00, ?it/s]

Epoch:[78/200] 0.0042


0it [00:00, ?it/s]

Epoch:[79/200] 0.0042


0it [00:00, ?it/s]

Epoch:[80/200] 0.0042


0it [00:00, ?it/s]

Epoch:[81/200] 0.0042


0it [00:00, ?it/s]

Epoch:[82/200] 0.0041


0it [00:00, ?it/s]

Epoch:[83/200] 0.0041


0it [00:00, ?it/s]

Epoch:[84/200] 0.0041


0it [00:00, ?it/s]

Epoch:[85/200] 0.0041


0it [00:00, ?it/s]

Epoch:[86/200] 0.0041


0it [00:00, ?it/s]

Epoch:[87/200] 0.0041


0it [00:00, ?it/s]

Epoch:[88/200] 0.004


0it [00:00, ?it/s]

Epoch:[89/200] 0.004


0it [00:00, ?it/s]

Epoch:[90/200] 0.004


0it [00:00, ?it/s]

Epoch:[91/200] 0.004


0it [00:00, ?it/s]

Epoch:[92/200] 0.004


0it [00:00, ?it/s]

Epoch:[93/200] 0.004


0it [00:00, ?it/s]

Epoch:[94/200] 0.0039


0it [00:00, ?it/s]

Epoch:[95/200] 0.0039


0it [00:00, ?it/s]

Epoch:[96/200] 0.0039


0it [00:00, ?it/s]

Epoch:[97/200] 0.0039


0it [00:00, ?it/s]

Epoch:[98/200] 0.0039


0it [00:00, ?it/s]

Epoch:[99/200] 0.0038


0it [00:00, ?it/s]

Epoch:[100/200] 0.0038
Saving models


0it [00:00, ?it/s]

Epoch:[101/200] 0.0038


0it [00:00, ?it/s]

In [None]:
e = Encoder(3, 128,128)

In [None]:
o = e(torch.zeros((1, 3, 160, 120)))

In [None]:
o[1].shape

In [None]:
print(summary(e, torch.zeros((1, 3, 160, 120)), show_input=False, show_hierarchical=False))

In [None]:
d = Decoder(3, 128, 128)

In [None]:
d

In [None]:
a = ['1', '2', '3', '4', '5']
a[:6]

In [None]:
class Conv2dAuto(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.padding =  (self.kernel_size[0] // 2, self.kernel_size[1] // 2) # dynamic add padding based on the kernel_size

In [None]:
from functools import partial
conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)  