In [1]:
# imports
from pathlib import Path
import sys  

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[0])

# Add to sys.path
sys.path.insert(0, parent_dir)

In [2]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch import distributions as torchd

import tools
import einops
from typing import Union
from torchinfo import summary

In [3]:
parent_dir

'/Users/intuinno/codegit/cwvae-torch'

In [4]:
from networks import Conv3dVAE

In [5]:
model = Conv3dVAE()

In [6]:
class Conv3dVAE(nn.Module):
  
  def __init__(self, channels_factor=2, 
               num_conv_layers=2, 
               act=nn.ELU,
               kernels=(4,4,4),
               stride=(2,2,2),
               input_num_seq=4,
               input_width=64,
               input_height=64,
               input_channels=1,
               temp_abs_factor=4):
    super(Conv3dVAE, self).__init__()
    self._act = act 
    
    enc_layers =[]
    in_channels = input_channels
    for level in range(num_conv_layers):
      out_channels = in_channels * channels_factor
      enc_layers.append(nn.Conv3d(in_channels, 
                                  out_channels,
                                  kernels, stride, 
                                  padding=(1,1,1)
                                  ))
      if level < num_conv_layers-1:
        enc_layers.append(act())
      in_channels = out_channels 
    self.encoder = nn.Sequential(*enc_layers)
    
    
    dec_layers =[]
    in_channels = out_channels 
    for _ in range(num_conv_layers):
      out_channels = in_channels // channels_factor 
      dec_layers.append(nn.ConvTranspose3d(in_channels, 
                                           out_channels,
                                           kernels,
                                           stride,
                                           padding=(1,1,1),
#                                            output_padding=(0,1,1),
                                           ))
      if level < num_conv_layers-1:
        dec_layers.append(act())
      in_channels = out_channels
    self.decoder = nn.Sequential(*dec_layers)
    self._temp_abs_factor = temp_abs_factor
    
  def forward(self, x):
    # Assume x is (b t h w c)
    B, T, H, W, C = x.shape
    t1 = T // self._temp_abs_factor
    x = einops.rearrange(x, 'b t h w c -> b  c t h w') 
    z = self.encoder(x)
    z = torch.clip(z, -0.5, 0.5)
    # logits = einops.rearrange(logits, 'b c t h w -> b t h w c')
    # dist = torchd.OneHotCategoricalStraightThrough(logits=logits)
    # dist = torchd.independent.Independent(dist, 3)
    # z = dist.rsample()
    # dec_z = einops.rearrange(z, 'b t h w c -> b c t h w')
    recon = self.decoder(z)
    recon = torch.clip(recon, -0.5, 0.5)
    recon = einops.rearrange(recon, 'b c t h w -> b t h w c')
    z = einops.rearrange(z, 'b c t h w -> b t h w c')
    return recon, z
  
  def decode(self, emb):
    # Assume emb is (b t h w c)
    B, T, H, W, C = emb.shape
    t2 = C // self._temp_abs_factor
    z = einops.rearrange(emb, 'b t h w (c t2) -> (b t) c t2 h w', t2 = t2)
    recon = self.decoder(z)
    recon = einops.rearrange(recon, '(b t1) c t h w -> b (t1 t) h w c', t1=T)
    return recon 
      
    
      
    

In [7]:
model = Conv3dVAE(channels_factor=4)
summary(model, input_size=(10,4,64,64,1))

  action_fn=lambda data: sys.getsizeof(data.storage()),


Layer (type:depth-idx)                   Output Shape              Param #
Conv3dVAE                                [10, 4, 64, 64, 1]        --
├─Sequential: 1-1                        [10, 16, 1, 16, 16]       --
│    └─Conv3d: 2-1                       [10, 4, 2, 32, 32]        260
│    └─ELU: 2-2                          [10, 4, 2, 32, 32]        --
│    └─Conv3d: 2-3                       [10, 16, 1, 16, 16]       4,112
├─Sequential: 1-2                        [10, 1, 4, 64, 64]        --
│    └─ConvTranspose3d: 2-4              [10, 4, 2, 32, 32]        4,100
│    └─ConvTranspose3d: 2-5              [10, 1, 4, 64, 64]        257
Total params: 8,729
Trainable params: 8,729
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 141.93
Input size (MB): 0.66
Forward/backward pass size (MB): 2.95
Params size (MB): 0.03
Estimated Total Size (MB): 3.64

In [13]:
class Conv3dVAE(nn.Module):
  
  def __init__(self, channels_factor=4, 
               num_conv_layers=2, 
               act=nn.ELU,
               kernels=(3,3,3),
               stride=(2,2,2),
               input_num_seq=4,
               input_width=64,
               input_height=64,
               input_channels=1,
               temp_abs_factor=4):
    super(Conv3dVAE, self).__init__()
    self._act = act 
    
    enc_layers =[]
    in_channels = input_channels
    for level in range(num_conv_layers):
      out_channels = in_channels * channels_factor
      enc_layers.append(nn.Conv3d(in_channels, 
                                  out_channels,
                                  kernels, stride, 
                                  padding=(0,1,1)
                                  ))
      if level < num_conv_layers-1:
        enc_layers.append(act())
      in_channels = out_channels 
    enc_layers.append(nn.Tanh())
    self.encoder = nn.Sequential(*enc_layers)
    
    
    dec_layers =[]
    in_channels = out_channels 
    for level in range(num_conv_layers):
      out_channels = in_channels // channels_factor 
      dec_layers.append(nn.ConvTranspose3d(in_channels, 
                                           out_channels,
                                           kernels,
                                           stride,
                                           padding=(0,1,1),
                                           output_padding=(0,1,1),
                                           ))
      if level < num_conv_layers-1:
        dec_layers.append(act())
      in_channels = out_channels
    dec_layers.append(nn.Tanh())
    self.decoder = nn.Sequential(*dec_layers)
    self._temp_abs_factor = temp_abs_factor
    
  def forward(self, x):
    # Assume x is (b t h w c)
    x = einops.rearrange(x, 'b t h w c -> b c t h w' ) 
    z = self.encoder(x)
    recon = self.decoder(z)
    recon = einops.rearrange(recon, 'b c t h w -> b t h w c')
    z = einops.rearrange(z, 'b c t h w -> b t h w c ')
    return recon, z
  
  def decode(self, emb):
    # Assume emb is (b t h w c)
    z = einops.rearrange(emb, 'b t h w c -> b c t h w')
    recon = self.decoder(z)
    recon = einops.rearrange(recon, 'b c t h w -> b t h w c')
    return recon 
      
    

