## Load libraries

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os,sys
import re
import math
from datetime import datetime
import time
sys.dont_write_bytecode = True

In [3]:
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
from skimage.color import rgb2gray
from skimage.transform import resize

from pathlib import Path
from typing import List, Set, Dict, Tuple, Optional, Iterable, Mapping, Union, Callable

from pprint import pprint
from ipdb import set_trace as brpt

In [4]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.linalg import norm as tnorm
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import loggers as pl_loggers
# Select Visible GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="2"

## Set Path 
1. Add project root and src folders to `sys.path`
2. Set DATA_ROOT to `maptile_v2` folder

In [5]:
this_nb_path = Path(os.getcwd())
ROOT = this_nb_path.parent
SRC = ROOT/'src'
DATA_ROOT = Path("/data/hayley-old/maptiles_v2/")
paths2add = [this_nb_path, ROOT]

print("Project root: ", str(ROOT))
print('Src folder: ', str(SRC))
print("This nb path: ", str(this_nb_path))


for p in paths2add:
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))
        print(f"\n{str(p)} added to the path.")
        
# print(sys.path)



Project root:  /data/hayley-old/Tenanbaum2000
Src folder:  /data/hayley-old/Tenanbaum2000/src
This nb path:  /data/hayley-old/Tenanbaum2000/nbs

/data/hayley-old/Tenanbaum2000 added to the path.


In [6]:
# from src.data.datasets.maptiles import Maptiles, MapStyles
from src.data.datamodules.mnist_datamodule import MNISTDataModule
from src.data.datamodules.maptiles_datamodule import MaptilesDataModule

from src.models.plmodules.three_fcs import ThreeFCs
from src.models.plmodules.vanilla_vae import VanillaVAE
from src.models.plmodules.beta_vae import BetaVAE

from src.visualize.utils import show_timgs

## Start experiment 
Given a maptile, predict its style as one of OSM, CartoVoyager

In [7]:
# Instantiate MNIST Datamodule
in_shape = (1,32,32)
batch_size = 32
dm = MNISTDataModule(data_root=ROOT/'data', 
                       in_shape=in_shape,
                      batch_size=batch_size)
dm.setup('fit')
print("DM: ", dm.name)

DM:  MNIST


In [84]:
# # Instantiate data module
# all_cities = ['la', 'charlotte', 'vegas', 'boston', 'paris', \
#               'amsterdam', 'shanghai', 'seoul', 'chicago', 'manhattan', \
#              'berlin', 'montreal', 'rome']
# cities = all_cities #['berlin']#['paris']
# styles = ['StamenTonerBackground']#['OSMDefault', 'CartoVoyagerNoLabels']
# zooms = ['14']
# in_shape = (1, 64, 64)
# batch_size = 32
# dm = MaptilesDataModule(data_root=DATA_ROOT,
#                         cities=cities,
#                         styles=styles,
#                         zooms=zooms,
#                        in_shape=in_shape,
#                        batch_size=batch_size
#                        )
# dm.setup('fit')
# print("DM: ", dm.name)

# # Instantiate the pl Module
# latent_dim = 10
# hidden_dims = [32,64,128,256,512]
# act_fn = nn.LeakyReLU()
# learning_rate = 3e-4
# model = VanillaVAE(
#     in_shape=in_shape,
#     latent_dim=latent_dim,
#     hidden_dims=hidden_dims,
#     learning_rate=learning_rate,
#     act_fn=act_fn
# )
# print(model.hparams)

In [14]:
# Instantiate the pl Module
from src.models.plmodules.beta_vae import BetaVAE

betas = [0.1 * 3**i for i in range(10)]
# for kld_weight in [1.0]
latent_dim = 10
hidden_dims = [32, 64, 128, 256] #,512]
act_fn = nn.LeakyReLU()
learning_rate = 3e-4
kld_weight = 1.0 #betas[0]
enc_type = 'resnet'
# dec_type = 'conv'
dec_type = 'conv'

if enc_type == 'resnet':
    hidden_dims = [32, 32, 64, 128, 256]

model = BetaVAE(
    in_shape=in_shape, 
    latent_dim=latent_dim,
    hidden_dims=hidden_dims,
    learning_rate=learning_rate,
    act_fn=act_fn,
    kld_weight=kld_weight,
    enc_type=enc_type,
    dec_type=dec_type,
)


