In [1]:
import torch
import torch.nn as nn

from torch.utils.data import Subset, DataLoader, ConcatDataset
from mouse_model.data_utils_new import MouseDatasetSegNewBehav
import numpy as np
from mouse_model.evaluation import cor_in_time
from sklearn.metrics import r2_score, mean_squared_error
import random, os
from kornia.geometry.transform import get_affine_matrix2d, warp_affine

the model

In [2]:
# useful for printing in nn.Sequential
class PrintLayer(nn.Module):
    
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        print(x.shape)
        return x

def size_helper(in_length, kernel_size, padding=0, dilation=1, stride=1):
    # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
    res = in_length + 2 * padding - dilation * (kernel_size - 1) - 1
    res /= stride
    res += 1
    return np.floor(res)

# CNN, the last fully connected layer maps to output_dim
class VisualEncoder(nn.Module):
    
    def __init__(self, output_dim, input_shape=(60, 80), k1=3, k2=3, k3=3, c1=32, c2=64, c3=128, dropout_prob=0):
        
        super().__init__()
        
        self.input_shape = (60, 80)
        out_shape_0 = size_helper(in_length=input_shape[0], kernel_size=k1, stride=2)
        out_shape_0 = size_helper(in_length=out_shape_0, kernel_size=k2, stride=2)
        out_shape_0 = size_helper(in_length=out_shape_0, kernel_size=k3, stride=2)
        out_shape_1 = size_helper(in_length=input_shape[1], kernel_size=k1, stride=2)
        out_shape_1 = size_helper(in_length=out_shape_1, kernel_size=k2, stride=2)
        out_shape_1 = size_helper(in_length=out_shape_1, kernel_size=k3, stride=2)
        self.output_shape = (int(out_shape_0), int(out_shape_1)) # shape of the final feature map
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=k1, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=k2, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=k3, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),
            nn.Flatten(),
            nn.Linear(self.output_shape[0]*self.output_shape[1]*c3, output_dim)
        )
        
    def forward(self, x):

        x = self.layers(x)

        return x

In [3]:
class VisualDecoder777(nn.Module):
    
    # input_dim is latent size
    def __init__(self, input_dim, c1=128, c2=64, c3=32, dropout_prob=0):
        
        super().__init__()
        
        self.input_dim = input_dim
        self.conv_layers = nn.Sequential(
            nn.Linear(input_dim, c1*3*5),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),
            nn.Unflatten(1, (c1, 3, 5)),
#             PrintLayer(),
            nn.ConvTranspose2d(in_channels=c1, out_channels=c2, kernel_size=7, stride=2, output_padding=(0,1)),
#             PrintLayer(),
            nn.BatchNorm2d(c2),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),
            nn.ConvTranspose2d(in_channels=c2, out_channels=c3, kernel_size=7, stride=2),
#             PrintLayer(),
            nn.BatchNorm2d(c3),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob),
            nn.ConvTranspose2d(in_channels=c3, out_channels=1, kernel_size=7, stride=2, output_padding=1),
#             PrintLayer(),
            nn.BatchNorm2d(1),
            nn.ReLU(),
        )
        
    def forward(self, x):

        x = self.conv_layers(x)

        return x

In [4]:
class Shifter(nn.Module):
    def __init__(self, input_dim=4, output_dim=3, hidden_dim=128):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.layers = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh(),
        )
        self.bias = nn.Parameter(torch.zeros(3))
    def forward(self, x):
        x = x.reshape(-1,self.input_dim )
        x = self.layers(x)
        x0 = (x[...,0] + self.bias[0]) * 80/4
        x1 = (x[...,1] + self.bias[1]) * 60/4
        x2 = (x[...,2] + self.bias[2]) * 180/4
        x = torch.stack([x0, x1, x2], dim=-1)
        x = x.reshape(-1,1,self.output_dim)
        return x
        
