Compared to autoencoder, this architecture does not include reconstruction.

variants to try:
* data segment number: 3, 5, 8, 10
* eye video vs. head video

fixed values:
* kernel sizes (3, 3, 3). If we have more time, could try (3, 5, 9) (best config from hyperparameter search)

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Subset, DataLoader, ConcatDataset
from mouse_model.data_utils_new import MouseDatasetSegNewBehav
import numpy as np
from mouse_model.evaluation import cor_in_time
from torchvision import models
from sklearn.metrics import r2_score, mean_squared_error
import timm
from kornia.geometry.transform import get_affine_matrix2d, warp_affine

In [2]:
# useful for printing in nn.Sequential
class PrintLayer(nn.Module):
    
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        print(x.shape)
        return x

def size_helper(in_length, kernel_size, padding=0, dilation=1, stride=1):
    # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
    res = in_length + 2 * padding - dilation * (kernel_size - 1) - 1
    res /= stride
    res += 1
    return np.floor(res)

# CNN, the last fully connected layer maps to output_dim
class VisualEncoder(nn.Module):
    
    def __init__(self, output_dim, input_shape=(60, 80), k1=3, k2=3, k3=3):
        
        super().__init__()
        
        self.input_shape = (60, 80)
        
        efficientnet = timm.create_model('efficientnet_b0', pretrained=True)
        efficientnet.conv_stem = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        efficientnet.classifier = nn.Linear(in_features=1280, out_features=output_dim, bias=True)
        self.layers =  efficientnet
    def forward(self, x):

        x = self.layers(x)

        return x


class Shifter(nn.Module):
    def __init__(self, input_dim=4, output_dim=3, hidden_dim=256):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.layers = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh(),
        )
        self.bias = nn.Parameter(torch.zeros(3))
    def forward(self, x):
        x = x.reshape(-1,self.input_dim)
        x = self.layers(x)
        x0 = (x[...,0] + self.bias[0]) * 80/4
        x1 = (x[...,1] + self.bias[1]) * 60/4
        x2 = (x[...,2] + self.bias[2]) * 180/4
        x = torch.stack([x0, x1, x2], dim=-1)
        x = x.reshape(-1,1,self.output_dim)
        return x

In [3]:
class Predictor(nn.Module):
    
    def __init__(self, num_neurons, k1, k2, k3):

        super().__init__()
        
        self.encoder = VisualEncoder(output_dim=num_neurons, k1=k1, k2=k2, k3=k3)
        self.softplus = nn.Softplus()
        self.shifter = Shifter()

    def forward(self, images, behav):
        if args.shifter:
            bs = images.size()[0]
            behav_shifter = torch.concat((behav[...,4].unsqueeze(-1),   # theta
                                              behav[...,3].unsqueeze(-1),   # phi
                                              behav[...,1].unsqueeze(-1),  # pitch
                                             behav[...,2].unsqueeze(-1),  # roll
                                             ), dim=-1)  
            shift_param = self.shifter(behav_shifter)  
            shift_param = shift_param.reshape(-1,3)
            scale_param = torch.ones_like(shift_param[..., 0:2]).to(shift_param.device)
            affine_mat = get_affine_matrix2d(translations=shift_param[..., 0:2] ,
                                                scale = scale_param, 
                                                 center =torch.repeat_interleave(torch.tensor([[30,40]], dtype=torch.float), 
                                                                                bs*1, dim=0).to(shift_param.device), 
                                                 angle=shift_param[..., 2])
            affine_mat = affine_mat[:, :2, :]
            images = warp_affine(images, affine_mat, dsize=(60,80))
        
        pred = self.encoder(images)
        pred = self.softplus(pred)
        
        return pred

In [4]:
class Args:
    
    seed = 0
    file_id = None
    epochs = 100
    batch_size = 256
    seq_len = 1
    num_neurons = None
    learning_rate = 0.0001
    segment_num = None
    best_train_path = None
    best_val_path = None
    vid_type = None
    shifter = True
    
    
args=Args()

In [5]:
def load_train_val_ds():
    ds_list = [MouseDatasetSegNewBehav(file_id=args.file_id, segment_num=args.segment_num, seg_idx=i, data_split="train", 
                               vid_type=args.vid_type, seq_len=args.seq_len, predict_offset=1) 
               for i in range(args.segment_num)]
    train_ds, val_ds = [], []
    for ds in ds_list:
        train_ratio = 0.8
        train_ds_len = int(len(ds) * train_ratio)
        train_ds.append(Subset(ds, np.arange(0, train_ds_len, 1)))
        val_ds.append(Subset(ds, np.arange(train_ds_len, len(ds), 1)))
    train_ds = ConcatDataset(train_ds)
    val_ds = ConcatDataset(val_ds)
    print(len(train_ds), len(val_ds))
    return train_ds, val_ds