In [15]:
model.name

'BetaVAE-resnet-conv-1.000'

In [16]:
# Instantiate a PL `Trainer` object
# Start the experiment
max_epochs = 200
exp_name = f'{model.name}_{dm.name}'
tb_logger = pl_loggers.TensorBoardLogger(save_dir=f'{ROOT}/temp-logs', 
                                         name=exp_name,
                                         log_graph=False,
                                        default_hp_metric=False)
print("Log dir: ", tb_logger.log_dir)

log_dir = Path(tb_logger.log_dir)
if not log_dir.exists():
    log_dir.mkdir(parents=True)
    print("Created: ", log_dir)
    

# Log computational graph
# model_wrapper = ModelWrapper(model)
# tb_logger.experiment.add_graph(model_wrapper, model.example_input_array.to(model.device))
# tb_logger.log_graph(model)

trainer_config = {
    'gpus':1,
    'max_epochs': max_epochs,
    'progress_bar_refresh_rate':0,
    'terminate_on_nan':True,
    'check_val_every_n_epoch':10,
    'logger':tb_logger,
#     'callbacks':callbacks,
}

Log dir:  /data/hayley-old/Tenanbaum2000/temp-logs/BetaVAE-resnet-conv-1.000_MNIST/version_0
Created:  /data/hayley-old/Tenanbaum2000/temp-logs/BetaVAE-resnet-conv-1.000_MNIST/version_0


In [None]:
# trainer = pl.Trainer(fast_dev_run=3)
trainer = pl.Trainer(**trainer_config)
# trainer.tune(model=model, datamodule=dm)

# Fit model
trainer.fit(model, dm)
print(f"Finished at ep {trainer.current_epoch, trainer.batch_idx}")

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]


BetaVAE is called



  | Name              | Type       | Params | In sizes        | Out sizes      
-------------------------------------------------------------------------------------
0 | act_fn            | LeakyReLU  | 0      | [1, 32, 16, 16] | [1, 32, 16, 16]
1 | out_fn            | Tanh       | 0      | [1, 1, 32, 32]  | [1, 1, 32, 32] 
2 | encoder           | ResNet     | 2.8 M  | [1, 1, 32, 32]  | [1, 256, 2, 2] 
3 | fc_mu             | Linear     | 10.2 K | [1, 1024]       | [1, 10]        
4 | fc_var            | Linear     | 10.2 K | [1, 1024]       | [1, 10]        
5 | fc_latent2flatten | Linear     | 11.3 K | [1, 10]         | [1, 1024]      
6 | decoder           | Sequential | 388 K  | [1, 256, 2, 2]  | [1, 1, 32, 32] 
7 | out_layer         | Sequential | 10     | [1, 1, 32, 32]  | [1, 1, 32, 32] 
-------------------------------------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params


Ep: 0, batch: 0, loss: 31734.501953125
Ep: 0, batch: 0, loss: 32631.783203125


# ResNet Encoder
- Jan 27, 2021
- https://d2l.ai/chapter_convolutional-modern/resnet.html