class AutoencoderPredictor777(nn.Module):
    
    def __init__(self, num_neurons, latent_size, c1=32, c2=64, c3=128, dropout_prob=0):

        super().__init__()
        
        self.encoder = VisualEncoder(output_dim=latent_size, k1=7, k2=7, k3=7, 
                                     c1=c1, c2=c2, c3=c3, dropout_prob=dropout_prob)
        self.decoder = VisualDecoder777(input_dim=latent_size, c1=c3, c2=c2, c3=c1, dropout_prob=dropout_prob)
        self.fc = nn.Linear(latent_size, num_neurons)
        self.softplus = nn.Softplus()
        self.shifter = Shifter()

    def forward(self, images, behav):
        if args.shifter:
            bs = images.size()[0]
            behav_shifter = torch.concat((behav[...,4].unsqueeze(-1),   # theta
                                          behav[...,3].unsqueeze(-1),   # phi
                                          behav[...,1].unsqueeze(-1),  # pitch
                                         behav[...,2].unsqueeze(-1),  # roll
                                         ), dim=-1)  
            shift_param = self.shifter(behav_shifter)  
            shift_param = shift_param.reshape(-1,3)
            scale_param = torch.ones_like(shift_param[..., 0:2]).to(shift_param.device)
            affine_mat = get_affine_matrix2d(
                                            translations=shift_param[..., 0:2] ,
                                             scale = scale_param, 
                                             center =torch.repeat_interleave(torch.tensor([[30,40]], dtype=torch.float), 
                                                                            bs*1, dim=0).to(shift_param.device), 
                                             angle=shift_param[..., 2])
            affine_mat = affine_mat[:, :2, :]
            images = warp_affine(images, affine_mat, dsize=(60,80))

        latent_vec = self.encoder(images)
        recon = self.decoder(latent_vec)
        pred = self.fc(latent_vec)
        pred = self.softplus(pred)
        
        return pred, recon

In [5]:
class Args:
    
    seed = 0
    batch_size = 256
    learning_rate = 0.0001
    epochs = 50
    alpha = 0.5
    file_id = None
    num_neurons = None
    behav_mode = "orig"
    best_val_path = None
    best_train_path = None
    vid_type = "vid_mean"
    segment_num = 10
    seq_len = 1
    vis_latent_dim = 256
    shifter=True
    
args=Args()

seed = args.seed
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.cuda.empty_cache()
print(torch.cuda.is_available())

True


In [6]:
def load_train_val_ds():
    ds_list = [MouseDatasetSegNewBehav(
        file_id=args.file_id, seg_idx=i, segment_num=args.segment_num, data_split="train", vid_type=args.vid_type, 
        seq_len=args.seq_len, predict_offset=1, behav_mode=args.behav_mode, norm_mode="01") 
               for i in range(args.segment_num)]
    train_ds, val_ds = [], []
    for ds in ds_list:
        train_ratio = 0.8
        train_ds_len = int(len(ds) * train_ratio)
        train_ds.append(Subset(ds, np.arange(0, train_ds_len, 1)))
        val_ds.append(Subset(ds, np.arange(train_ds_len, len(ds), 1)))
    train_ds = ConcatDataset(train_ds)
    val_ds = ConcatDataset(val_ds)
    print(len(train_ds), len(val_ds))
    return train_ds, val_ds

In [7]:
def load_test_ds():
    test_ds = [MouseDatasetSegNewBehav(
        file_id=args.file_id, seg_idx=i, segment_num=args.segment_num, data_split="test", vid_type=args.vid_type, 
        seq_len=args.seq_len, predict_offset=1, behav_mode=args.behav_mode, norm_mode="01") 
               for i in range(args.segment_num)]
    test_ds = ConcatDataset(test_ds)
    return test_ds

