# Import Modules

In [1]:
import glob
import os
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import gc

import pandas as pd
from functools import partial
from tqdm.notebook import trange, tqdm
import umap
# import umap.plot 
from sklearn.decomposition import PCA

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import sys
sys.path.insert(0, os.path.join(os.path.expanduser('~/Research/Github/'),'PyTorch-VAE'))
from datasets import WCDataset, WCShotgunDataset, WC3dDataset, WCRNNDataset

import plotly.express as px
import plotly.graph_objects as go

import matplotlib as mpl
mpl.rcParams.update({'font.size':         24,
                     'axes.linewidth':    3,
                     'xtick.major.size':  5,
                     'xtick.major.width': 2,
                     'ytick.major.size':  5,
                     'ytick.major.width': 2,
                     'axes.spines.right': False,
                     'axes.spines.top':   False,
                     'font.sans-serif':  "Arial",
                     'font.family':      "sans-serif",
                    })

########## Checks if path exists, if not then creates directory ##########
def check_path(basepath, path):
    if path in basepath:
        return basepath
    elif not os.path.exists(os.path.join(basepath, path)):
        os.makedirs(os.path.join(basepath, path))
        print('Added Directory:'+ os.path.join(basepath, path))
        return os.path.join(basepath, path)
    else:
        return os.path.join(basepath, path)

rootdir = os.path.expanduser('~/Research/FMEphys/')

# Set up partial functions for directory managing
join = partial(os.path.join,rootdir)
checkDir = partial(check_path,rootdir)
FigurePath = checkDir('Figures')

savefigs=False

# %matplotlib widget

In [2]:
from ConvLSTM import ConvLSTM
from experiment import VAEXperiment
import yaml
import torchvision

In [3]:
n = -1
version = 7
modeltype = '3dmp' # '3d'
if modeltype=='shotgun': 
    filename =  os.path.join(os.path.expanduser('~/Research/Github/'),'PyTorch-VAE','configs/WC_vae_shotgun.yaml')
    ckpt_path = glob.glob(os.path.expanduser('~/Research/FMEphys/logs2/VanillaVAE/version_3/checkpoints/*.ckpt'))[n]
elif modeltype=='vanilla':
    filename =  os.path.join(os.path.expanduser('~/Research/Github/'),'PyTorch-VAE','configs/WC_vae.yaml')
    ckpt_path = glob.glob(os.path.expanduser('~/Research/FMEphys/logs2/VanillaVAE/version_0/checkpoints/*.ckpt'))[n]
elif modeltype=='3d':
    filename =  os.path.join(os.path.expanduser('~/Research/Github/'),'PyTorch-VAE','configs/WC_vae3d.yaml')
    ckpt_path = glob.glob(os.path.expanduser('~/Research/FMEphys/logs2/VAE3d/version_4/checkpoints/*.ckpt'))[n]
elif modeltype=='3dmp':
    filename =  os.path.join(os.path.expanduser('~/Research/FMEphys/logs2/VAE3dmp/version_{:d}/WC_vae3dmp.yaml'.format(version)))
    ckpt_path = glob.glob(os.path.expanduser('~/Research/FMEphys/logs2/VAE3dmp/version_{:d}/checkpoints/*.ckpt'.format(version)))[n]
else:
    raise ValueError(f'{value} is not a valid model type')
print(ckpt_path)
Epoch = int(os.path.basename(ckpt_path).split('=')[1].split('-')[0])

