In [None]:
# video reconstruction overall plan
# 1. train a generator model to reconstruct video from activations of ResNet3D

# 2. train a linear fxn to map neural activity to activations

# 3. use to generate videos 

In [45]:
# config

# all parameters
config = dict()
config["modality"] = "video" # or image

# paths
input_dir = f'../data/{config["modality"]}/'
stimulus_dir = f'../data/{config["modality"]}/stimuli/'
embedding_dir = f'../data/{config["modality"]}/embeddings/'
model_output_path = f'../data/{config["modality"]}/model_output/results'

# dataset and dataloader hyperparameters 
config["win_size"] = 240
config['pos'] = (400, 180)
config["feat_ext_type"] = 'resnet3d'
config["stim_size"] = 32 
config["stim_dur_ms"] = 200
config["stim_shape"] = (1, 3, 5, config["stim_size"], config["stim_size"])
config["first_frame_only"] = False
config["exp_var_thresholds"] = [0.25, 0.25, 0.25] #[0.25, 0.25, 0.25]
config["batch_size"] = 16

# model hyperparameters
config["layer"] = "layer2"
config["use_sigma"] = True
config["center_readout"] = False
config["use_pool"] = True
config["pool_size"] = 4
config["pool_stride"] = 2
config["use_pretrained"] = True
config["flatten_time"] = True

# training parameters 
config["lr"] = 0.001 
config["num_epochs"] = 20
config["l2_weight"] = 0

# logging
config["wandb"] = True

# save model
config["save"] = True

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# session names
session_ids = ["082824", "082924", "083024"]

In [46]:
import torch
from torchvision.models.video import r3d_18
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

In [75]:
feat_ext = create_feature_extractor(r3d_18(), return_nodes={  'layer2.0.conv1.1': 'layer'}).to(device)
for param in feat_ext.parameters():
    param.requires_grad = False

In [78]:
feat_ext(torch.zeros((1,3,5,32,32), device=device))['layer'].shape

torch.Size([1, 128, 3, 8, 8])

In [93]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, output_channels=3, vid_size=32):
        super(Generator, self).__init__()
        self.conv1 = nn.ConvTranspose3d(128, 128, 4, 2, 1, device=device)
        self.conv2 = nn.ConvTranspose3d(128, 64, 4, 2, 1, device=device)
        self.conv3 = nn.Conv3d(64, output_channels, 3, 1, 1, device=device)
        self.conv4 = nn.Conv3d(output_channels, output_channels, 3, 1, 1, device=device)
        self.conv5 = nn.Conv3d(12, 5, 3, 1, 1, device=device)
        self.sig = nn.Sigmoid()
        self.relu = nn.ReLU()
        
        self.vid_size = vid_size

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = x.permute(0, 2, 1, 3, 4)
        x = (self.conv5(x)).permute(0, 2, 1, 3, 4)
        x = self.sig(x)*4 - 2
        return x

# generator maps from torch.Size([1, 128, 3, 8, 8]) to torch.Size([1,3,5,32,32])
g = Generator()
print(g(torch.zeros((1,128,3,8,8), device=device)).shape)

torch.Size([1, 3, 5, 32, 32])


In [94]:
# train the generator 
from fix_models.datasets import get_datasets_and_loaders

# get the dataloaders for videos 
ses_idx = 0
session_id = session_ids[ses_idx]
datasets = []
loaders = []
for i in range(120, 520, 20):
    for j in range(120, 240, 20):
        train_dataset, _, train_loader, _ = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = (i, j), test_bs=True)
        loaders.append(train_loader)

In [None]:
from tqdm import tqdm
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(g.parameters(), lr=0.0002)

num_epochs = len(loaders)
for epoch, loader in enumerate(loaders):
    losses = 0
    for images, targets in tqdm(loader):
        optimizer.zero_grad()
        images = images.to(device)
        feats = feat_ext(images)
        reconstructed_images = g(feats['layer'])
        loss = criterion(reconstructed_images, images)
        losses += loss.item()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {losses:.4f}")

print("Training complete.")

100% 127/127 [00:06<00:00, 19.11it/s]


Epoch [1/120], Loss: 124.8757


100% 127/127 [00:06<00:00, 19.04it/s]


Epoch [2/120], Loss: 61.7440


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [3/120], Loss: 23.9900


100% 127/127 [00:06<00:00, 18.95it/s]