In [8]:
def train_model():
    
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    
    train_ds, val_ds = load_train_val_ds()

    train_dataloader = DataLoader(dataset=train_ds, batch_size=args.batch_size, shuffle=True, num_workers=8)
    val_dataloader = DataLoader(dataset=val_ds, batch_size=args.batch_size, shuffle=False, num_workers=8)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    best_train_spike_loss = np.inf
    best_val_spike_loss = np.inf
    train_loss_list = []
    val_loss_list = []

    ct = 0
    
    # start training
    for epoch in range(args.epochs):

        print("Start epoch", epoch)

        model.train()

        epoch_train_loss, epoch_train_spike_loss, epoch_train_recon_loss = 0, 0, 0

        for (image, behav, spikes) in train_dataloader:

            image, behav, spikes = image.to(device), behav.to(device), spikes.to(device)
            image = torch.squeeze(image, axis=1)
            
            pred, recon = model(image, behav)
            
            recon_loss = torch.nn.functional.mse_loss(recon, image, reduction='mean')
            spike_loss = nn.functional.poisson_nll_loss(pred, spikes, reduction='mean', log_input=False)
            total_loss = spike_loss + args.alpha * recon_loss

            epoch_train_loss += total_loss.item()
            epoch_train_spike_loss += spike_loss.item()
            epoch_train_recon_loss += recon_loss.item()

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        epoch_train_loss = epoch_train_loss / len(train_dataloader)
        epoch_train_spike_loss = epoch_train_spike_loss / len(train_dataloader)
        epoch_train_recon_loss = epoch_train_recon_loss / len(train_dataloader)

        train_loss_list.append([epoch_train_loss, epoch_train_spike_loss, epoch_train_recon_loss])
        
        print("Epoch {} train loss: {}, spike loss: {}, recon loss: {}".format(
            epoch, epoch_train_loss, epoch_train_spike_loss, epoch_train_recon_loss))

        if epoch_train_spike_loss < best_train_spike_loss:

            print("save train model at epoch", epoch)
            torch.save(model.state_dict(), args.best_train_path)
            best_train_spike_loss = epoch_train_spike_loss

        model.eval()

        epoch_val_loss, epoch_val_spike_loss, epoch_val_recon_loss = 0, 0, 0

        with torch.no_grad():      

            for (image, behav, spikes) in val_dataloader:

                image, behav, spikes = image.to(device), behav.to(device), spikes.to(device)
                image = torch.squeeze(image, axis=1)

                pred, recon = model(image, behav)
                
                recon_loss = torch.nn.functional.mse_loss(recon, image, reduction='mean')
                spike_loss = nn.functional.poisson_nll_loss(pred, spikes, reduction='mean', log_input=False)
                total_loss = spike_loss + args.alpha * recon_loss

                epoch_val_loss += total_loss.item()
                epoch_val_spike_loss += spike_loss.item()
                epoch_val_recon_loss += recon_loss.item()
                
        epoch_val_loss = epoch_val_loss / len(val_dataloader)
        epoch_val_spike_loss = epoch_val_spike_loss / len(val_dataloader)
        epoch_val_recon_loss = epoch_val_recon_loss / len(val_dataloader)

        val_loss_list.append([epoch_val_loss, epoch_val_spike_loss, epoch_val_recon_loss])
        
        print("Epoch {} val spike loss: {}".format(epoch, epoch_val_spike_loss))

        if epoch_val_spike_loss < best_val_spike_loss:

            print("save val model at epoch", epoch)
            torch.save(model.state_dict(), args.best_val_path)
            best_val_spike_loss = epoch_val_spike_loss
            ct = 0
            
        else:
            ct += 1
            if ct > 5:
                print('stop training')
                break

        print("End epoch", epoch)
        
    return train_loss_list, val_loss_list

In [None]:
# 64 128 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
for file_id, num_neurons in [("070921_J553RT", 68), ("110421_J569LT", 32), ("101521_J559NC", 49)]:
    for shifter in [True]:
        print(file_id)
        args.vid_type = 'vid_mean'
        args.file_id = file_id
        args.num_neurons = num_neurons
        args.shifter=shifter
        args.best_train_path = "/hdd/yuchen/trainAEshifter{}_{}_{}_{}.pth".format(args.shifter, 1, args.vid_type, file_id)
        args.best_val_path = "/hdd/yuchen/valAEshifter{}_{}_{}_{}.pth".format(args.shifter, 1, args.vid_type, file_id)
    
        model = AutoencoderPredictor777(num_neurons=args.num_neurons, 
                                        latent_size=args.vis_latent_dim,
                                        c1=64, c2=128, c3=256).to(device)

        train_loss_list, val_loss_list = train_model()

    # np.save("/hdd/aiwenxu/net_weights_final/vis_autoencoder/loss_777_c1_64_{}.npy".format(file_id), 
    #         np.array([train_loss_list, val_loss_list]))

