In [1]:
import torch
import torch.nn as nn
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
from fix_models.feature_extractors import get_video_feature_extractor, VideoFeatureExtractor
from fix_models.readouts import PoissonGaussianReadout, PoissonLinearReadout

# neural activity embedding model 
class NeuralEmbedder(nn.Module):
    def __init__(self, num_neurons, num_layers = 3, hidden_size = 16, embed_size = 8, device=torch.device("cpu")):
        super().__init__()
        
        self.device = device
        self.num_neurons = num_neurons
        self.embed_size = embed_size
        
        self.linear1 = nn.Linear(num_neurons, hidden_size, device=device)
        self.linear2 = nn.Linear(hidden_size, hidden_size, device=device)
        self.linear3 = nn.Linear(hidden_size, embed_size, device=device)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.act(self.linear1(x))
        x = x + self.act(self.linear2(x))
        x = self.linear3(x)
        
        return x

# video embedding model
class VideoEmbedder(nn.Module):
    def __init__(self, modality, layer, stim_shape, train_dataset, feat_ext_type = 'resnet3d', use_pool = False, pool_size = 2, pool_stride = 2, use_pretrained = True, freeze_weights=True, flatten_time = False, device=torch.device("cpu")):
        super().__init__()
        num_neurons = len(train_dataset[0][1])

        feat_ext = get_video_feature_extractor(layer=layer, mod_type=feat_ext_type, device=device, use_pretrained=use_pretrained, freeze_weights=freeze_weights)
        feat_ext = VideoFeatureExtractor(feat_ext, stim_shape, device=device)
        
        readout_input = feat_ext(train_dataset[0][0].unsqueeze(0).to(device))
        num_input  = readout_input.shape[1] * readout_input.shape[2]
        
        feat_to_embed = FeatToEmbed(use_pool = use_pool, pool_size = pool_size, pool_stride= pool_stride, device=device)
        neu_embed = NeuralEmbedder(num_input, device=device)
        
        self.model = nn.Sequential(
            feat_ext,
            feat_to_embed,
            neu_embed
        )
            
        print(f"readout input shape: {num_input}")

    def forward(self, x):
        return self.model(x)

class FeatToEmbed(nn.Module):
    def __init__(self, use_pool = False, pool_size = 2, pool_stride = 2, device=torch.device("cpu")):
        super().__init__()

        self.device = device
        self.use_pool = use_pool
        
        # pooling size
        self.pool = nn.AvgPool2d(pool_size, stride=pool_stride, padding=int(pool_size/2), count_include_pad=False)
        
    def forward(self, x):
        n_batch, n_channel, n_time, width, height = x.shape
        x = x.view(n_batch, n_channel * n_time, width, height)
        
        if self.use_pool:
            x = self.pool(x)

        grid = torch.zeros((x.shape[0], 1, 1, 2), device=self.device)
        grid = torch.clamp(grid, min=-1, max=1) # clamp to ensure within feature map

        x = torch.squeeze(torch.squeeze(F.grid_sample(x, grid, align_corners=False), -1), -1)        
        
        return x

In [2]:
# imports 
import torch
import wandb
import numpy as np
from torch.nn import PoissonNLLLoss
from fix_models.metrics import get_decoder_accuracy

from fix_models.datasets import get_datasets_and_loaders

In [3]:
# config

# all parameters
config = dict()
config["modality"] = "video" # or image

# paths
input_dir = f'../data/{config["modality"]}/'
stimulus_dir = f'../data/{config["modality"]}/stimuli/'
embedding_dir = f'../data/{config["modality"]}/embeddings/'
model_output_path = f'../data/{config["modality"]}/model_output/results'

# dataset and dataloader hyperparameters 
config["win_size"] = 240
config['pos'] = (400, 180)
config["feat_ext_type"] = 'resnet3d'
config["stim_size"] = 32 
config["stim_dur_ms"] = 200
config["stim_shape"] = (1, 3, 5, config["stim_size"], config["stim_size"])
config["first_frame_only"] = False
config["exp_var_thresholds"] = [0.25, 0.25, 0.25]
config["batch_size"] = 16

