In [1]:
import torch
import torch.nn as nn
import numpy as np
import random
import os
import time
import torch.nn.init as init
from torch.nn import functional as F
from kornia.geometry.transform import get_affine_matrix2d, warp_affine
import matplotlib.pyplot as plt

the model

In [2]:
class Shifter(nn.Module):
    
    def __init__(self, input_dim=4, output_dim=3, hidden_dim=256, seq_len=8):
        super().__init__()
        self.seq_len = seq_len
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.layers = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh(),
        )
        self.bias = nn.Parameter(torch.zeros(3))
    
    def forward(self, x):
        x = x.reshape(-1,self.input_dim )
        x = self.layers(x)
        x0 = (x[...,0] + self.bias[0]) * 80/5.5
        x1 = (x[...,1] + self.bias[1]) * 60/5.5
        x2 = (x[...,2] + self.bias[2]) * 180/4
        x = torch.stack([x0, x1, x2], dim=-1)
        x = x.reshape(-1,self.seq_len,self.output_dim)
        return x

class PrintLayer(nn.Module):
    
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        print(x.shape)
        return x
    
def size_helper(in_length, kernel_size, padding=0, dilation=1, stride=1):
    # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
    res = in_length + 2 * padding - dilation * (kernel_size - 1) - 1
    res /= stride
    res += 1
    return np.floor(res)

# CNN, the last fully connected layer maps to output_dim
class VisualEncoder(nn.Module):
    
    def __init__(self, output_dim, input_shape=(60, 80), k1=7, k2=7, k3=7):
        
        super().__init__()
        
        self.input_shape = (60, 80)
        out_shape_0 = size_helper(in_length=input_shape[0], kernel_size=k1, stride=2)
        out_shape_0 = size_helper(in_length=out_shape_0, kernel_size=k2, stride=2)
        out_shape_0 = size_helper(in_length=out_shape_0, kernel_size=k3, stride=2)
        out_shape_1 = size_helper(in_length=input_shape[1], kernel_size=k1, stride=2)
        out_shape_1 = size_helper(in_length=out_shape_1, kernel_size=k2, stride=2)
        out_shape_1 = size_helper(in_length=out_shape_1, kernel_size=k3, stride=2)
        self.output_shape = (int(out_shape_0), int(out_shape_1)) # shape of the final feature map
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=k1, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=k2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=k3, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Flatten(),
            nn.Linear(480, output_dim)
        )
        
    def forward(self, x):

        x = self.layers(x)

        return x

    
# may consider adding an activation after linear
class BehavEncoder(nn.Module):
    
    def __init__(self, behav_dim, output_dim):
        
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.BatchNorm1d(behav_dim),
            nn.Linear(behav_dim, output_dim),
        )
        
    def forward(self, x):

        x = self.layers(x)

        return x