Epoch [4/120], Loss: 20.2023


100% 127/127 [00:06<00:00, 19.02it/s]


Epoch [5/120], Loss: 17.3144


100% 127/127 [00:06<00:00, 19.03it/s]


Epoch [6/120], Loss: 15.2031


100% 127/127 [00:06<00:00, 19.01it/s]


Epoch [7/120], Loss: 15.4717


100% 127/127 [00:06<00:00, 19.07it/s]


Epoch [8/120], Loss: 14.3226


100% 127/127 [00:06<00:00, 19.03it/s]


Epoch [9/120], Loss: 13.8025


100% 127/127 [00:06<00:00, 19.04it/s]


Epoch [10/120], Loss: 13.0804


100% 127/127 [00:06<00:00, 19.08it/s]


Epoch [11/120], Loss: 12.2668


100% 127/127 [00:06<00:00, 18.93it/s]


Epoch [12/120], Loss: 12.0861


100% 127/127 [00:06<00:00, 19.04it/s]


Epoch [13/120], Loss: 12.2204


100% 127/127 [00:06<00:00, 19.01it/s]


Epoch [14/120], Loss: 12.0714


100% 127/127 [00:06<00:00, 19.03it/s]


Epoch [15/120], Loss: 11.6572


100% 127/127 [00:06<00:00, 19.02it/s]


Epoch [16/120], Loss: 11.0764


100% 127/127 [00:06<00:00, 18.95it/s]


Epoch [17/120], Loss: 10.7689


100% 127/127 [00:06<00:00, 19.08it/s]


Epoch [18/120], Loss: 10.3270


100% 127/127 [00:06<00:00, 19.07it/s]


Epoch [19/120], Loss: 11.1234


100% 127/127 [00:06<00:00, 19.02it/s]


Epoch [20/120], Loss: 10.8014


100% 127/127 [00:06<00:00, 19.11it/s]


Epoch [21/120], Loss: 10.3395


100% 127/127 [00:06<00:00, 19.00it/s]


Epoch [22/120], Loss: 10.2168


100% 127/127 [00:06<00:00, 19.11it/s]


Epoch [23/120], Loss: 9.6265


100% 127/127 [00:06<00:00, 18.99it/s]


Epoch [24/120], Loss: 9.6440


100% 127/127 [00:06<00:00, 19.15it/s]


Epoch [25/120], Loss: 10.2835


100% 127/127 [00:06<00:00, 19.02it/s]


Epoch [26/120], Loss: 9.9711


100% 127/127 [00:06<00:00, 18.64it/s]


Epoch [27/120], Loss: 9.9826


100% 127/127 [00:06<00:00, 18.89it/s]


Epoch [28/120], Loss: 9.8356


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [29/120], Loss: 9.1992


100% 127/127 [00:06<00:00, 18.84it/s]


Epoch [30/120], Loss: 8.9807


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [31/120], Loss: 9.7883


100% 127/127 [00:06<00:00, 18.71it/s]


Epoch [32/120], Loss: 9.3972


100% 127/127 [00:06<00:00, 18.57it/s]


Epoch [33/120], Loss: 9.5123


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [34/120], Loss: 9.2246


100% 127/127 [00:06<00:00, 19.03it/s]


Epoch [35/120], Loss: 8.9580


100% 127/127 [00:06<00:00, 18.99it/s]


Epoch [36/120], Loss: 8.5228


100% 127/127 [00:06<00:00, 19.05it/s]


Epoch [37/120], Loss: 9.6460


100% 127/127 [00:06<00:00, 18.82it/s]


Epoch [38/120], Loss: 9.4986


100% 127/127 [00:06<00:00, 19.03it/s]


Epoch [39/120], Loss: 8.9672


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [40/120], Loss: 8.5626


100% 127/127 [00:06<00:00, 19.10it/s]


Epoch [41/120], Loss: 8.6256


100% 127/127 [00:06<00:00, 18.95it/s]


Epoch [42/120], Loss: 8.2485


100% 127/127 [00:06<00:00, 19.04it/s]


Epoch [43/120], Loss: 9.0365


100% 127/127 [00:06<00:00, 19.12it/s]


Epoch [44/120], Loss: 8.7884


100% 127/127 [00:06<00:00, 18.94it/s]


Epoch [45/120], Loss: 8.7795


100% 127/127 [00:06<00:00, 18.87it/s]


Epoch [46/120], Loss: 8.3995