In [None]:
class Residual_V1(nn.Module):
    """A module that implements a single flow of residual operation for ResNet.
    Each conv layer uses kernel of size 3x3, stride=streids, and padding=1.
    First the input's (h,w) are shrinked by `stride`, then the num of channels
    is increased to out_c via subsequent conv operations.
    
    
    input ---> conv2d-bn-relu -> z1 ---> conv2d-bn-----> z2 -> relu -> out
            |                                            ^
            |                                            |
            |                                            +
            ----------------->(1x1 conv2d)----------------
    
    
    Parameters
    ----------
    stride : int
        Stride parameter of the first conv layer. Use stride = 2 as a way to 
        halve the width, height of the input; similar to applying a pooling 
        operation.  
        
        
    use_1x1conv : bool
        Applies the 1x1 conv to the input to match the input's n_channel (in_c)
        to be equal to the output's n_chhanel (out_c), as well as (h,w) adjustment
        by `stride`.
        
        It must be set to True when the input's num channel or (h,w) need to 
        be adjusted in order to be added to the second conv's output (z2), ie: 
        - in_c is different from out_c, or
        - `stride` != 1, or
        - same shape of input and output, but just want to add 1x1 conv operation
        to the input before adding it to the activation after the second conv.
        
    `forward(x)` returns
    -----------
    out = model(x) returns a batch of tensors whose size is :
        (BS, out_c, in_h/stride, in_w/stride)
        
        
    """
    
    def __init__(self, in_c, out_c, 
                 *,
                 stride=1,
                use_1x1conv=False, 
                act_fn=nn.ReLU(inplace=True),
                kernel_size=3, padding=1 ):
        super().__init__()
        self.stride = stride
        self.use_1x1conv = use_1x1conv
        self.conv1 = nn.Conv2d(in_c, out_c, 
                              kernel_size=kernel_size, padding=padding, stride=self.stride)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c,
                              kernel_size=kernel_size, padding=padding, stride=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        
        self.conv3 = None
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_c, out_c,
                                  kernel_size=1, padding=0, stride=self.stride)
        self.act_fn = act_fn
    
    def forward(self, x):
        """
        Returns
        -------
        out = model(x) returns a batch of tensors whose size is :
        (BS, out_c, in_h/stride, in_w/stride)
        
        """
        z = self.act_fn(self.bn1(self.conv1(x)))
        z = self.bn2(self.conv2(z))
        
        if self.use_1x1conv:
            x = self.conv3(x)
        z = z + x
        return self.act_fn(z)
            
        
        

In [None]:
class Residual(nn.Module):
    """A module that implements a single flow of residual operation for ResNet.
    Each conv layer uses kernel of size 3x3, stride=streids, and padding=1.
    First the input's (h,w) are shrinked by `stride`, then the num of channels
    is increased to out_c via subsequent conv operations.
    
    
    input ---> (bn-relu-conv2d) -> z1 ---> (bn-relu-conv2d) -> z2 ---> out
            |                                                   ^
            |                                                   |
            |                                                   +
            ----------------->  (1x1 conv2d) --------------------
    
    
    Parameters
    ----------
    stride : int
        Stride parameter of the first conv layer. Use stride = 2 as a way to 
        halve the width, height of the input; similar to applying a pooling 
        operation.  
        
        
    use_1x1conv : bool
        Applies the 1x1 conv to the input to match the input's n_channel (in_c)
        to be equal to the output's n_chhanel (out_c), as well as (h,w) adjustment
        by `stride`.
        
        It must be set to True when the input's num channel or (h,w) need to 
        be adjusted in order to be added to the second conv's output (z2), ie: 
        - in_c is different from out_c, or
        - `stride` != 1, or
        - same shape of input and output, but just want to add 1x1 conv operation
        to the input before adding it to the activation after the second conv.
        
    `forward(x)` returns
    -----------
    out = model(x) returns a batch of tensors whose size is :
        (BS, out_c, in_h/stride, in_w/stride)
        
        
    """
    
    def __init__(self, in_c, out_c, 
                 *,
                 stride=1,
                 use_1x1conv=False, 
                 act_fn=nn.ReLU(inplace=True),
                 kernel_size=3, 
                 padding=1):
        super().__init__()
        self.stride = stride
        self.use_1x1conv = use_1x1conv
        
        self.bn1 = nn.BatchNorm2d(in_c)
        self.conv1 = nn.Conv2d(in_c, out_c, 
                              kernel_size=kernel_size, padding=padding, stride=self.stride)
        
        self.bn2 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c,
                              kernel_size=kernel_size, padding=padding, stride=1)
        
        self.conv3 = None
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_c, out_c,
                                  kernel_size=1, padding=0, stride=self.stride)
        self.act_fn = act_fn
    
    def forward(self, x):
        """
        Returns
        -------
        out = model(x) returns a batch of tensors whose size is :
        (BS, out_c, in_h/stride, in_w/stride)
        
        """
        z = self.conv1(self.act_fn(self.bn1(x)))
        z = self.conv2(self.act_fn(self.bn2(z)))
        
        if self.use_1x1conv:
            x = self.conv3(x)
        z = z + x
        return self.act_fn(z)
            
        
        

In [None]:
# Test Residual: in_c --> out_c is not the same
in_shape = (3, 64,64)
in_c = in_shape[0]
out_c = 32
stride = 2
m = Residual(in_c, out_c, 
             use_1x1conv=True, 
            stride=stride)