class LSTMPerNeuronCombiner(nn.Module):
    
    def __init__(self, num_neurons, behav_dim, k1, k2, k3, seq_len, hidden_size=512):
        
        super().__init__()
        
        self.seq_len = seq_len
        self.num_neurons = num_neurons
        self.shifter = Shifter(seq_len = seq_len)
        self.visual_encoder = VisualEncoder(output_dim=num_neurons, k1=k1, k2=k2, k3=k3)
        self.behav_encoder = BehavEncoder(behav_dim=behav_dim, output_dim=num_neurons)
        self.bn = nn.BatchNorm1d(3) # apply bn to vis_feats, beh_feats, prod
        self.lstm_net = nn.GRU(input_size=num_neurons*3, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_neurons)
        self.softplus = nn.Softplus() # we could also do relu or elu offset by 1
        
    def forward(self, images, behav):
        if args.shifter:
            bs = images.size()[0]
            behav_shifter = torch.concat((behav[...,4].unsqueeze(-1),   # theta
                                          behav[...,3].unsqueeze(-1),   # phi
                                          behav[...,1].unsqueeze(-1),  # pitch
                                         behav[...,2].unsqueeze(-1),  # roll
                                         ), dim=-1)  
            shift_param = self.shifter(behav_shifter)  
            shift_param = shift_param.reshape(-1,3)
            scale_param = torch.ones_like(shift_param[..., 0:2]).to(shift_param.device)
            affine_mat = get_affine_matrix2d(
                                            translations=shift_param[..., 0:2] ,
                                             scale = scale_param, 
                                             center =torch.repeat_interleave(torch.tensor([[30,40]], dtype=torch.float), 
                                                                            bs*self.seq_len, dim=0).to(shift_param.device), 
                                             angle=shift_param[..., 2])
            affine_mat = affine_mat[:, :2, :]
            images = warp_affine(images.reshape(-1,1,60,80), affine_mat, dsize=(60,80)).reshape(bs, self.seq_len,1,60,80)
        
        # get visual behavioral features in time
        vis_beh_feats = []
        for i in range(self.seq_len):
            v = self.visual_encoder(images[:, i, :, :, :])
            b = self.behav_encoder(behav[:, i, :])
            vb = v * b
            vis_beh_feat = torch.stack([v, b, vb], axis=1)
            vis_beh_feat = self.bn(vis_beh_feat)
            vis_beh_feats.append(vis_beh_feat)
        vis_beh_feats = torch.stack(vis_beh_feats, axis=1)
        
        # flatten features to (batch_size, seq_len, num_neurons*3)
        vis_beh_feats = torch.flatten(vis_beh_feats, start_dim=2)
        
        # get LSTM output
        output, _ = self.lstm_net(vis_beh_feats)
        output = output[:, -1, :] # extract the last hidden state
        
        # fully connected layer and activation function
        output = self.fc(output)
        pred_spikes = self.softplus(output)

        return pred_spikes

set args & random seed

In [3]:
class Args:
    
    seed = 0
    file_id = None
    epochs = 50
    batch_size = 256
    learning_rate = 0.0002
    l1_reg_w = 1
    seq_len = None
    num_neurons = None
    behav_mode = None
    behav_dim = None
    best_val_path = None
    best_train_path = None
    vid_type = "vid_mean"
    segment_num = 10
    hidden_size = 512
    shifter = True
    
args=Args()

def set_random_seed(seed: int, deterministic: bool = True):
    # from nnfabrik package
    random.seed(seed)
    np.random.seed(seed)
    if deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)  # this sets both CPU and CUDA seeds for PyTorch

seed = args.seed
set_random_seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.cuda.empty_cache()
print(torch.cuda.is_available())

True


gradient ascent function

In [4]:
device="cuda"

def gradient_ascent(inp_img, inp_beh, model, idx, lr=0.1, weight_decay=2.5e-4, laplace_reg_w=0.002):
    
    # if we have both visual input and behavior input, the list could be expanded like [vis, behav]
    # should be easy to do
    optimizer = torch.optim.Adam([inp_img, inp_beh], lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=20, factor=0.2, min_lr=0.00001, verbose=False)
    
#     laplacian reg's filter
    laplacian_filter_2d = torch.unsqueeze(torch.unsqueeze(
        torch.tensor(np.array([[0,-1,0],[-1,4,-1],[0,-1,0]]), dtype=torch.float, device='cuda'), dim=0), dim=0)

    losses = []
    for epoch in range(100):
        loss = 0
        optimizer.zero_grad()
        out = model(inp_img, inp_beh)
        lap_reg = torch.norm(nn.functional.conv2d(torch.squeeze(inp_img, axis=1), laplacian_filter_2d), p=2)
#         print(out.shape)
#         print(out[:, idx].shape)
        loss = torch.sum(-out[:, idx] + laplace_reg_w*lap_reg)
        losses.append(loss.item())
        loss.backward(retain_graph=True)
        optimizer.step()
        scheduler.step(loss)
        
    return inp_img, inp_beh, losses

Mouse 1: 070921_J553RT

In [5]:
args.file_id = "070921_J553RT"
args.num_neurons = 68
args.shifter = True

