In [1]:
# imports 
import torch
import wandb
import numpy as np
from torch.nn import PoissonNLLLoss
from fix_models.metrics import corr_to_avg

from fix_models.datasets import get_datasets_and_loaders

In [2]:
# config

# all parameters
config = dict()
config["modality"] = "video" # or image

# paths
input_dir = f'../data/{config["modality"]}/'
stimulus_dir = f'../data/{config["modality"]}/stimuli/'
embedding_dir = f'../data/{config["modality"]}/embeddings/'
model_output_path = f'../data/{config["modality"]}/model_output/results'

# dataset and dataloader hyperparameters 
config["win_size"] = 240
config['pos'] = (400, 180)
config["feat_ext_type"] = 'resnet3d'
config["stim_size"] = 32 
config["stim_dur_ms"] = 200
config["stim_shape"] = (1, 3, 5, config["stim_size"], config["stim_size"])
config["first_frame_only"] = False
config["exp_var_thresholds"] = [0.25, 0.25, 0.25]
config["batch_size"] = 16

# model hyperparameters
config["layer"] = "layer1"
config["use_sigma"] = True
config["center_readout"] = False
config["use_pool"] = True
config["pool_size"] = 4
config["pool_stride"] = 2
config["use_pretrained"] = True
config["flatten_time"] = True

# training parameters 
config["lr"] = 0.001 
config["num_epochs"] = 40
config["l2_weight"] = 0

# logging
config["wandb"] = True

# save model
config["save"] = True

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# session names
session_ids = ["082824", "082924", "083024"]

In [3]:
def train_model(full_model_fcn, model_name):
    # corr avgs
    corr_avgs = []

    config['model_name'] = model_name

    print(config['l2_weight'])
    
    for ses_idx, session_id in enumerate(session_ids):
        # set sess_corr_avg
        sess_corr_avg = -1
        sess_corrs = []

        # set session index 
        config["session_id"] = session_id

        # setup logging
        if config["wandb"]:
            wandb.init(
                project=f'{config["modality"]}-cs230',
                config=config,
            )
            wandb.define_metric("corr_to_avg", summary="max")
            wandb.define_metric("test_loss", summary="min")

        # load datasets and loaders 
        train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = config['pos'])

        full_model = full_model_fcn(train_dataset)

        for name, param in full_model.named_parameters():
            print(name)
            print(param.shape)
            print(param.requires_grad)
        # set which parameters to use regularization with and which not to
        params_with_l2 = []
        params_without_l2 = []
        for name, param in full_model.named_parameters():
            if 'mu' in name or 'sigma' in name or 'bias' in name:
                params_without_l2.append(param)
            else:
                params_with_l2.append(param)

        # setup Adam optimizer
        optimizer = torch.optim.Adam([
            {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
            {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
        ], lr=config["lr"], weight_decay=config['l2_weight'])
    
        # using poisson loss 
        loss_func = PoissonNLLLoss(log_input=False, full=True)
            
        for epochs in range(config["num_epochs"]):
            epoch_loss = 0
            for i, (stimulus, targets) in enumerate(train_loader): 
                stimulus = stimulus.to(device)
                targets = targets.to(device)
                
                optimizer.zero_grad()
                preds = full_model(stimulus)
                loss = loss_func(preds, targets)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
    
            # printing corr to avg and loss metrics 
            with torch.no_grad():
                corr_avg = corr_to_avg(full_model, test_loader, modality=config["modality"], device=device)
                test_loss = 0
                for i, (stimulus, targets) in enumerate(test_loader):
                    stimulus = stimulus.to(device)
                    targets = targets.to(device)
                    preds = full_model(stimulus) 
                    loss = loss_func(preds, targets)
                    test_loss += loss.item()
                    
            if config["wandb"]:
                wandb.log({"corr_to_avg": np.nanmean(corr_avg), "train_loss": epoch_loss / len(train_loader), "test_loss": test_loss / len(test_loader)})
            
            if np.nanmean(corr_avg) > sess_corr_avg:
                sess_corr_avg = np.nanmean(corr_avg)
                sess_corrs = corr_avg
                
            print('  epoch {} loss: {} corr: {}'.format(epochs + 1, epoch_loss / len(train_dataset), np.nanmean(corr_avg)))
            print(f' num. neurons : {len(corr_avg)}')
            
        if config["save"]:
            torch.save(full_model.state_dict(), f"{model_output_path}_{session_id}.pickle")
            
        corr_avgs.append(sess_corrs)
        
        if config["wandb"]:
            wandb.finish()
    
    if config["wandb"]:
        wandb.init(
            project=f'{config["modality"]}-cs230',
            config=config,
        )
        for sess_corr in corr_avgs:
            for corr in sess_corr:
                wandb.log({"corr": corr})
        wandb.finish()

In [4]:
#### train 4 encoding models: 
from fix_models.models import FullModel

"""
#### 1) poisson GLM trained from scratch
poisson_glm_fcn = lambda train_dataset: FullModel(feat_ext_type = 'none', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
#### 2) frozen pretrained CNN with poisson GLM readout
cnn_frozen_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=True, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
#### 3) unfrozen pretrained CNN with poisson GLM readout
cnn_unfrozen_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=False, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
#### 4) CNN trained from scratch with poisson GLM readout
cnn_untrained_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
"""

'\n#### 1) poisson GLM trained from scratch\npoisson_glm_fcn = lambda train_dataset: FullModel(feat_ext_type = \'none\', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config[\'use_pool\'], pool_size = config[\'pool_size\'], pool_stride = config["pool_stride"], device=device)\n#### 2) frozen pretrained CNN with poisson GLM readout\ncnn_frozen_fcn = lambda train_dataset:FullModel(feat_ext_type = \'resnet3d\', freeze_weights=True, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config[\'use_pool\'], pool_size = config[\'pool_size\'], pool_stride = config["pool_stride"], device=device)\n#### 3) unfrozen pretrained CNN with poisson GLM readout\ncnn_unfrozen_fcn = lambda train_dataset:FullModel(feat_ext_type = \'resnet3d\', freeze_weights=False, use_pretrained = True, mo

In [22]:
for layer in ['layer1', 'layer2', 'layer3', 'layer4']:
    config['layer'] = layer
    print(f'layer: {layer}')

    #### 1) frozen pretrained CNN with poisson GLM readout
    cnn_frozen_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=True, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 2) unfrozen pretrained CNN with poisson GLM readout
    cnn_unfrozen_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=False, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 3) CNN trained from scratch with poisson GLM readout
    cnn_untrained_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 4) poisson GLM trained from scratch
    poisson_glm_fcn = lambda train_dataset: FullModel(feat_ext_type = 'none', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)

    config['l2_weight'] = 0 #l2_weight
    train_model(cnn_frozen_fcn, "frozen pretrained")
    train_model(cnn_unfrozen_fcn, "unfrozen pretrained")
    train_model(cnn_untrained_fcn, "untrained")
    config['l2_weight'] = 100
    train_model(poisson_glm_fcn, "linear")

layer: layer1
0
readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.11927233177938579 corr: 0.10107337113395437
 num. neurons : 54
  epoch 2 loss: 0.11063482060844515 corr: 0.12282421543545964
 num. neurons : 54


KeyboardInterrupt: 

In [None]:
    train_model(cnn_frozen_fcn, "frozen pretrained")