with open(filename, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

if modeltype=='shotgun': 
    config['exp_params']['data_path'] = os.path.expanduser('~/Research/FMEphys/')
    config['exp_params']['csv_path_train'] = os.path.expanduser('~/Research/FMEphys//WCShotgun_Train_Data.csv')
    config['exp_params']['csv_path_val'] = os.path.expanduser('~/Research/FMEphys//WCShotgun_Val_Data.csv')
    config['logging_params']['save_dir'] = os.path.expanduser('~/Research/FMEphys/logs2/')
elif modeltype=='vanilla':
    config['exp_params']['data_path'] = os.path.expanduser('~/Research/FMEphys/')
    config['exp_params']['csv_path_train'] = os.path.expanduser('~/Research/FMEphys//WC_Train_Data.csv')
    config['exp_params']['csv_path_val'] = os.path.expanduser('~/Research/FMEphys//WC_Val_Data.csv')
    config['logging_params']['save_dir'] = os.path.expanduser('~/Research/FMEphys/logs2/')
elif (modeltype=='3d') | (modeltype=='3dmp'):
    config['exp_params']['data_path'] = os.path.expanduser('~/Research/FMEphys/data')
    config['exp_params']['csv_path_train'] = os.path.expanduser('~/Research/FMEphys/WC3d_Train_Data_SingVid.csv')
    config['exp_params']['csv_path_val'] = os.path.expanduser('~/Research/FMEphys/WC3d_Val_Data_SingVid.csv')
    config['logging_params']['save_dir'] = os.path.expanduser('~/Research/FMEphys/logs2/')
config
check_path(FigurePath,'version_{:d}'.format(version))

/home/seuss/Research/FMEphys/logs2/VAE3dmp/version_7/checkpoints/epoch=23-step=82521.ckpt


'/home/seuss/Research/FMEphys/Figures/version_7'

In [4]:
from torchvision import transforms
SetRange = transforms.Lambda(lambda X: 2 * X - 1.)

transform = transforms.Compose([
                    transforms.Grayscale(num_output_channels=1),
                    # transforms.RandomHorizontalFlip(),
                    transforms.Resize((config['exp_params']['imgH_size'],config['exp_params']['imgW_size'])),
                    transforms.ToTensor(),
                    SetRange])

In [5]:
dataset = WCRNNDataset(root_dir = config['exp_params']['data_path'],
                                csv_file = config['exp_params']['csv_path_train'],
                                N_fm=config['exp_params']['N_fm'],
                                transform=transform)
    
StartInd = 0
config['exp_params']['batch_size'] = 30
config['model_params']['tstrides'] = [1,1,1,1]
NumBatches= 100 #len(dataset)

train_dataset = Subset(dataset,torch.arange(StartInd,StartInd+config['exp_params']['batch_size']*NumBatches)) # 107162
train_dataloader = DataLoader(train_dataset,
                              batch_size= config['exp_params']['batch_size'],
                              shuffle = False,
                              drop_last=False,
                              num_workers=7,
                              pin_memory=False,)
batch = next(iter(train_dataloader))

In [6]:
locals().update(config['model_params'])

in_channels, latent_dim, depth_dim, xystrides, tstrides, kernels, mpkernels, input_size, hidden_dims

(1,
 128,
 16,
 [2, 2, 2, 2],
 [1, 1, 1, 1],
 [5, 5, 5, 5],
 [2, 2, 2, 2],
 [1, 16, 64, 64],
 [32, 64, 128, 256])

In [7]:
# Build Encoder
encoder = nn.ModuleList()
for layer_n, h_dim in enumerate(hidden_dims):
    encoder.add_module(str('ConvLSTM{}'.format(layer_n)),ConvLSTM(input_dim = in_channels,
                                                                 hidden_dim = [h_dim],
                                                                 kernel_size= (kernels[layer_n],kernels[layer_n]),
                                                                 num_layers = 1,
                                                                 batch_first=True,
                                                                 ))

    encoder.add_module(str('batchnorm%i' % layer_n), 
                                nn.BatchNorm3d(h_dim))
    encoder.add_module(str('maxpool%i' % layer_n), 
                                nn.MaxPool2d(kernel_size=mpkernels[layer_n], 
                                stride=(xystrides[layer_n], xystrides[layer_n]), 
                                padding=0,return_indices=False))
    encoder.add_module(str('relu%i' % layer_n), 
                                nn.LeakyReLU(0.05))
    in_channels = h_dim
fc_mu = nn.Linear(hidden_dims[-1]*depth_dim*4*4, latent_dim)
fc_var = nn.Linear(hidden_dims[-1]*depth_dim*4*4, latent_dim)

In [8]:
x = batch
B,T,C,H,W = x.shape
for name, layer in encoder.named_children():
    print(name,x.shape)
    if isinstance(layer,ConvLSTM):
        if x.shape[1] != T:
            x = x.permute(0,2,1,3,4)
        else:
            pass
        x, state = layer(x)
        x = x[0].permute(0,2,1,3,4)
    elif isinstance(layer,nn.MaxPool2d):
        shape = x.shape
        x = x.view(shape[0],shape[1]*shape[2],shape[3],shape[4])
        x = layer(x)
        x = x.view(shape[0],shape[1],shape[2],x.shape[-2],x.shape[-1])
    else:
        x = layer(x)
    print('OutShape', x.shape)
encoder_shapes = [x.shape]
x = torch.flatten(x, start_dim=1)
mu = fc_mu(x)
log_var = fc_var(x)

print('mu:',mu.shape,'log_var',log_var.shape)

ConvLSTM0 torch.Size([30, 16, 1, 64, 64])
OutShape torch.Size([30, 32, 16, 64, 64])
batchnorm0 torch.Size([30, 32, 16, 64, 64])
OutShape torch.Size([30, 32, 16, 64, 64])
maxpool0 torch.Size([30, 32, 16, 64, 64])
OutShape torch.Size([30, 32, 16, 32, 32])
relu0 torch.Size([30, 32, 16, 32, 32])
OutShape torch.Size([30, 32, 16, 32, 32])
ConvLSTM1 torch.Size([30, 32, 16, 32, 32])
OutShape torch.Size([30, 64, 16, 32, 32])
batchnorm1 torch.Size([30, 64, 16, 32, 32])
OutShape torch.Size([30, 64, 16, 32, 32])
maxpool1 torch.Size([30, 64, 16, 32, 32])
OutShape torch.Size([30, 64, 16, 16, 16])
relu1 torch.Size([30, 64, 16, 16, 16])
OutShape torch.Size([30, 64, 16, 16, 16])
ConvLSTM2 torch.Size([30, 64, 16, 16, 16])
OutShape torch.Size([30, 128, 16, 16, 16])
batchnorm2 torch.Size([30, 128, 16, 16, 16])
OutShape torch.Size([30, 128, 16, 16, 16])
maxpool2 torch.Size([30, 128, 16, 16, 16])
OutShape torch.Size([30, 128, 16, 8, 8])
relu2 torch.Size([30, 128, 16, 8, 8])
OutShape torch.Size([30, 128, 16,

In [9]:
def reparameterize(mu, logvar):
    """
    Reparameterization trick to sample from N(mu, var) from
    N(0,1).
    :param mu: (Tensor) Mean of the latent Gaussian [B x D]
    :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
    :return: (Tensor) [B x D]
    """
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return eps * std + mu

In [10]:
z = reparameterize(mu, log_var)

In [11]:
# Build Decoder
decoder_input = nn.Linear(latent_dim, hidden_dims[-1]*encoder_shapes[-1][-1]*encoder_shapes[-1][-2]*encoder_shapes[-1][-3])

hidden_dims.reverse()
decoder = nn.ModuleList()
for layer_n, i in enumerate(range(len(hidden_dims) - 1)):
    decoder.add_module(str('ConvLSTM{}_transpose'.format(layer_n)), ConvLSTM(input_dim   = hidden_dims[i],
                                                                             hidden_dim  = hidden_dims[i+1],
                                                                             kernel_size = (kernels[layer_n],kernels[layer_n]),
                                                                             num_layers  = 1,
                                                                             batch_first = True,
                                                                             use_transpose = True,
                                                                         ))
    decoder.add_module(str('upsample%i' % i),nn.Upsample(scale_factor=(1, xystrides[-1], xystrides[-1]), mode='nearest'))

    decoder.add_module(str('batchnorm%i' % layer_n),
                            nn.BatchNorm3d(hidden_dims[i + 1]))
    decoder.add_module(str('relu%i' % layer_n), nn.LeakyReLU(0.05))


final_layer = nn.ModuleList()
final_layer.add_module(str('ConvLSTM{}_transpose'.format(layer_n)), ConvLSTM(input_dim   = hidden_dims[-1],
                                                                         hidden_dim  = hidden_dims[-1],
                                                                         kernel_size = (kernels[-1],kernels[-1]),
                                                                         num_layers  = 1,
                                                                         batch_first = True,
                                                                         use_transpose = True,
                                                                     ))
final_layer.add_module(str('upsample%i' % i),nn.Upsample(scale_factor=(1, xystrides[-1], xystrides[-1]), mode='nearest'))

final_layer.add_module(str('batchnorm%i' % layer_n), nn.BatchNorm3d(hidden_dims[-1]))
final_layer.add_module(str('relu%i' % layer_n), nn.LeakyReLU(0.05))
final_layer.add_module(str('last_conv%i' % 0), nn.Conv3d(hidden_dims[-1], out_channels=1,
                                                             kernel_size= kernels[0], padding=2))
final_layer.add_module(str('last_Tanh%i' % 0), nn.Tanh()) 


In [12]:
result = decoder_input(z)

result = result.view(-1,encoder_shapes[-1][1],encoder_shapes[-1][2],encoder_shapes[-1][-2],encoder_shapes[-1][-1])
result = result.permute(0,2,1,3,4)
result.shape

torch.Size([30, 16, 256, 4, 4])

In [13]:

for name, layer in decoder.named_children():
    print(name,result.shape)
    if isinstance(layer,ConvLSTM):
        if result.shape[1] != T:
            result = result.permute(0,2,1,3,4)
        else:
            pass
        result, state = layer(result)
        result = result[0].permute(0,2,1,3,4)
    else:
        result = layer(result)
    print('OutShape', result.shape)

ConvLSTM0_transpose torch.Size([30, 16, 256, 4, 4])
OutShape torch.Size([30, 128, 16, 4, 4])
upsample0 torch.Size([30, 128, 16, 4, 4])
OutShape torch.Size([30, 128, 16, 8, 8])
batchnorm0 torch.Size([30, 128, 16, 8, 8])
OutShape torch.Size([30, 128, 16, 8, 8])
relu0 torch.Size([30, 128, 16, 8, 8])
OutShape torch.Size([30, 128, 16, 8, 8])
ConvLSTM1_transpose torch.Size([30, 128, 16, 8, 8])
OutShape torch.Size([30, 64, 16, 8, 8])
upsample1 torch.Size([30, 64, 16, 8, 8])
OutShape torch.Size([30, 64, 16, 16, 16])
batchnorm1 torch.Size([30, 64, 16, 16, 16])
OutShape torch.Size([30, 64, 16, 16, 16])
relu1 torch.Size([30, 64, 16, 16, 16])
OutShape torch.Size([30, 64, 16, 16, 16])
ConvLSTM2_transpose torch.Size([30, 64, 16, 16, 16])
OutShape torch.Size([30, 32, 16, 16, 16])
upsample2 torch.Size([30, 32, 16, 16, 16])
OutShape torch.Size([30, 32, 16, 32, 32])
batchnorm2 torch.Size([30, 32, 16, 32, 32])
OutShape torch.Size([30, 32, 16, 32, 32])
relu2 torch.Size([30, 32, 16, 32, 32])
OutShape torch

In [14]:
for name, layer in final_layer.named_children():
    print(name,result.shape)
    if isinstance(layer,ConvLSTM):
        if result.shape[1] != T:
            result = result.permute(0,2,1,3,4)
        else:
            pass
        result, state = layer(result)
        result = result[0].permute(0,2,1,3,4)
    else:
        result = layer(result)
    print('OutShape', result.shape)

ConvLSTM2_transpose torch.Size([30, 32, 16, 32, 32])
OutShape torch.Size([30, 32, 16, 32, 32])
upsample2 torch.Size([30, 32, 16, 32, 32])
OutShape torch.Size([30, 32, 16, 64, 64])
batchnorm2 torch.Size([30, 32, 16, 64, 64])
OutShape torch.Size([30, 32, 16, 64, 64])
relu2 torch.Size([30, 32, 16, 64, 64])
OutShape torch.Size([30, 32, 16, 64, 64])
last_conv0 torch.Size([30, 32, 16, 64, 64])
OutShape torch.Size([30, 1, 16, 64, 64])
last_Tanh0 torch.Size([30, 1, 16, 64, 64])
OutShape torch.Size([30, 1, 16, 64, 64])


In [15]:
result.shape

torch.Size([30, 1, 16, 64, 64])