args.behav_mode = "all_prod"
args.behav_dim = 66

args.seq_len = 1

model = LSTMPerNeuronCombiner(num_neurons=args.num_neurons, 
                          behav_dim=args.behav_dim, 
                          k1=7, k2=7, k3=7, 
                          seq_len=args.seq_len,
                          hidden_size=args.hidden_size).to(device)

weights_path = "weights_cnn_gru_shifter/val_070921_J553RT_all_prod_seq_1.pth"

model.load_state_dict(torch.load(weights_path))
# freeze layer param
for param in model.parameters():
    param.requires_grad = False
# freeze batchnorm statistics
model.shifter.layers[0].eval()
model.shifter.layers[2].eval()
model.shifter.layers[5].eval()
model.visual_encoder.layers[1].eval()
model.visual_encoder.layers[5].eval()
model.visual_encoder.layers[9].eval()
model.behav_encoder.layers[0].eval()
model.bn.eval()

BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [None]:
for neuron_idx in range(68):

    img_list, beh_list = [], []
    
    print("neuron index", neuron_idx)
    
    for i in range(100):
        
        inp_img = torch.normal(0.5, .25, (64, 1, 1, 60, 80), requires_grad=True, device=device, dtype=torch.float)
        inp_beh = torch.ones((64, 1, 66),  requires_grad=True, device=device, dtype=torch.float)
        inp_img, inp_beh, loss = gradient_ascent(inp_img=inp_img, inp_beh=inp_beh, 
                                                 model=model, idx=neuron_idx, 
                                                 lr=0.1, weight_decay=0.02, laplace_reg_w=0.01)

        img_list.append(np.squeeze(inp_img.cpu().detach().numpy()))
        beh_list.append(np.squeeze(inp_beh.cpu().detach().numpy()))
        
    img_arr = np.array(img_list)
    mean_img = np.mean(np.concatenate(img_arr, axis=0), axis=0)

    beh_arr = np.array(beh_list)
    mean_beh = np.mean(np.concatenate(beh_arr, axis=0), axis=0)

    print(mean_img.shape, mean_beh.shape)
        
    np.save("gradient_ascent_cnn_gru_shifter/070921_J553RT/img_{}.npy".format(neuron_idx), mean_img)
    np.save("gradient_ascent_cnn_gru_shifter/070921_J553RT/beh_{}.npy".format(neuron_idx), mean_beh)