# model hyperparameters
config["layer"] = "layer2"
config["use_sigma"] = True
config["center_readout"] = False
config["use_pool"] = True
config["pool_size"] = 4
config["pool_stride"] = 2
config["use_pretrained"] = True
config["flatten_time"] = True

# training parameters 
config["lr"] = 0.001 
config["num_epochs"] = 20
config["l2_weight"] = 0

# logging
config["wandb"] = False

# save model
config["save"] = True

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# session names
session_ids = ["082824", "082924", "083024"]

In [4]:
import torch
import torch.nn.functional as F

# note - this loss function was written by chatgpt with some edits
def triplet_loss(vid_embed, neu_embed, alpha):
    """
    Compute the triplet loss for given video and neural embeddings.
    
    Args:
        vid_embed (torch.Tensor): Tensor of shape (batch_size, embed_size) for video embeddings.
        neu_embed (torch.Tensor): Tensor of shape (batch_size, embed_size) for neural embeddings.
        alpha (float): Margin value for the triplet loss.
    
    Returns:
        torch.Tensor: Scalar loss value.
    """
    # Compute pairwise distances
    #vid_embed = F.normalize(vid_embed, p=2, dim=1)
    #neu_embed = F.normalize(neu_embed, p=2, dim=1)

    vid_embed_norm = vid_embed.unsqueeze(1)  # Shape: (batch_size, 1, embed_size)
    neu_embed_norm = neu_embed.unsqueeze(0)  # Shape: (1, batch_size, embed_size)
    pairwise_dist = torch.sum((vid_embed_norm - neu_embed_norm) ** 2, dim=2)  # Shape: (batch_size, batch_size)

    # Find the "challenging negatives"
    # Set diagonal to a large value to exclude positives
    pairwise_dist.fill_diagonal_(float('inf'))  
    challenging_negatives_idx = torch.argmin(pairwise_dist, dim=1)  # Shape: (batch_size,)
    shuffled_neu_embed = neu_embed[challenging_negatives_idx]  # Shape: (batch_size, embed_size)

    # Compute distances for positives and negatives
    pos_dist = torch.sum((vid_embed - neu_embed) ** 2, dim=1)  # Shape: (batch_size,)
    neg_dist = torch.sum((vid_embed - shuffled_neu_embed) ** 2, dim=1)  # Shape: (batch_size,)

    # Compute triplet loss
    loss = F.relu(pos_dist - neg_dist + alpha)  # Shape: (batch_size,)
    return loss.mean()  # Scalar loss value


