In [11]:
# imports
from pathlib import Path
import sys  

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[0])

# Add to sys.path
sys.path.insert(0, parent_dir)

In [20]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch import distributions as torchd

import tools
import einops
from typing import Union
from torchinfo import summary

In [21]:
parent_dir

'/Users/intuinno/codegit/cwvae-torch'

In [22]:
from networks import Conv3dVAE

In [23]:
model = Conv3dVAE()

Layer (type:depth-idx)                   Output Shape              Param #
Conv3dVAE                                [10, 1000, 67, 67, 1]     --
├─Sequential: 1-1                        [2500, 4, 2, 16, 16]      --
│    └─Conv3d: 2-1                       [2500, 2, 3, 32, 32]      66
│    └─ELU: 2-2                          [2500, 2, 3, 32, 32]      --
│    └─Conv3d: 2-3                       [2500, 4, 2, 16, 16]      260
├─Sequential: 1-2                        [2500, 1, 4, 67, 67]      --
│    └─ConvTranspose3d: 2-4              [2500, 2, 3, 33, 33]      258
│    └─ConvTranspose3d: 2-5              [2500, 1, 4, 67, 67]      65
Total params: 649
Trainable params: 649
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 5.86
Input size (MB): 163.84
Forward/backward pass size (MB): 653.64
Params size (MB): 0.00
Estimated Total Size (MB): 817.48

In [31]:
class Conv3dVAE(nn.Module):
  
  def __init__(self, channels_factor=2, 
               num_conv_layers=2, 
               act=nn.ELU,
               kernels=(2,4,4),
               stride=(1,2,2),
               input_num_seq=4,
               input_width=64,
               input_height=64,
               input_channels=1,
               temp_abs_factor=4):
    super(Conv3dVAE, self).__init__()
    self._act = act 
    
    enc_layers =[]
    in_channels = input_channels
    for level in range(num_conv_layers):
      out_channels = in_channels * channels_factor
      enc_layers.append(nn.Conv3d(in_channels, 
                                  out_channels,
                                  kernels, stride, 
                                  padding=(0,1,1)
                                  ))
      if level < num_conv_layers-1:
        enc_layers.append(act())
      in_channels = out_channels 
    self.encoder = nn.Sequential(*enc_layers)
    
    
    dec_layers =[]
    in_channels = out_channels 
    for _ in range(num_conv_layers):
      out_channels = in_channels // channels_factor 
      dec_layers.append(nn.ConvTranspose3d(in_channels, 
                                           out_channels,
                                           kernels,
                                           stride,
                                           padding=(0,1,1),
#                                            output_padding=(0,1,1),
                                           ))
      if level < num_conv_layers-1:
        dec_layers.append(act())
      in_channels = out_channels
    self.decoder = nn.Sequential(*dec_layers)
    self._temp_abs_factor = temp_abs_factor
    
  def forward(self, x):
    # Assume x is (b t h w c)
    B, T, H, W, C = x.shape
    t1 = T // self._temp_abs_factor
    x = einops.rearrange(x, 'b (t t2) h w c -> (b t) c t2 h w', t2=self._temp_abs_factor) 
    z = self.encoder(x)
    z = torch.clip(z, -0.5, 0.5)
    # logits = einops.rearrange(logits, 'b c t h w -> b t h w c')
    # dist = torchd.OneHotCategoricalStraightThrough(logits=logits)
    # dist = torchd.independent.Independent(dist, 3)
    # z = dist.rsample()
    # dec_z = einops.rearrange(z, 'b t h w c -> b c t h w')
    recon = self.decoder(z)
    recon = torch.clip(recon, -0.5, 0.5)
    recon = einops.rearrange(recon, '(b t1) c t2 h w -> b (t1 t2) h w c', t1=t1)
    z = einops.rearrange(z, '(b t1) c t h w -> b t1 h w (c t)', t1=t1)
    return recon, z
  
  def decode(self, emb):
    # Assume emb is (b t h w c)
    B, T, H, W, C = emb.shape
    t2 = C // self._temp_abs_factor
    z = einops.rearrange(emb, 'b t h w (c t2) -> (b t) c t2 h w', t2 = t2)
    recon = self.decoder(z)
    recon = einops.rearrange(recon, '(b t1) c t h w -> b (t1 t) h w c', t1=T)
    return recon 
      
    
      
    

In [32]:
model = Conv3dVAE()
summary(model, input_size=(10,1000,64,64,1))

Layer (type:depth-idx)                   Output Shape              Param #
Conv3dVAE                                [10, 1000, 64, 64, 1]     --
├─Sequential: 1-1                        [2500, 4, 2, 16, 16]      --
│    └─Conv3d: 2-1                       [2500, 2, 3, 32, 32]      66
│    └─ELU: 2-2                          [2500, 2, 3, 32, 32]      --
│    └─Conv3d: 2-3                       [2500, 4, 2, 16, 16]      260
├─Sequential: 1-2                        [2500, 1, 4, 64, 64]      --
│    └─ConvTranspose3d: 2-4              [2500, 2, 3, 32, 32]      258
│    └─ConvTranspose3d: 2-5              [2500, 1, 4, 64, 64]      65
Total params: 649
Trainable params: 649
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 5.48
Input size (MB): 163.84
Forward/backward pass size (MB): 614.40
Params size (MB): 0.00
Estimated Total Size (MB): 778.24