x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)

In [None]:
# Test Residual: in_c --> out_c is the same
in_shape = (32, 8, 8)
in_c = in_shape[0]
out_c = in_c
stride = 2

m = Residual(in_c, out_c, 
             use_1x1conv=True, 
             stride=stride)


x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)

In [None]:
# Input -> first block 
in_shape = (3, 64, 64)
in_c = in_shape[0]

b1 = nn.Sequential(
    nn.Conv2d(in_c, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
m = b1


x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)

In [None]:
# Each of the subsequent blocks contain 2 residual operations, where the output channel
# is doubled and the resolution (ie. h,w) is halved
def get_resnet_block(in_c, out_c, *, 
                     n_residuals=2,
                     first_block=False,
                    ) -> List[nn.Module]:
    # First residual: In the first block, we don't adjust the (h,w) by half because 
    # the input is already processed by a MaxPool layer
    if first_block:
        use_1x1conv = False if in_c == out_c else True
        res0 = Residual(in_c, out_c, stride=1, use_1x1conv=use_1x1conv)

    else:
        res0 = Residual(in_c, out_c, stride=2, use_1x1conv=True)

    block = [res0]
    # Add subsequence residuals 
    for i in range(n_residuals-1):
        block.append(Residual(out_c, out_c, 
                              stride=1, use_1x1conv=False))
    return block
        

In [None]:
# Test
in_shape = (3, 64, 64)
in_c = in_shape[0]

# m = nn.Sequential(
#     *get_resnet_block(in_c, 2*in_c, first_block=True)
# )
m = nn.Sequential(
    *get_resnet_block(in_c, 2*in_c, first_block=False)
)
x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)


## Let's create a full ResNet module
- b1 : convolution before residual blocks start
  input -> (conv-bn-relu) -> maxpool(1/2)
  


In [56]:
class ResNetEncoder(nn.Module):
    """
    input -> b1: (conv-bn-relu-maxpool(1/2))
          -> resnet_blocks: nn.Sequential of resnet blocks
               Each block has 2 residual units: (bn-relu-conv, bn-relu-conv)
          -> Flatten()
          -> FC(len_flatten, out_dim) 
          -> out_fn: eg. nn.Tanh(), nn.Sigmoid()
    """
    def __init__(self, 
                 in_c: int, 
                 hidden_dims: List[int],  
                 act_fn=nn.ReLU(inplace=True)):
        super().__init__()
        
        self.in_c = in_c
        self.hidden_dims = hidden_dims
        self.act_fn = act_fn
        
        n0 = hidden_dims[0]
        self.b1 = nn.Sequential(
            nn.Conv2d(in_c, n0, kernel_size=3, stride=1, padding=1), # (bs,n0,h,w)
            nn.BatchNorm2d(n0),
            self.act_fn,
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # (bs, n0, h/2, w/2)
        )
        
        blocks = []
        for i, (in_c, out_c) in enumerate(zip(hidden_dims, hidden_dims[1:])):
            is_first = (i == 0)
            blocks.extend(get_resnet_block(in_c, out_c, first_block=is_first))
            
        self.resnet_blocks = nn.Sequential(*blocks)
        
    def forward(self, x):
        """
        x -> b1 (conv-bn-act-maxpool(1/2)) -> resnet_blocks -> out
        
        Returns
        -------
        out : (BS, hidden_dims[-1], last_h, last_w)
        """
        out = self.b1(x)
        out = self.resnet_blocks(out)
        return out
    
                          
                    
            
        

In [57]:
in_shape = (3,64,64)
in_c = in_shape[0]
hidden_dims = [32,32,64,128]
m = ResNetEncoder(in_c, hidden_dims)

x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)

NameError: name 'get_resnet_block' is not defined

# Resnet Decoder
- Jan 29, 2021

## Residual unit with deconvolutions