In [6]:
def load_test_ds():
    test_ds = [MouseDatasetSegNewBehav(file_id=args.file_id, segment_num=args.segment_num, seg_idx=i, data_split="test", 
                               vid_type=args.vid_type, seq_len=args.seq_len, predict_offset=1) 
               for i in range(args.segment_num)]
    test_ds = ConcatDataset(test_ds)
    return test_ds

In [7]:
def train_model():
    
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    
    train_ds, val_ds = load_train_val_ds()

    train_dataloader = DataLoader(dataset=train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4)
    val_dataloader = DataLoader(dataset=val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    best_train_loss = np.inf
    best_val_loss = np.inf
    train_loss_list = []
    val_loss_list = []

    # start training
    ct = 0
    for epoch in range(args.epochs):

        print("Start epoch", epoch)

        model.train()
        print(sum([param.nelement() for param in model.parameters()]))


        epoch_train_loss = 0

        for (image, behav, spikes) in train_dataloader:

            image, behav, spikes = image.to(device), behav.to(device),spikes.to(device)
            image = torch.squeeze(image, axis=1)

            pred = model(image, behav)
            
            loss = nn.functional.poisson_nll_loss(pred, spikes, reduction='mean', log_input=False)

            epoch_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_train_loss = epoch_train_loss / len(train_dataloader)

        train_loss_list.append(epoch_train_loss)

        if epoch_train_loss < best_train_loss:

            torch.save(model.state_dict(), args.best_train_path)
            best_train_loss = epoch_train_loss

        print("Epoch {} train loss: {}".format(epoch, epoch_train_loss))

        model.eval()

        epoch_val_loss = 0

        with torch.no_grad():      

            for (image, behav, spikes) in val_dataloader:

                image,behav, spikes = image.to(device), behav.to(device),spikes.to(device)
                image = torch.squeeze(image, axis=1)

                pred = model(image, behav)

                loss = nn.functional.poisson_nll_loss(pred, spikes, reduction='mean', log_input=False)

                epoch_val_loss += loss.item()

        epoch_val_loss = epoch_val_loss / len(val_dataloader)

        val_loss_list.append(epoch_val_loss)

        if epoch_val_loss < best_val_loss:
            ct = 0
            torch.save(model.state_dict(), args.best_val_path)
            best_val_loss = epoch_val_loss
        else: 
            ct+=1
            if ct>=5: 
                print('stop trianing')
                break

        print("Epoch {} val loss: {}".format(epoch, epoch_val_loss))

        print("End epoch", epoch)
        
    return train_loss_list, val_loss_list

In [8]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# for file_id, num_neurons in [("110421_J569LT", 52), ("101521_J559NC", 63), ("070921_J553RT", 108)]:
for file_id, num_neurons in [("070921_J553RT", 68), ("101521_J559NC", 49), ("110421_J569LT", 32)]:
    for vid_type in ["vid_mean"]:
        for shifter in [True, False]:
            print(file_id, vid_type, shifter)
            args.segment_num = 10
            args.vid_type = vid_type
            args.num_neurons = num_neurons
            args.shifter = shifter
            args.file_id = file_id
            args.best_train_path = "/hdd/yuchen/trainEffNetShifter{}_{}_{}_{}.pth".format(args.shifter, args.segment_num, 
                                                                                            args.vid_type, 
                                                                                            args.file_id)
            args.best_val_path = "/hdd/yuchen/valEffNetShifter{}_{}_{}_{}.pth".format(args.shifter, args.segment_num, 
                                                                                        args.vid_type, 
                                                                                        args.file_id)
            
            model = Predictor(num_neurons=args.num_neurons, k1=7, k2=7, k3=7).to(device)
    
            train_loss_list, val_loss_list = train_model()


070921_J553RT vid_mean True
30120 7540
Start epoch 0
4162958
Epoch 0 train loss: 0.7149918049068774
Epoch 0 val loss: 0.6612052838007609
End epoch 0
Start epoch 1
4162958
Epoch 1 train loss: 0.628875284376791
Epoch 1 val loss: 0.6400157968203227
End epoch 1
Start epoch 2
4162958
Epoch 2 train loss: 0.613935808003959
Epoch 2 val loss: 0.6331848402818044
End epoch 2
Start epoch 3
4162958
Epoch 3 train loss: 0.6048753685870413
Epoch 3 val loss: 0.6299521227677664
End epoch 3
Start epoch 4
4162958
Epoch 4 train loss: 0.5973379601866512
Epoch 4 val loss: 0.6300125598907471
End epoch 4
Start epoch 5
4162958
Epoch 5 train loss: 0.5911980357210515
Epoch 5 val loss: 0.6264310300350189
End epoch 5
Start epoch 6
4162958
Epoch 6 train loss: 0.5852572791657206
Epoch 6 val loss: 0.6293766975402832
End epoch 6
Start epoch 7
4162958
Epoch 7 train loss: 0.5786201297226599
Epoch 7 val loss: 0.6309882203737894
End epoch 7
Start epoch 8
4162958
Epoch 8 train loss: 0.5724357300895756
Epoch 8 val loss: 0.62

eval

In [9]:
# default is smoothing with 2 second, 48 ms per frame
def smoothing_with_np_conv(nsp, size=int(2000/48)):
    np_conv_res = []
    for i in range(nsp.shape[1]):
        np_conv_res.append(np.convolve(nsp[:, i], np.ones(size)/size, mode="same"))        
    np_conv_res = np.transpose(np.array(np_conv_res))
    return np_conv_res

In [10]:
def evaluate_model(model, weights_path, dataset, device):

    dl = DataLoader(dataset=dataset, batch_size=256, shuffle=False, num_workers=2)
    
    model.load_state_dict(torch.load(weights_path))

    ground_truth_all = []
    pred_all = []
    
    model.eval()
    
    with torch.no_grad():      
        
        for (image, behav, spikes) in dl:
            
            image = image.to(device)
            behav = behav.to(device)
            
            image = torch.squeeze(image, axis=1)
            
            pred = model(image, behav)
            
            ground_truth_all.append(spikes.numpy())
            pred_all.append(pred.cpu().numpy())
    
    return np.concatenate(pred_all, axis=0), np.concatenate(ground_truth_all, axis=0)

In [11]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

# for file_id, num_neurons in [("110421_J569LT", 52), ("101521_J559NC", 63), ("070921_J553RT", 108)]:
for file_id, num_neurons in [("070921_J553RT", 68), ("101521_J559NC", 49), ("110421_J569LT", 32)]:

    for vid_type in ["vid_mean"]:
        for shifter in [True, False]:
            args.shifter=shifter
    
            print(file_id, vid_type, shifter)
    
            args.segment_num = 10
            args.vid_type = vid_type
            args.num_neurons = num_neurons
            args.file_id = file_id
            args.best_train_path = "/hdd/yuchen/trainEffNetShifter{}_{}_{}_{}.pth".format(args.shifter, 
                                                                                           args.segment_num,
                                                                                            args.vid_type, 
                                                                                            args.file_id)
            args.best_val_path = "/hdd/yuchen/valEffNetShifter{}_{}_{}_{}.pth".format(args.shifter, 
                                                                                       args.segment_num, 
                                                                                        args.vid_type, 
                                                                                        args.file_id)
            
            model = Predictor(num_neurons=args.num_neurons, k1=7, k2=7, k3=7).to(device)
        
            train_ds, val_ds = load_train_val_ds()
            test_ds = load_test_ds()
        
            pred, label = evaluate_model(model, weights_path=args.best_val_path, dataset=test_ds, device=device)
            cor_array = cor_in_time(pred, label)
            
            pred = smoothing_with_np_conv(pred)
            label = smoothing_with_np_conv(label)
            # print("R2", "{:.6f}".format(r2_score(label.T, pred.T)))
            print("MSE", "{:.6f}".format(mean_squared_error(label, pred)))
            cor_array = cor_in_time(pred, label)
            print("mean corr, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
            # print("max corr", "{:.6f}".format(np.max(cor_array)))
            # print("min corr", "{:.6f}".format(np.min(cor_array)))

070921_J553RT vid_mean True
30120 7540
MSE 0.069431
mean corr, 0.542+-0.153
070921_J553RT vid_mean False
30120 7540
MSE 0.068346
mean corr, 0.521+-0.153
101521_J559NC vid_mean True
42410 10610
MSE 0.096501
mean corr, 0.510+-0.127
101521_J559NC vid_mean False
42410 10610
MSE 0.098369
mean corr, 0.468+-0.145
110421_J569LT vid_mean True
32940 8240
MSE 0.103219
mean corr, 0.393+-0.165
110421_J569LT vid_mean False
32940 8240
MSE 0.109311
mean corr, 0.349+-0.183
