In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import sys

notebook_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(notebook_dir, '..'))
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)
print(sys.path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%reload_ext autoreload
%autoreload 2

['/mnt/ceph/users/blyo1/projects', '/mnt/sw/nix/store/71ksmx7k6xy3v9ksfkv5mp5kxxp64pd6-python-3.10.13-view/lib/python310.zip', '/mnt/sw/nix/store/71ksmx7k6xy3v9ksfkv5mp5kxxp64pd6-python-3.10.13-view/lib/python3.10', '/mnt/sw/nix/store/71ksmx7k6xy3v9ksfkv5mp5kxxp64pd6-python-3.10.13-view/lib/python3.10/lib-dynload', '', '/mnt/home/blyo1/venvs/py310/lib/python3.10/site-packages']


# hVAE

In [59]:
from b_models.hvae import ConvBlock, TopDownBlock, BottomUpBlock, LadderVAE

lvae = LadderVAE(input_dim=32, z_dims = [2, 3], c_in = [1, 5], c_out = [5, 8], num_blocks=1)

# c_in, c_out has to match each other, like this: c_in = [a, b, c], c_out = [b, c, d]
# which means we can just use one array that lists the channels and iterate through that.

lvae

LadderVAE(
  (encoder): ModuleList(
    (0): BottomUpBlock(
      (conv_block): ConvBlock(
        (pre_conv): Conv2d(1, 5, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (block): Sequential(
          (0): BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
          (2): Conv2d(5, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (compress_block): Sequential(
        (0): AdaptiveAvgPool2d(output_size=1)
        (1): Flatten(start_dim=1, end_dim=-1)
        (2): Linear(in_features=5, out_features=4, bias=True)
      )
    )
    (1): BottomUpBlock(
      (conv_block): ConvBlock(
        (pre_conv): Conv2d(5, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (block): Sequential(
          (0): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (1): ReLU()
          (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      

In [56]:
# count number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Number of trainable parameters:", count_parameters(lvae))

Number of trainable parameters: 27615


In [53]:
test_input = torch.randn(1, 1, 32, 32)

lvae(test_input)

i: 0
torch.Size([1, 1, 32, 32])
torch.Size([1, 8, 16, 16])
torch.Size([1, 8, 16, 16])
i: 1
torch.Size([1, 8, 16, 16])
torch.Size([1, 8, 8, 8])
torch.Size([1, 8, 8, 8])
j: 0
torch.Size([1, 8, 8, 8])
torch.Size([1, 8, 16, 16])
torch.Size([1, 8, 16, 16])
j: 1
torch.Size([1, 8, 16, 16])
torch.Size([1, 1, 32, 32])
torch.Size([1, 1, 32, 32])


tensor([[[[-0.2183, -0.0263, -0.0677,  ...,  0.1675,  0.2517,  0.1360],
          [-0.1993,  0.1291, -0.1218,  ...,  0.1240,  0.2605,  0.3418],
          [-0.0156, -0.2398,  0.0745,  ...,  0.3267,  0.3944,  0.2639],
          ...,
          [-0.2797, -0.0138,  0.0043,  ...,  0.1777,  0.1372,  0.2306],
          [-0.0066,  0.1069,  0.0167,  ...,  0.3334,  0.2229,  0.3079],
          [ 0.0056,  0.0389, -0.0461,  ...,  0.1615,  0.2783,  0.0793]]]],
       grad_fn=<ConvolutionBackward0>)

In [27]:
import torch.nn.functional as F


F.conv_transpose2d(torch.randn(1, 3), 5, 4, 2, 1)

TypeError: conv_transpose2d(): argument 'weight' (position 2) must be Tensor, not int

# scrap

In [14]:
import argparse
from b_models.vae import VariableConvEncoder, VariableConvDecoder, VAE, UNetDecoder

num_channels = 1
image_dims = 32
latent_dims = 5

enc = VariableConvEncoder(
    num_channels,
    image_dims,
    latent_dims, 
    channels=[32, 32, 32], 
    bias=True,
)

dec = VariableConvDecoder(
    num_channels,
    image_dims,
    latent_dims,
    channels=[32, 32, 32],  # output_channels, kernel_size, stride, padding
    output_channels=1,
    bias=True,
)

vae = VAE(
    encoder=enc,
    decoder=dec,
    kl_reduction="mean"
)

In [44]:
from b_models.vae import VariableConvEncoder

enc = VariableConvEncoder(num_channels=1, image_dims=32, latent_dims=5, channels=[4, 8, 16])
print(enc.d1_input_dims)
print(enc.d2_input_dims)
print(enc.d3_input_dims)

print(enc.conv_d1.weight.shape)
print(enc.conv_d2.weight.shape)
print(enc.conv_d3.weight.shape)

print(enc.linear_mu.weight.shape)

32
16
8
torch.Size([4, 1, 3, 3])
torch.Size([8, 4, 3, 3])
torch.Size([16, 8, 3, 3])
torch.Size([5, 256])


In [119]:
from b_models.hvae import EncoderConvBlock

example_input = torch.randn(1, 1, 32, 32)
enc_block = EncoderConvBlock(num_channels=1, img_dim=example_input.shape[3], latent_dim=5, channels=[4, 8, 16])

# for i in [0, 2, 4]:
#     print(enc_block.conv_network[i].weight.shape)
# print(enc_block.linear_mu.weight.shape)
d, mu, var = enc_block(example_input)

torch.Size([1, 1, 32, 32])
torch.Size([1, 4, 16, 16])
torch.Size([1, 4, 16, 16])
torch.Size([1, 8, 8, 8])
torch.Size([1, 8, 8, 8])
torch.Size([1, 16, 4, 4])
torch.Size([1, 16, 4, 4])


In [5]:
# testing adaptive avg pool
rand_input = torch.randn(1, 16, 7, 7)
import torch.nn.functional as F
output = F.adaptive_avg_pool2d(rand_input, 1)
print(output.shape)

torch.Size([1, 16, 1, 1])


In [3]:
from b_models.hvae import SpatialEncoderConvBlock

example_input = torch.randn(1, 1, 32, 32)
enc_block = SpatialEncoderConvBlock(num_channels=1, img_dim=example_input.shape[3], latent_dim=5, channels=[4, 8, 16])

# for i in [0, 2, 4]:
#     print(enc_block.conv_network[i].weight.shape)
# print(enc_block.linear_mu.weight.shape)
d, mu, var = enc_block(example_input)

torch.Size([1, 1, 32, 32])
torch.Size([1, 4, 16, 16])
torch.Size([1, 4, 16, 16])
torch.Size([1, 8, 8, 8])
torch.Size([1, 8, 8, 8])
torch.Size([1, 16, 4, 4])
torch.Size([1, 16, 4, 4])
torch.Size([1, 5])


In [94]:
from b_models.hvae import DecoderConvBlock

dec_block = DecoderConvBlock(output_dim=16, latent_dim=5, channels=[16, 8, 4])

example_input = torch.randn(1, 5)
dec_out = dec_block(example_input)
print(dec_out.shape)

# for i in range(len(dec_block.deconv_network)):
#     if hasattr(dec_block.deconv_network[i], 'weight'):
#         print(dec_block.deconv_network[i].weight.shape)

torch.Size([1, 4, 16, 16])


In [117]:
from b_models.hvae import DecoderConvBlock

dec_block = DecoderConvBlock(latent_in_dim=5, latent_out_dim=2, output_dim=8, channels=[4, 8, 16])

example_input = torch.randn(1, 5)
dec_out, _, _ = dec_block(example_input)
print(dec_out.shape)

# for i in range(len(dec_block.deconv_network)):
#     if hasattr(dec_block.deconv_network[i], 'weight'):
#         print(dec_block.deconv_network[i].weight.shape)

torch.Size([1, 5])
torch.Size([1, 16])
torch.Size([1, 16, 1, 1])
torch.Size([1, 8, 2, 2])
torch.Size([1, 8, 2, 2])
torch.Size([1, 4, 4, 4])
torch.Size([1, 4, 4, 4])
torch.Size([1, 16, 8, 8])
torch.Size([1, 16, 8, 8])
end of deconv network
torch.Size([1, 2])