In [51]:
class ResidualDeconv(nn.Module):
    """A module that implements a single flow of residual operation for ResNet.
    Each conv layer uses kernel of size 3x3, stride=stride, and padding=1.
    First the input's (h,w) are shrinked by `stride`, then the num of channels
    is increased to out_c via subsequent conv operations.
    
    
    input ---> (bn-relu-convTranspose2d) -> z1 ---> (bn-relu-convTranspose2d) -> z2 ---> out
            |                                                                     ^
            |                                                                     |
            |                                                                     +
            ----------------->  (upsampling) --------------------------------------
    
    
    Parameters
    ----------
    stride : int
        Stride parameter of the first convTranspose layer. Use stride = 2 as a way to 
        double the (h,w) of the input
        
        
    use_upsampling : bool
        Applies an upsampling and 1x1 conv to the input to adjust its (h,w) and 
        n_channel (in_c) to be equal to the output's n_chhanel (out_c).
        Thus, it must be set to True whenever the input's num channel or (h,w) need to 
        be adjusted in order to be added to the second conv's output (z2), ie: 
        - in_c is different from out_c, or
        - `stride` != 1, or
        - same shape of input and output, but just want to add 1x1 conv operation
        to the input before adding it to the activation after the second conv.

    upsampling_type : str
        'nearest': use nearest neighbor unsampling (no extra parameters) followed by 1x1 conv
        'deconv': use convTranspose2d with `stride` to adjust both (h,w) and num of channels
        
        
    `forward(x)` returns
    -----------
    out = model(x) returns a batch of tensors whose size is :
        (BS, out_c, in_h/stride, in_w/stride)
        
        
    """
    
    def __init__(self, in_c, out_c, 
                 *,
                 stride=2,
                 use_upsampling: bool = True,
                 upsampling_type: str = 'deconv',
                 norm_input: bool = True,
                 act_fn=nn.ReLU(inplace=True),
                 **kwargs
                ):
        """
        To double the input's (h,w), ie. stride=2,
            use kernel_size=3, padding=1, stride=2, output_padding=1
        When the output needs bo have the same (h,w) as the input, ie. stride=1, 
            use output_padding = 0 
        """
        super().__init__()
        deconv_kwargs = {'kernel_size': 3, 'padding': 1}#, 'output_padding':1}
        deconv_kwargs.update(kwargs)
        
        self.stride = stride
        self.outp = 1 if stride==2 else 0
        self.use_upsampling = use_upsampling
        if in_c != out_c or stride>1:
            assert self.use_upsampling==True, "Input needs to be adjusted in (h,w) and/or num channels. Set use_upsampling=True"
        self.upsampling = None
        if upsampling_type == 'nearest':
            self.upsampling = nn.Sequential(
                nn.UpsamplingNearest2d(scale_factor=self.stride),
                nn.ConvTranspose2d(in_c, out_c, **deconv_kwargs, stride=1, output_padding=0)
            )
        elif upsampling_type == 'deconv':
            #Use 1x1 conv to do both channelwise and resolutionwise expansion
            self.upsampling = nn.ConvTranspose2d(in_c, out_c,  **deconv_kwargs, stride=self.stride, output_padding=self.outp)
        self.norm_input = norm_input
        
        self.bn1 = nn.BatchNorm2d(in_c)
        self.deconv1 = nn.ConvTranspose2d(in_c, out_c, **deconv_kwargs, stride=self.stride, output_padding=self.outp)
        
        self.bn2 = nn.BatchNorm2d(out_c)
        self.deconv2 = nn.ConvTranspose2d(out_c, out_c, **deconv_kwargs, stride=1, output_padding=0)
        
        self.act_fn = act_fn
    
    def forward(self, x):
        """
        Returns
        -------
        out = model(x) returns a batch of tensors whose size is :
        (BS, out_c, in_h * stride, in_w *stride)
        
        """
        if self.norm_input:
            z = self.deconv1(self.act_fn(self.bn1(x)))
        else:
            z = self.deconv1(x) # ; print(z.shape);breakpoint()
        z = self.deconv2(self.act_fn(self.bn2(z))) # ; print(z.shape);breakpoint()
        
        if self.use_upsampling:
            x = self.upsampling(x) # ; print(x.shape);breakpoint()
        z = z + x # ; print(z.shape);breakpoint()
        
        return self.act_fn(z)
            
        
        

In [52]:
# Test Residual: in_c --> out_c is not the same
in_shape = (3, 64,64)
in_c = in_shape[0]
out_c = 32
stride = 2
use_upsampling=True
upsampling_type='deconv'
norm_input = True
m = ResidualDeconv(in_c, 
                   out_c, 
                   use_upsampling=True, 
                   upsampling_type=upsampling_type,
                   norm_input=norm_input,
                   stride=stride
                  )