def train_model(full_vid, full_neu, model_name):
    # corr avgs
    corr_avgs = []

    config['model_name'] = model_name

    print(config['l2_weight'])
    
    for ses_idx, session_id in enumerate(session_ids):
        # set sess_corr_avg
        sess_corr_avg = -1
        sess_corrs = []

        # set session index 
        config["session_id"] = session_id

        # setup logging
        if config["wandb"]:
            wandb.init(
                project=f'{config["modality"]}-cs230-decode',
                config=config,
            )
            wandb.define_metric("decode_acc", summary="max")
            wandb.define_metric("test_loss", summary="min")

        # load datasets and loaders 
        train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = config['pos'], test_bs=True)
        _, _, _, test_loader_single = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = config['pos'], test_bs=False)

        full_vid_embedder = full_vid(train_dataset)
        full_neu_embedder = full_neu(len(train_dataset[0][1]))

        # set which parameters to use regularization with and which not to
        params_with_l2 = []
        params_without_l2 = []
        for name, param in full_vid_embedder.named_parameters():
            if 'bias' in name:
                params_without_l2.append(param)
            else:
                params_with_l2.append(param)

        # setup Adam optimizer
        vid_optimizer = torch.optim.Adam([
            {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
            {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
        ], lr=config["lr"], weight_decay=config['l2_weight'])
        
        params_with_l2 = []
        params_without_l2 = []
        for name, param in full_neu_embedder.named_parameters():
            if 'bias' in name:
                params_without_l2.append(param)
            else:
                params_with_l2.append(param)

        neu_optimizer = torch.optim.Adam([
            {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
            {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
        ], lr=config["lr"], weight_decay=config['l2_weight'])
    
        # using triplet loss   
        alpha = 0.1
        
        for epochs in range(config["num_epochs"]):
            epoch_loss = 0
            for i, (stimulus, targets) in enumerate(train_loader): 
                vids = stimulus.to(device)
                neus = targets.to(device)

                vid_optimizer.zero_grad()
                neu_optimizer.zero_grad()
                
                vid_embed = full_vid_embedder(vids)
                neu_embed = full_neu_embedder(neus)

                loss = triplet_loss(vid_embed, neu_embed, alpha) + triplet_loss(neu_embed, vid_embed, alpha)
                loss.backward()
                vid_optimizer.step()
                neu_optimizer.step()
                
                epoch_loss += loss.item()
    
            # printing corr to avg and loss metrics 
            with torch.no_grad():
                decode_acc = get_decoder_accuracy(full_vid_embedder, full_neu_embedder, test_loader_single, modality=config["modality"], device=device)
                test_loss = 0
                for i, (stimulus, targets) in enumerate(test_loader):
                    vids = stimulus.to(device)
                    neus = targets.to(device)
                    vid_embed = full_vid_embedder(vids)
                    neu_embed = full_neu_embedder(neus)
                    loss = triplet_loss(vid_embed, neu_embed, alpha) + triplet_loss(neu_embed, vid_embed, alpha)
                    test_loss += loss.item()
                    
            if config["wandb"]:
                wandb.log({"decode_acc": np.nanmean(decode_acc), "train_loss": epoch_loss / len(train_loader), "test_loss": test_loss / len(test_loader)})
            
            if np.nanmean(decode_acc) > sess_corr_avg:
                sess_corr_avg = np.nanmean(decode_acc)
                sess_corrs = decode_acc
                
            print('  epoch {} loss: {} decode acc: {}'.format(epochs + 1, epoch_loss / len(train_dataset), np.nanmean(decode_acc)))
            #print(f' num. neurons : {len(decode_acc)}')
            
        if config["save"]:
            torch.save(full_model.state_dict(), f"{model_output_path}_{session_id}.pickle")
            
        corr_avgs.append(sess_corrs)
        
        if config["wandb"]:
            wandb.finish()
    
    if config["wandb"]:
        wandb.init(
            project=f'{config["modality"]}-cs230-decode',
            config=config,
        )
        for corr in corr_avgs:
            wandb.log({"decode_accs": corr})
        wandb.finish()

In [5]:
full_vid_fcn = lambda train_dataset: VideoEmbedder(feat_ext_type = 'resnet3d', freeze_weights=True, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
full_neu_fcn = lambda num_neurons: NeuralEmbedder(num_neurons, device=device)

train_model(full_vid_fcn, full_neu_fcn, "frozen pretrained")


0
readout input shape: 384
  epoch 1 loss: 0.018660290329544634 decode acc: 1.0
  epoch 2 loss: 0.012940233651502632 decode acc: 1.0
  epoch 3 loss: 0.012748880680696463 decode acc: 1.0
  epoch 4 loss: 0.012685570223831837 decode acc: 1.0
  epoch 5 loss: 0.01265096715203038 decode acc: 1.0
  epoch 6 loss: 0.01262990712383647 decode acc: 1.0
  epoch 7 loss: 0.01261837148372038 decode acc: 1.0
  epoch 8 loss: 0.012607208711129648 decode acc: 1.0
  epoch 9 loss: 0.012594342040426938 decode acc: 1.0
  epoch 10 loss: 0.012589485203778302 decode acc: 1.0
  epoch 11 loss: 0.012583673898084664 decode acc: 1.0
  epoch 12 loss: 0.012580772034915876 decode acc: 1.0
  epoch 13 loss: 0.012577647831704881 decode acc: 1.0
  epoch 14 loss: 0.012577152693713152 decode acc: 1.0
  epoch 15 loss: 0.01257584739614416 decode acc: 1.0


KeyboardInterrupt: 