In [1]:
# to do:

# 1- change the loss function to variational loss (recon and kl)
# 2- play with hyper-parameters, and loss
# 3- check on test error


import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
from lfads import LFADS_Net
from objective import *

class conv_block(nn.Module):# *args, **kwargs 
    def __init__(self, in_f, out_f):
        super(conv_block,self).__init__()
        
        self.conv1 = nn.Conv3d(in_f, out_f, 
                  kernel_size=3, 
                  padding=1)
        self.pool1 = nn.MaxPool3d(kernel_size=(1,2,2),
                     return_indices=True)
        self.relu1 = nn.ReLU()
        
    def forward(self,x):
        
        x = self.conv1(x)
        x, ind = self.pool1(x)
        x = self.relu1(x)
        
        return x, ind
        


class deconv_block(nn.Module):
    def __init__(self, in_f, out_f):
        super(deconv_block,self).__init__()
        
        self.unpool1 = nn.MaxUnpool3d(kernel_size=(1,2,2))
        
        self.deconv1 = nn.ConvTranspose3d(in_channels=in_f,
                                          out_channels=out_f,
                                          kernel_size=3,
                                          padding=1, 
                                         )
        self.relu1 = nn.ReLU()
        
    def forward(self,x,ind):
        
        x = self.unpool1(x,ind)
        x = self.deconv1(x)
        x = self.relu1(x)
        
        return x



class convVAE(nn.Module):
    def __init__(self):
        super(convVAE,self).__init__()
        
        in_f = 1
        out_f = [2,3]
        all_f = [in_f,*out_f]
        self.n_layers = 2
        
        self.video_dim_space = 128
        self.video_dim_time = 50
        self.final_size = 32
        self.final_f = 3
        
        self.convlayers = nn.ModuleList()
        for n in range(0,self.n_layers):
            self.convlayers.add_module('{}{}'.format('ce', n),conv_block(all_f[n], all_f[n+1]))
#         self.convlayers.add_module('ce1',conv_block(out_f1, out_f2))
        
        self.deconvlayers = nn.ModuleList()
        for n in range(0,self.n_layers):
            self.deconvlayers.add_module('{}{}'.format('dec', n),deconv_block(all_f[self.n_layers-n], all_f[self.n_layers-n-1]))
#         self.deconvlayers.add_module('dec0',deconv_block(out_f2,out_f1))
#         self.deconvlayers.add_module('dec1',deconv_block(out_f1,in_f))
#         self.ce1 = conv_block(in_f, out_f1) 
#         self.ce2 = conv_block(out_f1, out_f2)

#         self.dec1 = deconv_block(out_f2,out_f1)
#         self.dec2 = deconv_block(out_f1,in_f) 

        self.lfads = LFADS_Net(self.final_size * self.final_size * self.final_f, output_size = None, factor_size = 4,
                 g_encoder_size  = 64, c_encoder_size = 64,
                 g_latent_size   = 64, u_latent_size  = 1,
                 controller_size = 64, generator_size = 64,
                 prior = {'g0' : {'mean' : {'value': 0.0, 'learnable' : True},
                                  'var'  : {'value': 0.1, 'learnable' : False}},
                          'u'  : {'mean' : {'value': 0.0, 'learnable' : False},
                                  'var'  : {'value': 0.1, 'learnable' : True},
                                  'tau'  : {'value': 10,  'learnable' : True}}},
                 clip_val=5.0, dropout=0.0, max_norm = 200, deep_freeze = False,
                 do_normalize_factors=True, device='cpu')

        
    def forward(self,video):
        x = video
        Ind = list()
        for n, layer in enumerate(self.convlayers):
            x, ind1 = layer(x)
            Ind.append(ind1)
        
        x = x.permute(0,2,1,3,4)
        x = x.reshape(x.shape[0],x.shape[1],-1)
        
        x = x.permute(1,0,2)
        r,_ = self.lfads(x)
        x = r['data']
        x = x.permute(1,0,2)
        # call LFADS here:
        # x should be reshaped for LFADS [time x batch x cells]:
        # 
        # LFADS output should be also reshaped back for the conv decoder
        
        x = x.reshape(x.shape[0],x.shape[1],self.final_f,self.final_size, self.final_size)
        x = x.permute(0,2,1,3,4)

        
        
        for n, layer in enumerate(self.deconvlayers):     
            x = layer(x,Ind[self.n_layers-n-1])
            

#         x, ind1 = self.ce0(video)
#         x, ind2 = self.ce1(x)
#         x = self.dec0(x,ind2)
#         v_p = self.dec1(x,ind1)
        

#         return v_p
        return x


def get_data():
    
    from synthetic_data import generate_lorenz_data, SyntheticCalciumVideoDataset

    # convert data to torch.FloatTensor
    transform = transforms.ToTensor()

    # load the training and test datasets

    data_dict = generate_lorenz_data(20, 65, 50, 50, save=False)
    cells = data_dict['cells']
    traces = data_dict['train_fluor']
    train_data = SyntheticCalciumVideoDataset(traces=traces, cells=cells)
    test_data = SyntheticCalciumVideoDataset(traces=traces, cells=cells)
    
    num_workers = 0
    # how many samples per batch to load
    batch_size = 20

    # prepare data loaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)

    return train_data, train_loader, test_loader

class convLFADS_loss(nn.Module):
    def __init__():
        pass
    
    
    
def train_convVAE(train_loader,test_loader,n_epochs): #model,
    model = convVAE()

    # number of epochs to train the model
#     n_epochs = 30
#     train_loader, test_loader = get_data()
#     model = convVAE()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    

    for epoch in range(1, n_epochs+1):
        # monitor training loss
        train_loss = 0.0

        ###################
        # train the model #
        ###################
        i = 0
        for data in train_loader:
#             print(i)
            # _ stands in for labels, here
            # no need to flatten images
            videos = data
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            outputs = model(videos)
            # calculate the loss
            loss = criterion(outputs, videos)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # update running training loss
            train_loss += loss.item()*videos.size(0)
            i += 1

        # print avg training statistics 
        train_loss = train_loss/len(train_loader)
        print('Epoch: {} \tTraining Loss: {:.6f}'.format(
            epoch, 
            train_loss
            ))

        

In [2]:
train_data, train_loader, test_loader = get_data()
train_convVAE(train_loader,test_loader,10)

Generating Lorenz data
Converting to rates and spikes
Converting to fluorescence
Train and test split
Saving to .//synth_data/lorenz_100
Epoch: 1 	Training Loss: 20.495386
Epoch: 2 	Training Loss: 17.015049
Epoch: 3 	Training Loss: 13.311833
Epoch: 4 	Training Loss: 11.157027
Epoch: 5 	Training Loss: 9.906223
Epoch: 6 	Training Loss: 9.051613
Epoch: 7 	Training Loss: 8.469227
Epoch: 8 	Training Loss: 8.045379
Epoch: 9 	Training Loss: 7.779731
Epoch: 10 	Training Loss: 7.615138