x = torch.ones(1, *in_shape)
out = m(x)

print(x.shape, out.shape)

torch.Size([1, 3, 64, 64]) torch.Size([1, 32, 128, 128])


In [53]:
# Test Residual: in_c --> out_c is the same
in_shape = (32, 8, 8)
in_c = in_shape[0]
out_c = in_c
stride = 2
use_upsampling=True
upsampling_type='deconv'
norm_input = False

m = ResidualDeconv(in_c, 
                   out_c, 
                   use_upsampling=use_upsampling, 
                   upsampling_type=upsampling_type,
                   norm_input=norm_input,
                   stride=stride
                  )


x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)

torch.Size([1, 32, 8, 8]) torch.Size([1, 32, 16, 16])


## ResNetDeconv Blocks


In [None]:
# Input -> first block 
in_shape = (3, 64, 64)
in_c = in_shape[0]

b1 = nn.Sequential(
    nn.Conv2d(in_c, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
m = b1


x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)

In [54]:
# Each of the subsequent blocks contain 2 residual operations, where the output channel
# is doubled and the resolution (ie. h,w) is halved
def get_resnet_deconv_block(
    in_c, out_c, *, 
    n_residuals=2,
    first_block=False,
) -> List[nn.Module]:
    # First residual: In the first block, we don't apply the batchnorm because 
    # the input is already processed with a batchnorm.
    
    norm_input = True
    if first_block:
        norm_input = False 
    res0 = ResidualDeconv(in_c, out_c, stride=2, use_upsampling=True, norm_input=norm_input)

    block = [res0]
    # Add subsequence residuals 
    for i in range(n_residuals-1):
        block.append(ResidualDeconv(out_c, out_c, stride=1, 
                                    use_upsampling=False, norm_input=True)
                    )
    return block
        

In [55]:
# Test
in_shape = (3, 64, 64)
in_c = in_shape[0]

# m = nn.Sequential(
#     *get_resnet_block(in_c, 2*in_c, first_block=True)
# )
m = nn.Sequential(
    *get_resnet_deconv_block(in_c, 2*in_c, first_block=False)
)
x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)


torch.Size([1, 3, 64, 64]) torch.Size([1, 6, 128, 128])


## Let's create a full ResNetDecoder module
Each block contains 2 residual deconv units.
If the block is the first block in the net, then the first unit of the block does not apply
batchnorm to the input.

In each block, the first residual unit doubles both the n_channels and resolutions (height,width) by
setting:
  out_c = 2*in_c, and
  stride = 2
  


In [68]:
class ResNetDecoder(nn.Module):
    """x -> resnet_deconv_blocks -> out

    input -> resnet_deconv_blocks: nn.Sequential of resnet_deconv blocks
                Each block has 2 residual units: (bn-relu-conv, bn-relu-conv),
                except the first block, whose first residual unit doesn't apply (bn-relu)
    out : (BS, hidden_dims[-1]=in_channels, in_h, in_w)
                
    """
    def __init__(self, 
                 nfs: List[int],  
                 act_fn=nn.ReLU(inplace=True)):
        super().__init__()
        self.act_fn = act_fn
        
        blocks = []
        for i, (in_c, out_c) in enumerate(zip(nfs, nfs[1:])):
            is_first = (i == 0)
            blocks.extend(get_resnet_deconv_block(in_c, out_c, first_block=is_first))
        self.resnet_blocks = nn.Sequential(*blocks)
        
    def forward(self, x):
        """
        x -> resnet_deconv_blocks -> out
        
        Returns
        -------
        out : (BS, hidden_dims[-1]=in_channels, in_h, in_w)
        """
        out = self.resnet_deconv_blocks(x)
        return out
    
                          
                    
            
        

In [79]:
from src.models.resnet_deconv import ResNetDecoder

in_shape = (128,2,2)
in_c = in_shape[0]
nfs = [in_c, in_c//2, in_c//4, 3]
m = ResNetDecoder(nfs)

x = torch.ones(1,*in_shape)
out = m(x)

print(x.shape, out.shape)

torch.Size([1, 128, 2, 2]) torch.Size([1, 3, 16, 16])