100% 127/127 [00:07<00:00, 18.13it/s]


Epoch [47/120], Loss: 8.2795


100% 127/127 [00:06<00:00, 18.67it/s]


Epoch [48/120], Loss: 7.9499


100% 127/127 [00:06<00:00, 18.49it/s]


Epoch [49/120], Loss: 8.7250


100% 127/127 [00:06<00:00, 18.97it/s]


Epoch [50/120], Loss: 8.6119


100% 127/127 [00:06<00:00, 18.98it/s]


Epoch [51/120], Loss: 8.5587


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [52/120], Loss: 8.1315


100% 127/127 [00:06<00:00, 19.04it/s]


Epoch [53/120], Loss: 7.9699


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [54/120], Loss: 7.7346


100% 127/127 [00:06<00:00, 18.92it/s]


Epoch [55/120], Loss: 8.7149


100% 127/127 [00:06<00:00, 19.09it/s]


Epoch [56/120], Loss: 8.5332


100% 127/127 [00:06<00:00, 19.03it/s]


Epoch [57/120], Loss: 8.2330


100% 127/127 [00:06<00:00, 19.04it/s]


Epoch [58/120], Loss: 8.0371


100% 127/127 [00:06<00:00, 18.99it/s]


Epoch [59/120], Loss: 7.7519


100% 127/127 [00:06<00:00, 19.07it/s]


Epoch [60/120], Loss: 7.6563


100% 127/127 [00:06<00:00, 19.04it/s]


Epoch [61/120], Loss: 8.4350


100% 127/127 [00:06<00:00, 19.04it/s]


Epoch [62/120], Loss: 8.0485


100% 127/127 [00:06<00:00, 19.02it/s]


Epoch [63/120], Loss: 8.0177


100% 127/127 [00:06<00:00, 19.13it/s]


Epoch [64/120], Loss: 7.8148


100% 127/127 [00:06<00:00, 18.97it/s]


Epoch [65/120], Loss: 7.6561


100% 127/127 [00:06<00:00, 19.02it/s]


Epoch [66/120], Loss: 7.5442


100% 127/127 [00:06<00:00, 18.97it/s]


Epoch [67/120], Loss: 8.3485


100% 127/127 [00:06<00:00, 18.99it/s]


Epoch [68/120], Loss: 7.9975


100% 127/127 [00:06<00:00, 18.90it/s]


Epoch [69/120], Loss: 7.6934


100% 127/127 [00:06<00:00, 18.96it/s]


Epoch [70/120], Loss: 7.6391


100% 127/127 [00:07<00:00, 17.62it/s]


Epoch [71/120], Loss: 7.2874


100% 127/127 [00:06<00:00, 18.85it/s]


Epoch [72/120], Loss: 7.2850


100% 127/127 [00:07<00:00, 17.89it/s]


Epoch [73/120], Loss: 7.8879


100% 127/127 [00:06<00:00, 18.49it/s]


Epoch [74/120], Loss: 7.7394


100% 127/127 [00:06<00:00, 18.99it/s]


Epoch [75/120], Loss: 7.5192


100% 127/127 [00:06<00:00, 19.16it/s]


Epoch [76/120], Loss: 7.3895


100% 127/127 [00:06<00:00, 19.09it/s]


Epoch [77/120], Loss: 7.3065


100% 127/127 [00:06<00:00, 19.16it/s]


Epoch [78/120], Loss: 7.1048


100% 127/127 [00:06<00:00, 19.11it/s]


Epoch [79/120], Loss: 7.7027


100% 127/127 [00:06<00:00, 19.08it/s]


Epoch [80/120], Loss: 7.6697


100% 127/127 [00:06<00:00, 19.10it/s]


Epoch [81/120], Loss: 7.7022


100% 127/127 [00:06<00:00, 19.05it/s]


Epoch [82/120], Loss: 7.3015


100% 127/127 [00:06<00:00, 19.02it/s]


Epoch [83/120], Loss: 7.1612


100% 127/127 [00:06<00:00, 18.98it/s]


Epoch [84/120], Loss: 6.9685


100% 127/127 [00:06<00:00, 19.08it/s]


Epoch [85/120], Loss: 7.6103


100% 127/127 [00:06<00:00, 19.17it/s]


Epoch [86/120], Loss: 7.1986


100% 127/127 [00:06<00:00, 18.99it/s]