070921_J553RT


AcceleratorError: CUDA error: invalid device ordinal
GPU device may be out of range, do you have enough GPUs?
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


eval

In [9]:
# default is smoothing with 2 second, 48 ms per frame
def smoothing_with_np_conv(nsp, size=int(2000/48)):
    np_conv_res = []
    for i in range(nsp.shape[1]):
        np_conv_res.append(np.convolve(nsp[:, i], np.ones(size)/size, mode="same"))        
    np_conv_res = np.transpose(np.array(np_conv_res))
    return np_conv_res

In [10]:
def evaluate_model(model, weights_path, dataset, device):

    dl = DataLoader(dataset=dataset, batch_size=256, shuffle=False, num_workers=4)
    
    model.load_state_dict(torch.load(weights_path))

    ground_truth_all = []
    pred_all = []
    
    model.eval()
    
    with torch.no_grad():      
        
        for (image, behav, spikes) in dl:
            
            image = image.to(device)
            behav = behav.to(device)
            
            image = torch.squeeze(image, axis=1)
            
            pred, recon = model(image, behav)
            
            ground_truth_all.append(spikes.numpy())
            pred_all.append(pred.cpu().numpy())
    
    return np.concatenate(pred_all, axis=0), np.concatenate(ground_truth_all, axis=0)

In [None]:
# 64 128 256
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
    
for file_id, num_neurons in [("070921_J553RT", 68), ("110421_J569LT", 32), ("101521_J559NC", 49)]:
    
    for shifter in [True, False]:
        print(file_id)
        args.vid_type = 'vid_mean'
        args.file_id = file_id
        args.num_neurons = num_neurons
        args.shifter=shifter
        args.best_train_path = "/home/herbelinluke/Downloads/trainAEshifter{}_{}_{}_{}.pth".format(args.shifter, 1, args.vid_type, file_id)
        args.best_val_path = "/home/herbelinluke/Downloads/valAEshifter{}_{}_{}_{}.pth".format(args.shifter, 1, args.vid_type, file_id)
        model = AutoencoderPredictor777(num_neurons=args.num_neurons, 
                                    latent_size=args.vis_latent_dim,
                                    c1=64, c2=128, c3=256).to(device)
    
        train_ds, val_ds = load_train_val_ds()
        test_ds = load_test_ds()
    
        pred, label = evaluate_model(model, weights_path=args.best_val_path, dataset=test_ds, device=device)
        cor_array = cor_in_time(pred, label)
    #     print("best val model on test dataset, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
        pred = smoothing_with_np_conv(pred)
        label = smoothing_with_np_conv(label)
        print("R2", "{:.6f}".format(r2_score(label.T, pred.T)))
        print("MSE", "{:.6f}".format(mean_squared_error(label, pred)))
        cor_array = cor_in_time(pred, label)
        print("mean corr, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
        print("max corr", "{:.6f}".format(np.max(cor_array)))
        print("min corr", "{:.6f}".format(np.min(cor_array)))

070921_J553RT
30120 7540
R2 0.777193
MSE 0.071034
mean corr, 0.555+-0.140
max corr 0.816352
min corr 0.177165
070921_J553RT
30120 7540
R2 0.765927
MSE 0.076842
mean corr, 0.509+-0.144
max corr 0.773564
min corr 0.115169
110421_J569LT
32940 8240
R2 0.224810
MSE 0.112353
mean corr, 0.370+-0.145
max corr 0.651283
min corr 0.048037
110421_J569LT
32940 8240
R2 0.210233
MSE 0.114998
mean corr, 0.375+-0.148
max corr 0.661694
min corr 0.110536
101521_J559NC
42410 10610
R2 0.727099
MSE 0.097395
mean corr, 0.521+-0.144
max corr 0.822660
min corr 0.242436
101521_J559NC
42410 10610
R2 0.663554
MSE 0.117084
mean corr, 0.459+-0.140
max corr 0.749090
min corr 0.166779