In [14]:
model = Conv3dVAE(channels_factor=4, input_channels=16)
summary(model, input_size=(8,252,16,16,16))

Layer (type:depth-idx)                   Output Shape              Param #
Conv3dVAE                                [8, 251, 16, 16, 16]      --
├─Sequential: 1-1                        [8, 256, 62, 4, 4]        --
│    └─Conv3d: 2-1                       [8, 64, 125, 8, 8]        27,712
│    └─ELU: 2-2                          [8, 64, 125, 8, 8]        --
│    └─Conv3d: 2-3                       [8, 256, 62, 4, 4]        442,624
│    └─Tanh: 2-4                         [8, 256, 62, 4, 4]        --
├─Sequential: 1-2                        [8, 16, 251, 16, 16]      --
│    └─ConvTranspose3d: 2-5              [8, 64, 125, 8, 8]        442,432
│    └─ELU: 2-6                          [8, 64, 125, 8, 8]        --
│    └─ConvTranspose3d: 2-7              [8, 16, 251, 16, 16]      27,664
│    └─Tanh: 2-8                         [8, 16, 251, 16, 16]      --
Total params: 940,432
Trainable params: 940,432
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 47.82
Input size (MB): 33.03
F

In [24]:
class Conv3dVAE(nn.Module):
  
  def __init__(self, channels_factor=4, 
               num_conv_layers=4, 
               act=nn.GELU,
               kernels=(3,3,3),
               stride=(2,2,2),
               input_num_seq=4,
               input_width=64,
               input_height=64,
               input_channels=1,
               temp_abs_factor=4):
    super(Conv3dVAE, self).__init__()
    
    c_hid = channels_factor * input_channels 
    
    self.encoder = nn.Sequential(
        nn.Conv3d(input_channels, c_hid, kernel_size=3, padding=1, stride=2),  # 64x64 => 32x32
        act(),
        nn.Conv3d(c_hid, c_hid, kernel_size=3, padding=1),
        act(),
        nn.Conv3d(c_hid, channels_factor * c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
        act(),
        nn.Conv3d(channels_factor * c_hid, channels_factor * c_hid, kernel_size=3, padding=1),
        nn.Tanh(),
        )


    self.decoder = nn.Sequential(
        nn.ConvTranspose3d(
            channels_factor * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2
        ),  # 16x16 => 32x32
        act(),
        nn.Conv3d( c_hid,  c_hid, kernel_size=3, padding=1),
        act(),
        nn.ConvTranspose3d(c_hid, input_channels, kernel_size=3, output_padding=1, padding=1, stride=2),  # 32x32 => 64x64
        act(),
        nn.Conv3d(input_channels, input_channels, kernel_size=3, padding=1),
        nn.Tanh(),  # The input images is scaled between -1 and 1, hence the output has to be bounded as well
    )
    
    
    
    self._temp_abs_factor = temp_abs_factor
    
  def forward(self, x):
    # Assume x is (b t h w c)
    x = einops.rearrange(x, 'b t h w c -> b c t h w' ) 
    z = self.encoder(x)
    recon = self.decoder(z)
    recon = einops.rearrange(recon, 'b c t h w -> b t h w c')
    z = einops.rearrange(z, 'b c t h w -> b t h w c ')
    return recon, z
  
  def decode(self, emb):
    # Assume emb is (b t h w c)
    z = einops.rearrange(emb, 'b t h w c -> b c t h w')
    recon = self.decoder(z)
    recon = einops.rearrange(recon, 'b c t h w -> b t h w c')
    return recon 
      

In [25]:
model = Conv3dVAE(channels_factor=4, input_channels=1)
summary(model, input_size=(8,1000,64,64,1))

Layer (type:depth-idx)                   Output Shape              Param #
Conv3dVAE                                [8, 1000, 64, 64, 1]      --
├─Sequential: 1-1                        [8, 16, 250, 16, 16]      --
│    └─Conv3d: 2-1                       [8, 4, 500, 32, 32]       112
│    └─GELU: 2-2                         [8, 4, 500, 32, 32]       --
│    └─Conv3d: 2-3                       [8, 4, 500, 32, 32]       436
│    └─GELU: 2-4                         [8, 4, 500, 32, 32]       --
│    └─Conv3d: 2-5                       [8, 16, 250, 16, 16]      1,744
│    └─GELU: 2-6                         [8, 16, 250, 16, 16]      --
│    └─Conv3d: 2-7                       [8, 16, 250, 16, 16]      6,928
│    └─Tanh: 2-8                         [8, 16, 250, 16, 16]      --
├─Sequential: 1-2                        [8, 1, 1000, 64, 64]      --
│    └─ConvTranspose3d: 2-9              [8, 4, 500, 32, 32]       1,732
│    └─GELU: 2-10                        [8, 4, 500, 32, 32]       --
│   