Epoch [87/120], Loss: 7.2786


100% 127/127 [00:06<00:00, 19.12it/s]


Epoch [88/120], Loss: 7.0209


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [89/120], Loss: 6.9980


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [90/120], Loss: 6.8611


100% 127/127 [00:06<00:00, 19.05it/s]


Epoch [91/120], Loss: 7.8229


100% 127/127 [00:06<00:00, 19.10it/s]


Epoch [92/120], Loss: 7.4048


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [93/120], Loss: 7.2225


100% 127/127 [00:06<00:00, 19.10it/s]


Epoch [94/120], Loss: 6.9729


100% 127/127 [00:06<00:00, 19.08it/s]


Epoch [95/120], Loss: 6.8771


100% 127/127 [00:06<00:00, 18.97it/s]


Epoch [96/120], Loss: 6.7062


100% 127/127 [00:06<00:00, 19.10it/s]


Epoch [97/120], Loss: 7.3221


100% 127/127 [00:06<00:00, 19.08it/s]


Epoch [98/120], Loss: 7.2344


100% 127/127 [00:06<00:00, 19.12it/s]


Epoch [99/120], Loss: 7.1014


100% 127/127 [00:06<00:00, 19.06it/s]


Epoch [100/120], Loss: 6.9036


100% 127/127 [00:06<00:00, 19.05it/s]


Epoch [101/120], Loss: 6.7169


100% 127/127 [00:06<00:00, 19.11it/s]


Epoch [102/120], Loss: 6.6135


100% 127/127 [00:06<00:00, 19.12it/s]


Epoch [103/120], Loss: 7.3205


100% 127/127 [00:06<00:00, 19.08it/s]


Epoch [104/120], Loss: 6.9657


100% 127/127 [00:06<00:00, 19.01it/s]


Epoch [105/120], Loss: 6.9066


 90% 114/127 [00:05<00:00, 21.03it/s]

In [None]:
torch.save(g.state_dict(), "./generator.pt")

In [None]:
import matplotlib.pyplot as plt 

# Step 2: Define a function for visualization
def visualize_reconstruction(generator, feat_ext, test_loader):
    model.eval()

    # Get a single test image
    for images, _ in test_loader:
        test_image = images[0].unsqueeze(0)  # Add batch dimension
        break

    # Pass the image through the model to get clean activations
    with torch.no_grad():
        acts = feat_ext(test_image.to(device))['layer']

    # Generate reconstruction from clean activations
    reconstruction = generator(acts)

    # Convert images to numpy arrays for visualization
    test_image_np = test_image[:, :, 0, :, :].squeeze(0).permute(1, 2, 0).cpu().numpy()
    reconstruction_np = clean_reconstruction[:, :, 0, :, :].squeeze(0).permute(1, 2, 0).detach().cpu().numpy()

    MEAN = np.array([0.485, 0.456, 0.406])
    STD = np.array([0.229, 0.224, 0.225])

    # Clip and normalize for visualization
    test_image_np = test_image_np * STD + mean
    test_image_np = np.clip(test_image_np, 0, 1)
    reconstruction_np = reconstruction_np * STD + mean
    reconstruction_np = np.clip(reconstruction_np, 0, 1)

    # Plot the results
    plt.figure(figsize=(8, 4))

    plt.subplot(1, 2, 1)
    plt.imshow(test_image_np)
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(reconstruction_np)
    plt.title("Reconstruction (No Noise)")
    plt.axis("off")
    
    plt.show()

# Step 3: Visualize the reconstruction
for _ in range(10):
  visualize_reconstruction(g, feat_ext, test_loader)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet3d
from torch.utils.data import DataLoader, Dataset
import numpy as np

from tqdm import tqdm

# Step 1: Load and preprocess CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

train_dataset, _, train_loader, _ = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = config['pos'], test_bs=True)

# Step 2: Load pretrained ResNet18
resnet = resnet18(pretrained=True)
resnet.eval()  # Set to evaluation mode

# Select a particular layer for activations (e.g., layer4[1].conv2)
layer_to_hook = resnet.layer3[0].relu

# Define a hook to capture activations
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