neuron index 0
(60, 80) (66,)
neuron index 1
(60, 80) (66,)
neuron index 2
(60, 80) (66,)
neuron index 3
(60, 80) (66,)
neuron index 4
(60, 80) (66,)
neuron index 5
(60, 80) (66,)
neuron index 6
(60, 80) (66,)
neuron index 7
(60, 80) (66,)
neuron index 8
(60, 80) (66,)
neuron index 9
(60, 80) (66,)
neuron index 10
(60, 80) (66,)
neuron index 11
(60, 80) (66,)
neuron index 12
(60, 80) (66,)
neuron index 13
(60, 80) (66,)
neuron index 14
(60, 80) (66,)
neuron index 15
(60, 80) (66,)
neuron index 16
(60, 80) (66,)
neuron index 17
(60, 80) (66,)
neuron index 18
(60, 80) (66,)
neuron index 19
(60, 80) (66,)
neuron index 20
(60, 80) (66,)
neuron index 21
(60, 80) (66,)
neuron index 22
(60, 80) (66,)
neuron index 23
(60, 80) (66,)
neuron index 24
(60, 80) (66,)
neuron index 25
(60, 80) (66,)
neuron index 26
(60, 80) (66,)
neuron index 27
(60, 80) (66,)
neuron index 28
(60, 80) (66,)
neuron index 29
(60, 80) (66,)
neuron index 30
(60, 80) (66,)
neuron index 31
(60, 80) (66,)
neuron index 32
(6

Mouse 2: 110421_J569LT

In [5]:
args.file_id = "110421_J569LT"
args.num_neurons = 32
args.shifter = True

args.behav_mode = "all_prod"
args.behav_dim = 66

args.seq_len = 1

model = LSTMPerNeuronCombiner(num_neurons=args.num_neurons, 
                          behav_dim=args.behav_dim, 
                          k1=7, k2=7, k3=7, 
                          seq_len=args.seq_len,
                          hidden_size=args.hidden_size).to(device)

weights_path = "weights_cnn_gru_shifter/val_110421_J569LT_all_prod_seq_1.pth"

model.load_state_dict(torch.load(weights_path))
# freeze layer param
for param in model.parameters():
    param.requires_grad = False
# freeze batchnorm statistics
model.shifter.layers[0].eval()
model.shifter.layers[2].eval()
model.shifter.layers[5].eval()
model.visual_encoder.layers[1].eval()
model.visual_encoder.layers[5].eval()
model.visual_encoder.layers[9].eval()
model.behav_encoder.layers[0].eval()
model.bn.eval()

BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [6]:
for neuron_idx in range(32):

    img_list, beh_list = [], []
    
    print("neuron index", neuron_idx)
    
    for i in range(100):
        
        inp_img = torch.normal(0.5, .25, (64, 1, 1, 60, 80), requires_grad=True, device=device, dtype=torch.float)
        inp_beh = torch.ones((64, 1, 66),  requires_grad=True, device=device, dtype=torch.float)
        inp_img, inp_beh, loss = gradient_ascent(inp_img=inp_img, inp_beh=inp_beh, 
                                                 model=model, idx=neuron_idx, 
                                                 lr=0.1, weight_decay=0.02, laplace_reg_w=0.01)

        img_list.append(np.squeeze(inp_img.cpu().detach().numpy()))
        beh_list.append(np.squeeze(inp_beh.cpu().detach().numpy()))
        
    img_arr = np.array(img_list)
    mean_img = np.mean(np.concatenate(img_arr, axis=0), axis=0)

    beh_arr = np.array(beh_list)
    mean_beh = np.mean(np.concatenate(beh_arr, axis=0), axis=0)

    print(mean_img.shape, mean_beh.shape)
        
    np.save("gradient_ascent_cnn_gru_shifter/110421_J569LT/img_{}.npy".format(neuron_idx), mean_img)
    np.save("gradient_ascent_cnn_gru_shifter/110421_J569LT/beh_{}.npy".format(neuron_idx), mean_beh)

neuron index 0
(60, 80) (66,)
neuron index 1
(60, 80) (66,)
neuron index 2
(60, 80) (66,)
neuron index 3
(60, 80) (66,)
neuron index 4
(60, 80) (66,)
neuron index 5
(60, 80) (66,)
neuron index 6
(60, 80) (66,)
neuron index 7
(60, 80) (66,)
neuron index 8
(60, 80) (66,)
neuron index 9
(60, 80) (66,)
neuron index 10
(60, 80) (66,)
neuron index 11
(60, 80) (66,)
neuron index 12
(60, 80) (66,)
neuron index 13
(60, 80) (66,)
neuron index 14
(60, 80) (66,)
neuron index 15
(60, 80) (66,)
neuron index 16
(60, 80) (66,)
neuron index 17
(60, 80) (66,)
neuron index 18
(60, 80) (66,)
neuron index 19
(60, 80) (66,)
neuron index 20
(60, 80) (66,)
neuron index 21
(60, 80) (66,)
neuron index 22
(60, 80) (66,)
neuron index 23
(60, 80) (66,)
neuron index 24
(60, 80) (66,)
neuron index 25
(60, 80) (66,)
neuron index 26
(60, 80) (66,)
neuron index 27
(60, 80) (66,)
neuron index 28
(60, 80) (66,)
neuron index 29
(60, 80) (66,)
neuron index 30
(60, 80) (66,)
neuron index 31
(60, 80) (66,)


Mouse 3: 101521_J559NC

In [5]:
args.file_id = "101521_J559NC"
args.num_neurons = 49
args.shifter = True

args.behav_mode = "all_prod"
args.behav_dim = 66

args.seq_len = 1

model = LSTMPerNeuronCombiner(num_neurons=args.num_neurons, 
                          behav_dim=args.behav_dim, 
                          k1=7, k2=7, k3=7, 
                          seq_len=args.seq_len,
                          hidden_size=args.hidden_size).to(device)

weights_path = "weights_cnn_gru_shifter/val_101521_J559NC_all_prod_seq_1.pth"

model.load_state_dict(torch.load(weights_path))
# freeze layer param
for param in model.parameters():
    param.requires_grad = False
# freeze batchnorm statistics
model.shifter.layers[0].eval()
model.shifter.layers[2].eval()
model.shifter.layers[5].eval()
model.visual_encoder.layers[1].eval()
model.visual_encoder.layers[5].eval()
model.visual_encoder.layers[9].eval()
model.behav_encoder.layers[0].eval()
model.bn.eval()

BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [None]:
for neuron_idx in range(49):

    img_list, beh_list = [], []
    
    print("neuron index", neuron_idx)
    
    for i in range(100):
        
        inp_img = torch.normal(0.5, .25, (64, 1, 1, 60, 80), requires_grad=True, device=device, dtype=torch.float)
        inp_beh = torch.ones((64, 1, 66),  requires_grad=True, device=device, dtype=torch.float)
        inp_img, inp_beh, loss = gradient_ascent(inp_img=inp_img, inp_beh=inp_beh, 
                                                 model=model, idx=neuron_idx, 
                                                 lr=0.1, weight_decay=0.02, laplace_reg_w=0.01)

        img_list.append(np.squeeze(inp_img.cpu().detach().numpy()))
        beh_list.append(np.squeeze(inp_beh.cpu().detach().numpy()))
        
    img_arr = np.array(img_list)
    mean_img = np.mean(np.concatenate(img_arr, axis=0), axis=0)

    beh_arr = np.array(beh_list)
    mean_beh = np.mean(np.concatenate(beh_arr, axis=0), axis=0)

    print(mean_img.shape, mean_beh.shape)
        
    np.save("gradient_ascent_cnn_gru_shifter/101521_J559NC/img_{}.npy".format(neuron_idx), mean_img)
    np.save("gradient_ascent_cnn_gru_shifter/101521_J559NC/beh_{}.npy".format(neuron_idx), mean_beh)

neuron index 0
(60, 80) (66,)
neuron index 1
(60, 80) (66,)
neuron index 2
(60, 80) (66,)
neuron index 3
(60, 80) (66,)
neuron index 4
(60, 80) (66,)
neuron index 5
(60, 80) (66,)
neuron index 6
(60, 80) (66,)
neuron index 7
(60, 80) (66,)
neuron index 8
(60, 80) (66,)
neuron index 9
(60, 80) (66,)
neuron index 10
(60, 80) (66,)
neuron index 11
(60, 80) (66,)
neuron index 12
(60, 80) (66,)
neuron index 13
(60, 80) (66,)
neuron index 14
(60, 80) (66,)
neuron index 15
(60, 80) (66,)
neuron index 16
(60, 80) (66,)
neuron index 17
(60, 80) (66,)
neuron index 18
(60, 80) (66,)
neuron index 19
(60, 80) (66,)
neuron index 20
(60, 80) (66,)
neuron index 21
(60, 80) (66,)
neuron index 22
(60, 80) (66,)
neuron index 23
(60, 80) (66,)
neuron index 24
(60, 80) (66,)
neuron index 25
(60, 80) (66,)
neuron index 26
(60, 80) (66,)
neuron index 27
(60, 80) (66,)
neuron index 28
(60, 80) (66,)
neuron index 29
(60, 80) (66,)
neuron index 30
(60, 80) (66,)
neuron index 31
(60, 80) (66,)
neuron index 32
(6