layer_to_hook.register_forward_hook(get_activation('selected_layer'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 3: Define a function to create noisy samples from Poisson processes
def poisson_sample(activations):
    rate_params = activations.view(activations.size(0), -1).cpu().numpy()  # Flatten the activations

    noisy_samples = np.random.poisson(rate_params).astype(np.float32)
    return torch.tensor(noisy_samples)

# Step 4: Create custom dataset of image and samples from Poisson processes
class ImagePoissonDataset(Dataset):
    def __init__(self, data_loader, model, layer_name):
        self.data_loader = data_loader
        self.model = model.to(device)
        self.layer_name = layer_name
        self.images = []
        self.samples = []
        self.create_dataset()

    def create_dataset(self):
        for images, _ in tqdm(self.data_loader):
            if len(self.images) <= len(train_dataset)*1/2:
              with torch.no_grad():
                  _ = self.model(images.to(device))
                  activations = activation[self.layer_name]
                  self.images.extend(images)
                  self.samples.extend(poisson_sample(activations))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.samples[idx], self.images[idx]

poisson_dataset = ImagePoissonDataset(train_loader, resnet, 'selected_layer')
poisson_loader = DataLoader(poisson_dataset, batch_size=32, shuffle=True)


In [1]:
import torch
import torch.nn as nn
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn.functional as F
from fix_models.feature_extractors import get_video_feature_extractor, VideoFeatureExtractor
from fix_models.readouts import PoissonGaussianReadout, PoissonLinearReadout

# neural activity embedding model 
class NeuralEmbedder(nn.Module):
    def __init__(self, num_neurons, num_layers = 3, hidden_size = 16, embed_size = 16, device=torch.device("cpu")):
        super().__init__()
        
        self.device = device
        self.num_neurons = num_neurons
        self.embed_size = num_neurons #embed_size #um_neurons #embed_size
        
        self.linear1 = nn.Linear(num_neurons, hidden_size, device=device)
        self.linear2 = nn.Linear(hidden_size, hidden_size, device=device)
        self.linear3 = nn.Linear(hidden_size, embed_size, device=device)

        self.linear = nn.Linear(num_neurons, embed_size, device=device)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.act(self.linear1(x))
        x = x + self.act(self.linear2(x))
        x = self.linear3(x)
        x = (self.linear(x))
        
        return x

# video embedding model
class VideoEmbedder(nn.Module):
    def __init__(self, modality, layer, stim_shape, train_dataset, feat_ext_type = 'resnet3d', use_pool = False, pool_size = 2, pool_stride = 2, use_pretrained = True, freeze_weights=True, flatten_time = False, device=torch.device("cpu")):
        super().__init__()
        num_neurons = len(train_dataset[0][1])

        feat_ext = get_video_feature_extractor(layer=layer, mod_type=feat_ext_type, device=device, use_pretrained=use_pretrained, freeze_weights=freeze_weights)
        feat_ext = VideoFeatureExtractor(feat_ext, stim_shape, device=device)
        
        readout_input = feat_ext(train_dataset[0][0].unsqueeze(0).to(device))
        num_input  = readout_input.shape[1] * readout_input.shape[2]
        
        feat_to_embed = FeatToEmbed(use_pool = use_pool, pool_size = pool_size, pool_stride= pool_stride, device=device)
        #neu_embed = NeuralEmbedder(num_neurons, device=device) #num_input, device=device)
        neu_embed = nn.Linear(num_input, num_neurons, device=device)
        
        self.model = nn.Sequential(
            feat_ext,
            feat_to_embed,
            neu_embed
        )
            
        print(f"readout input shape: {num_input}")

        self.act = nn.ReLU()
    def forward(self, x):
        return self.act(self.model(x)) + 1

class FeatToEmbed(nn.Module):
    def __init__(self, use_pool = False, pool_size = 2, pool_stride = 2, device=torch.device("cpu")):
        super().__init__()

        self.device = device
        self.use_pool = use_pool
        
        # pooling size
        self.pool = nn.AvgPool2d(pool_size, stride=pool_stride, padding=int(pool_size/2), count_include_pad=False)
        
    def forward(self, x):
        n_batch, n_channel, n_time, width, height = x.shape
        x = x.view(n_batch, n_channel * n_time, width, height)
        
        if self.use_pool:
            x = self.pool(x)

        grid = torch.zeros((x.shape[0], 1, 1, 2), device=self.device)
        grid = torch.clamp(grid, min=-1, max=1) # clamp to ensure within feature map

        x = torch.squeeze(torch.squeeze(F.grid_sample(x, grid, align_corners=False), -1), -1)        
        
        return x

In [2]:
# imports 
import torch
import wandb
import numpy as np
from torch.nn import PoissonNLLLoss
from fix_models.metrics import get_decoder_accuracy

from fix_models.datasets import get_datasets_and_loaders

In [3]:
# config

# all parameters
config = dict()
config["modality"] = "video" # or image

# paths
input_dir = f'../data/{config["modality"]}/'
stimulus_dir = f'../data/{config["modality"]}/stimuli/'
embedding_dir = f'../data/{config["modality"]}/embeddings/'
model_output_path = f'../data/{config["modality"]}/model_output/results'

# dataset and dataloader hyperparameters 
config["win_size"] = 240
config['pos'] = (400, 180)
config["feat_ext_type"] = 'resnet3d'
config["stim_size"] = 32 
config["stim_dur_ms"] = 200
config["stim_shape"] = (1, 3, 5, config["stim_size"], config["stim_size"])
config["first_frame_only"] = False
config["exp_var_thresholds"] = [0.25, 0.25, 0.25] #[0.25, 0.25, 0.25]
config["batch_size"] = 16

# model hyperparameters
config["layer"] = "layer2"
config["use_sigma"] = True
config["center_readout"] = False
config["use_pool"] = True
config["pool_size"] = 4
config["pool_stride"] = 2
config["use_pretrained"] = True
config["flatten_time"] = True

# training parameters 
config["lr"] = 0.001 
config["num_epochs"] = 20
config["l2_weight"] = 0

# logging
config["wandb"] = True

# save model
config["save"] = True

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# session names
session_ids = ["082824", "082924", "083024"]

In [4]:
import torch
import torch.nn.functional as F

# note - this loss function was written by chatgpt with some edits
def triplet_loss(vid_embed, neu_embed, alpha):
    """
    Compute the triplet loss for given video and neural embeddings.
    
    Args:
        vid_embed (torch.Tensor): Tensor of shape (batch_size, embed_size) for video embeddings.
        neu_embed (torch.Tensor): Tensor of shape (batch_size, embed_size) for neural embeddings.
        alpha (float): Margin value for the triplet loss.
    
    Returns:
        torch.Tensor: Scalar loss value.
    """
    # Compute pairwise distances
    #vid_embed = F.normalize(vid_embed, p=2, dim=1)
    #neu_embed = F.normalize(neu_embed, p=2, dim=1)

    vid_embed_norm = vid_embed.unsqueeze(1)  # Shape: (batch_size, 1, embed_size)
    neu_embed_norm = neu_embed.unsqueeze(0)  # Shape: (1, batch_size, embed_size)
    pairwise_dist = torch.sum((vid_embed_norm - neu_embed_norm) ** 2, dim=2)  # Shape: (batch_size, batch_size)

    # Find the "challenging negatives"
    # Set diagonal to a large value to exclude positives
    pairwise_dist.fill_diagonal_(float('inf'))  
    challenging_negatives_idx = torch.argmin(pairwise_dist, dim=1)  # Shape: (batch_size,)
    shuffled_neu_embed = neu_embed[challenging_negatives_idx]  # Shape: (batch_size, embed_size)

    # Compute distances for positives and negatives
    pos_dist = torch.sum((vid_embed - neu_embed) ** 2, dim=1)  # Shape: (batch_size,)
    neg_dist = torch.sum((vid_embed - shuffled_neu_embed) ** 2, dim=1)  # Shape: (batch_size,)

    # Compute triplet loss
    loss = F.relu(pos_dist - neg_dist + alpha)  # Shape: (batch_size,)
    return loss.mean()  # Scalar loss value


def train_model(full_vid, full_neu, model_name):
    # corr avgs
    corr_avgs = []

    config['model_name'] = model_name

    print(config['l2_weight'])
    
    for ses_idx, session_id in enumerate(session_ids):
        # set sess_corr_avg
        sess_corr_avg = -1
        sess_corrs = []

        # set session index 
        config["session_id"] = session_id

        # setup logging
        if config["wandb"]:
            wandb.init(
                project=f'{config["modality"]}-cs230-decode',
                config=config,
            )
            wandb.define_metric("decode_acc", summary="max")
            wandb.define_metric("test_loss", summary="min")

        # load datasets and loaders 
        train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = config['pos'], test_bs=True)
        _, _, _, test_loader_single = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = config['pos'], test_bs=False)

        full_vid_embedder = full_vid(train_dataset)
        full_neu_embedder = full_neu(len(train_dataset[0][1]))

        # set which parameters to use regularization with and which not to
        params_with_l2 = []
        params_without_l2 = []
        for name, param in full_vid_embedder.named_parameters():
            if 'bias' in name:
                params_without_l2.append(param)
            else:
                params_with_l2.append(param)

        # setup Adam optimizer
        vid_optimizer = torch.optim.Adam([
            {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
            {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
        ], lr=config["lr"], weight_decay=config['l2_weight'])
        
        params_with_l2 = []
        params_without_l2 = []
        for name, param in full_neu_embedder.named_parameters():
            if 'bias' in name:
                params_without_l2.append(param)
            else:
                params_with_l2.append(param)

        neu_optimizer = torch.optim.Adam([
            {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
            {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
        ], lr=config["lr"], weight_decay=config['l2_weight'])
    
        # using triplet loss   
        alpha = 0.1
        loss_func = PoissonNLLLoss(log_input=False, full=True)

        for epochs in range(config["num_epochs"]):
            epoch_loss = 0
            for i, (stimulus, targets) in enumerate(train_loader): 
                vids = stimulus.to(device)
                neus = targets.to(device)

                vid_optimizer.zero_grad()
                neu_optimizer.zero_grad()
                
                vid_embed = full_vid_embedder(vids)
                neu_embed = full_neu_embedder(neus)

                #loss = triplet_loss(vid_embed, neu_embed, alpha) #+ triplet_loss(neu_embed, vid_embed, alpha)
                loss = loss_func(vid_embed, neu_embed)
                loss.backward()

                vid_optimizer.step()
                neu_optimizer.step()
                
                epoch_loss += loss.item()
    
            # printing corr to avg and loss metrics 
            with torch.no_grad():
                decode_acc = get_decoder_accuracy(full_vid_embedder, full_neu_embedder, test_loader_single, modality=config["modality"], device=device)
                test_loss = 0
                for i, (stimulus, targets) in enumerate(test_loader):
                    vids = stimulus.to(device)
                    neus = targets.to(device)
                    vid_embed = full_vid_embedder(vids)
                    neu_embed = full_neu_embedder(neus)
                    loss = triplet_loss(vid_embed, neu_embed, alpha) #+ triplet_loss(neu_embed, vid_embed, alpha)
                    test_loss += loss.item()
                    
            if config["wandb"]:
                wandb.log({"decode_acc": np.nanmean(decode_acc), "train_loss": epoch_loss / len(train_loader), "test_loss": test_loss / len(test_loader)})
            
            if np.nanmean(decode_acc) > sess_corr_avg:
                sess_corr_avg = np.nanmean(decode_acc)
                sess_corrs = decode_acc
                
            print('  epoch {} loss: {} decode acc: {}'.format(epochs + 1, epoch_loss / len(train_dataset), np.nanmean(decode_acc)))
            #print(f' num. neurons : {len(decode_acc)}')
            
        #if config["save"]:
        #    torch.save(full_model.state_dict(), f"{model_output_path}_{session_id}.pickle")
            
        corr_avgs.append(sess_corrs)
        
        if config["wandb"]:
            wandb.finish()
    
    if config["wandb"]:
        wandb.init(
            project=f'{config["modality"]}-cs230-decode',
            config=config,
        )
        for corr in corr_avgs:
            wandb.log({"decode_accs": corr})
        wandb.finish()

In [5]:
full_vid_fcn = lambda train_dataset: VideoEmbedder(feat_ext_type = 'resnet3d', freeze_weights=True, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
full_neu_fcn = lambda num_neurons: NeuralEmbedder(num_neurons, device=device)

train_model(full_vid_fcn, full_neu_fcn, "frozen pretrained")

0


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33met22[0m. Use [1m`wandb login --relogin`[0m to force relogin


readout input shape: 384
  epoch 1 loss: 0.12235652758751386 decode acc: 0.4
  epoch 2 loss: 0.11666988143214473 decode acc: 0.425
  epoch 3 loss: 0.11594631830851237 decode acc: 0.525
  epoch 4 loss: 0.11545191882569113 decode acc: 0.625
  epoch 5 loss: 0.11518091713940656 decode acc: 0.65
  epoch 6 loss: 0.11494378407796224 decode acc: 0.675
  epoch 7 loss: 0.11477671752741307 decode acc: 0.7
  epoch 8 loss: 0.11464894265304378 decode acc: 0.675
  epoch 9 loss: 0.114555534845517 decode acc: 0.675
  epoch 10 loss: 0.11443922148810493 decode acc: 0.675
  epoch 11 loss: 0.11430847003136152 decode acc: 0.7
  epoch 12 loss: 0.11425545280362352 decode acc: 0.725
  epoch 13 loss: 0.11421989635184959 decode acc: 0.725
  epoch 14 loss: 0.11407451235217812 decode acc: 0.725
  epoch 15 loss: 0.1139945031389778 decode acc: 0.775
  epoch 16 loss: 0.11393759845215597 decode acc: 0.725
  epoch 17 loss: 0.11387442035439574 decode acc: 0.775
  epoch 18 loss: 0.11379451810577769 decode acc: 0.75
  epo

0,1
decode_acc,▁▁▃▅▆▆▇▆▆▆▇▇▇▇█▇██▇█
test_loss,█▆▆▄▄▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
train_loss,█▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.81438


readout input shape: 384
  epoch 1 loss: 0.11582509591317301 decode acc: 0.425
  epoch 2 loss: 0.11172449326640024 decode acc: 0.475
  epoch 3 loss: 0.11096722617823417 decode acc: 0.5
  epoch 4 loss: 0.11060457728920182 decode acc: 0.625
  epoch 5 loss: 0.11036484790722113 decode acc: 0.675
  epoch 6 loss: 0.11010989206623657 decode acc: 0.7
  epoch 7 loss: 0.10996435649732021 decode acc: 0.725
  epoch 8 loss: 0.10985431664901255 decode acc: 0.725
  epoch 9 loss: 0.10970041146453138 decode acc: 0.75
  epoch 10 loss: 0.10965823971164164 decode acc: 0.7
  epoch 11 loss: 0.10951418021586554 decode acc: 0.7
  epoch 12 loss: 0.10951287640326934 decode acc: 0.65
  epoch 13 loss: 0.10947642725799721 decode acc: 0.725
  epoch 14 loss: 0.10932655565401646 decode acc: 0.725
  epoch 15 loss: 0.10929544646078379 decode acc: 0.75
  epoch 16 loss: 0.1092259022577895 decode acc: 0.725
  epoch 17 loss: 0.10922847686637759 decode acc: 0.675
  epoch 18 loss: 0.1091297878020721 decode acc: 0.675
  epoch

0,1
decode_acc,▁▂▃▅▆▇▇▇█▇▇▆▇▇█▇▆▆▆▇
test_loss,█▆▄▄▄▃▃▃▃▃▂▁▂▃▃▂▂▂▁▂
train_loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.73599


readout input shape: 384
  epoch 1 loss: 0.13410460462375562 decode acc: 0.225
  epoch 2 loss: 0.12483198423774874 decode acc: 0.325
  epoch 3 loss: 0.12399968047425508 decode acc: 0.45
  epoch 4 loss: 0.12350221407974095 decode acc: 0.5
  epoch 5 loss: 0.12310866862138818 decode acc: 0.5
  epoch 6 loss: 0.12284618314643189 decode acc: 0.575
  epoch 7 loss: 0.12260553814609598 decode acc: 0.625
  epoch 8 loss: 0.12241629599044042 decode acc: 0.625
  epoch 9 loss: 0.12229219753338155 decode acc: 0.65
  epoch 10 loss: 0.12215043415299748 decode acc: 0.65
  epoch 11 loss: 0.12201044838726785 decode acc: 0.7
  epoch 12 loss: 0.12192451414431527 decode acc: 0.675
  epoch 13 loss: 0.12184440264579144 decode acc: 0.675
  epoch 14 loss: 0.12175369236251371 decode acc: 0.675
  epoch 15 loss: 0.12169908646894115 decode acc: 0.65
  epoch 16 loss: 0.1216364912787897 decode acc: 0.65
  epoch 17 loss: 0.12159374374780596 decode acc: 0.65
  epoch 18 loss: 0.12153085048154275 decode acc: 0.65
  epoch 

0,1
decode_acc,▁▂▄▅▅▆▇▇▇▇████▇▇▇▇▇▇
test_loss,█▆▅▅▄▄▃▃▃▂▂▂▂▁▁▂▁▁▁▂
train_loss,█▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.94133


0,1
decode_accs,█▆▁

0,1
decode_accs,0.7
