In [1]:
# imports 
import torch
import wandb
import numpy as np
from torch.nn import PoissonNLLLoss
from fix_models.metrics import corr_to_avg

from fix_models.datasets import get_datasets_and_loaders

In [2]:
# config

# all parameters
config = dict()
config["modality"] = "video" # or image

# paths
input_dir = f'../data/{config["modality"]}/'
stimulus_dir = f'../data/{config["modality"]}/stimuli/'
embedding_dir = f'../data/{config["modality"]}/embeddings/'
model_output_path = f'../data/{config["modality"]}/model_output/results'

# dataset and dataloader hyperparameters 
config["win_size"] = 240
config['pos'] = (400, 180)
config["feat_ext_type"] = 'resnet3d'
config["stim_size"] = 32 
config["stim_dur_ms"] = 200
config["stim_shape"] = (1, 3, 5, config["stim_size"], config["stim_size"])
config["first_frame_only"] = False
config["exp_var_thresholds"] = [0.25, 0.25, 0.25]
config["batch_size"] = 16

# model hyperparameters
config["layer"] = "layer1"
config["use_sigma"] = True
config["center_readout"] = False
config["use_pool"] = True
config["pool_size"] = 4
config["pool_stride"] = 2
config["use_pretrained"] = True
config["flatten_time"] = True

# training parameters 
config["lr"] = 0.001 
config["num_epochs"] = 20
config["l2_weight"] = 0

# logging
config["wandb"] = True

# save model
config["save"] = True

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# session names
session_ids = ["082824", "082924", "083024"]

In [3]:
def train_model(full_model_fcn, model_name):
    # corr avgs
    corr_avgs = []

    config['model_name'] = model_name

    print(config['l2_weight'])
    
    for ses_idx, session_id in enumerate(session_ids):
        # set sess_corr_avg
        sess_corr_avg = -1
        sess_corrs = []

        # set session index 
        config["session_id"] = session_id

        # setup logging
        if config["wandb"]:
            wandb.init(
                project=f'{config["modality"]}-cs230',
                config=config,
            )
            wandb.define_metric("corr_to_avg", summary="max")
            wandb.define_metric("test_loss", summary="min")

        # load datasets and loaders 
        train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = config['pos'])

        full_model = full_model_fcn(train_dataset)

        for name, param in full_model.named_parameters():
            print(name)
            print(param.shape)
            print(param.requires_grad)
        # set which parameters to use regularization with and which not to
        params_with_l2 = []
        params_without_l2 = []
        for name, param in full_model.named_parameters():
            if 'mu' in name or 'sigma' in name or 'bias' in name:
                params_without_l2.append(param)
            else:
                params_with_l2.append(param)

        # setup Adam optimizer
        optimizer = torch.optim.Adam([
            {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
            {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
        ], lr=config["lr"], weight_decay=config['l2_weight'])
    
        # using poisson loss 
        loss_func = PoissonNLLLoss(log_input=False, full=True)
            
        for epochs in range(config["num_epochs"]):
            epoch_loss = 0
            for i, (stimulus, targets) in enumerate(train_loader): 
                stimulus = stimulus.to(device)
                targets = targets.to(device)
                
                optimizer.zero_grad()
                preds = full_model(stimulus)
                loss = loss_func(preds, targets)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
    
            # printing corr to avg and loss metrics 
            with torch.no_grad():
                corr_avg = corr_to_avg(full_model, test_loader, modality=config["modality"], device=device)
                test_loss = 0
                for i, (stimulus, targets) in enumerate(test_loader):
                    stimulus = stimulus.to(device)
                    targets = targets.to(device)
                    preds = full_model(stimulus) 
                    loss = loss_func(preds, targets)
                    test_loss += loss.item()
                    
            if config["wandb"]:
                wandb.log({"corr_to_avg": np.nanmean(corr_avg), "train_loss": epoch_loss / len(train_loader), "test_loss": test_loss / len(test_loader)})
            
            if np.nanmean(corr_avg) > sess_corr_avg:
                sess_corr_avg = np.nanmean(corr_avg)
                sess_corrs = corr_avg
                
            print('  epoch {} loss: {} corr: {}'.format(epochs + 1, epoch_loss / len(train_dataset), np.nanmean(corr_avg)))
            print(f' num. neurons : {len(corr_avg)}')
            
        if config["save"]:
            torch.save(full_model.state_dict(), f"{model_output_path}_{session_id}.pickle")
            
        corr_avgs.append(sess_corrs)
        
        if config["wandb"]:
            wandb.finish()
    
    if config["wandb"]:
        wandb.init(
            project=f'{config["modality"]}-cs230',
            config=config,
        )
        for sess_corr in corr_avgs:
            for corr in sess_corr:
                wandb.log({"corr": corr})
        wandb.finish()

In [None]:
#### train 4 encoding models: 
from fix_models.models import FullModel

for layer in ['layer1', 'layer2', 'layer3', 'layer4']:
    config['layer'] = layer
    print(f'layer: {layer}')

    #### 1) frozen pretrained CNN with poisson GLM readout
    cnn_frozen_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=True, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 2) unfrozen pretrained CNN with poisson GLM readout
    cnn_unfrozen_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=False, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 3) CNN trained from scratch with poisson GLM readout
    cnn_untrained_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 4) poisson GLM trained from scratch
    poisson_glm_fcn = lambda train_dataset: FullModel(feat_ext_type = 'none', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)

    config['l2_weight'] = 0 #l2_weight
    train_model(cnn_frozen_fcn, "frozen pretrained")
    train_model(cnn_unfrozen_fcn, "unfrozen pretrained")
    train_model(cnn_untrained_fcn, "untrained")
    config['l2_weight'] = 100
    train_model(poisson_glm_fcn, "linear")

layer: layer1
0


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33met22[0m. Use [1m`wandb login --relogin`[0m to force relogin


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.11971184977778682 corr: 0.1411561720009646
 num. neurons : 54
  epoch 2 loss: 0.11071030822801001 corr: 0.11844485913714636
 num. neurons : 54
  epoch 3 loss: 0.11010752183419686 corr: 0.10726472906333752
 num. neurons : 54
  epoch 4 loss: 0.10989189089080434 corr: 0.1492363559551482
 num. neurons : 54
  epoch 5 loss: 0.1097915001857428 corr: 0.20662128215548028


0,1
corr_to_avg,▃▂▁▃▇▅▃▅▄▅▆▆▆▇█▇▇▇█▇
test_loss,█▆▅▄▄▅▅▄▃▄▂▃▃▁▂▃▁▂▁▃
train_loss,█▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.72452


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.1.mu
torch.Size([95, 2])
True
model.1.sigma
torch.Size([95])
True
model.1.poisson_linear.linear.weight
torch.Size([95, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.11322425641314522 corr: 0.15243800999530488
 num. neurons : 95
  epoch 2 loss: 0.10636829025458291 corr: 0.17595817716659506
 num. neurons : 95
  epoch 3 loss: 0.1057932906749985 corr: 0.1899716343673713
 num. neurons : 95
  epoch 4 loss: 0.1054267930734844 corr: 0.203550364511327
 num. neurons : 95
  epoch 5 loss: 0.10532792881521255 corr: 0.1750038053284786
 nu

0,1
corr_to_avg,▁▂▃▄▂▅▂▅▄▇▇▆▇███▆▇▇▇
test_loss,█▅▅▅▄▃▂▃▂▁▂▂▂▁▂▁▂▁▁▁
train_loss,█▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
train_loss,1.63972


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.13499187587528177 corr: 0.12183234949544966
 num. neurons : 54
  epoch 2 loss: 0.11987168215605254 corr: 0.17570166846053645
 num. neurons : 54
  epoch 3 loss: 0.11860771409790391 corr: 0.18995573714770012
 num. neurons : 54
  epoch 4 loss: 0.11825785926727457 corr: 0.12308779536834902
 num. neurons : 54
  epoch 5 loss: 0.1180444964287867 corr: 0.1936493478178513

0,1
corr_to_avg,▁▃▄▁▄▃▄▇▅▆▆▇█▇▆██▇██
test_loss,█▅▄▄▄▄▃▄▃▂▃▂▁▂▁▂▂▁▂▁
train_loss,█▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.86021


0,1
corr,▅▅▅▄▆▆▇▆▇▅▆▆▇▇▅▆▇▆▅▄▁▇▅█▄▆▅▆▅▆▆▇▅▅▆▄▄▅▆▅

0,1
corr,0.3867


0


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.11542467305689683 corr: 0.028716704857319815
 num. neurons : 54
  epoch 2 loss: 0.11036782264709473 corr: 0.09549239197286756
 num. neurons : 54
  epoch 3 loss: 0.11006301291194963 corr: 0.08663872011856187
 num. neurons : 54
  epoch 4 loss: 0.10992128089622215 corr: 0.06881272590273341
 num. neurons : 54
  epoch 5 loss: 0.10977907121917348 corr: 0.14203829583079683
 n

0,1
corr_to_avg,▁▃▃▂▄▅▅▄▅▆▇▇▇▇▇▇█▇██
test_loss,█▇▇█▆▅▆▅▅▃▃▃▂▂▂▁▁▂▁▁
train_loss,█▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
train_loss,1.69962


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([95, 2])
True
model.1.sigma
torch.Size([95])
True
model.1.poisson_linear.linear.weight
torch.Size([95, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.11001958853287222 corr: 0.006370007557453737
 num. neurons : 95
  epoch 2 loss: 0.10601983095338831 corr: 0.13721229920008915
 num. neurons : 95
  epoch 3 loss: 0.10575780980874107 corr: 0.14341386831428793
 num. neurons : 95
  epoch 4 loss: 0.10533650826409226 corr: 0.14978830918196345
 num. neurons : 95
  epoch 5 loss: 0.10517595812912386 corr: 0.17451204892471914
 n

0,1
corr_to_avg,▁▄▄▅▅▆▅▆▆▇█▇▇▇██████
test_loss,█▇█▇▆▆▄▅▄▄▂▂▃▂▂▂▁▂▂▁
train_loss,█▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
train_loss,1.61145


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.12565484264834115 corr: 0.038344091413455954
 num. neurons : 54
  epoch 2 loss: 0.1188017301254814 corr: 0.038410125069592876
 num. neurons : 54
  epoch 3 loss: 0.11847253467725671 corr: 0.06268410421292353
 num. neurons : 54
  epoch 4 loss: 0.11815750223403497 corr: 0.10176583447333025
 num. neurons : 54
  epoch 5 loss: 0.11792382960205282 corr: 0.11098815524586206
 n

0,1
corr_to_avg,▁▁▂▃▃▄▄▆▅▇▇█▇▇▇█▇█▇▇
test_loss,█▇▅▅▅▅▄▃▄▂▂▃▃▂▃▂▂▃▂▁
train_loss,█▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁

0,1
train_loss,1.8232


0,1
corr,▃▄▁▆▅▆▅▆▇▇█▅▆▇▃▃▃▇▇▇▄▅▁▃▆▆▂▂▆▃▄▅▃▄▅▅▆▇▅▆

0,1
corr,0.37803


0


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.11468338495419349 corr: 0.04794707155730058
 num. neurons : 54
  epoch 2 loss: 0.11116113992384923 corr: 0.07616055395934379
 num. neurons : 54
  epoch 3 loss: 0.1109850842864425 corr: 0.04984696446636102
 num. neurons : 54
  epoch 4 loss: 0.11061539231995006 corr: 0.08878463360019728
 num. neurons : 54
  epoch 5 loss: 0.11040811491601261 corr: 0.07954158635585938
 num

0,1
corr_to_avg,▁▂▁▂▂▃▃▃▁▄▄▅▆▄▅▇▇▇▇█
test_loss,█▆▇▆▆▅▆▅▅▅▄▄▄▃▄▂▂▁▂▁
train_loss,█▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁

0,1
train_loss,1.71177


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([95, 2])
True
model.1.sigma
torch.Size([95])
True
model.1.poisson_linear.linear.weight
torch.Size([95, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.1102906804434292 corr: 0.03798632667568064
 num. neurons : 95
  epoch 2 loss: 0.10725807257347707 corr: 0.05447175859519283
 num. neurons : 95
  epoch 3 loss: 0.10664440892753801 corr: 0.1266031997082627
 num. neurons : 95
  epoch 4 loss: 0.10644053014785207 corr: 0.08954136092399452
 num. neurons : 95
  epoch 5 loss: 0.10592031447675215 corr: 0.11213714855891996
 num.

0,1
corr_to_avg,▁▁▃▂▃▄▄▄▅▅▅▇▆▇▇▇█▇▇█
test_loss,██▇▆▆▆▅▄▄▃▃▂▃▂▃▂▁▂▁▁
train_loss,█▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁

0,1
train_loss,1.63066


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.12465544596327652 corr: 0.012332993387948432
 num. neurons : 54
  epoch 2 loss: 0.1196177595467724 corr: 0.05073651028903067
 num. neurons : 54
  epoch 3 loss: 0.11914079276036667 corr: 0.05256499819715316
 num. neurons : 54
  epoch 4 loss: 0.11884576394268127 corr: 0.0590014210895141
 num. neurons : 54
  epoch 5 loss: 0.1187085795423682 corr: 0.15596740747634924
 num.

0,1
corr_to_avg,▁▂▂▂▄▃▄▅▅▅▅▆▇▇▇▇▇▇▇█
test_loss,█▇▇▅▅▆▄▄▄▄▃▄▃▃▂▂▂▄▃▁
train_loss,█▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
train_loss,1.83644


0,1
corr,▂▅▃▄▃▂▅▅▂▆█▅▆▆▄▂▂▄▂▄▄▇▄▇▇▅▃▅▇▂▁▇▇▅▄▄▆▄▆▅

0,1
corr,0.40119


100


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.14995239870047863 corr: -0.0038529157763796456
 num. neurons : 54
  epoch 2 loss: 0.1474048912378005 corr: -0.05671359284674722
 num. neurons : 54
  epoch 3 loss: 0.14456361252584576 corr: -0.006050159369621809
 num. neurons : 54
  epoch 4 loss: 0.14093821184134778 corr: 0.004737145754248771
 num. neurons : 54
  epoch 5 loss: 0.13896327324855476 corr: -0.00853658614774139
 num. neurons : 54
  epoch 6 loss: 0.13811771116138977 corr: -0.005399715796069087
 num. neurons : 54
  epoch 7 loss: 0.13568587785885658 corr: 0.0025135657519691416
 num. neurons : 54
  epoch 8 loss: 0.13338802190474522 corr: -0.011734927959059768
 num. neurons : 54
  epoch 9 loss: 0.13229606498906643 corr: -0.00608360599346543
 num. neurons : 54
  epoch 10 loss: 0.13045316372388674 corr: 0.03292582675118912
 num. neurons : 54
  epoch 11 loss: 0.13021904209513724 corr: -0.004803498

0,1
corr_to_avg,▅▁▅▅▄▅▅▄▅█▅▆▅▆▆▇▆▄▇█
test_loss,██▇▅▅▄▅▄▄▃▂▂▂▂▂▂▂▁▁▁
train_loss,█▇▇▆▅▅▄▄▃▃▃▂▂▂▂▁▂▁▁▁

0,1
train_loss,1.95324


readout input shape: 15360
model.1.linear.weight
torch.Size([95, 15360])
True
model.1.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.1360421685648214 corr: 0.018646517129874102
 num. neurons : 95
  epoch 2 loss: 0.13247152666770975 corr: 0.004803005988310479
 num. neurons : 95
  epoch 3 loss: 0.12975116574951492 corr: 0.03455693755503983
 num. neurons : 95
  epoch 4 loss: 0.1272927403450012 corr: 0.011582845075603119
 num. neurons : 95
  epoch 5 loss: 0.1255346815623538 corr: 0.037274761375642756
 num. neurons : 95
  epoch 6 loss: 0.12401995721287752 corr: -0.0024397353897709973
 num. neurons : 95
  epoch 7 loss: 0.12262563330964893 corr: 0.0495670665116475
 num. neurons : 95
  epoch 8 loss: 0.12117280105021612 corr: 0.038864721137023876
 num. neurons : 95
  epoch 9 loss: 0.12044504281738042 corr: 0.031562478900961664
 num. neurons : 95
  epoch 10 loss: 0.11966460749741 corr: 0.036550479648757184
 num. neurons : 95
  epoch 11 loss: 0.11815509047183691 corr: -0.010302736096494514
 

0,1
corr_to_avg,▅▅▆▅▆▄▇▆▆▆▄▄▄▆▅▆▆▅▁█
test_loss,█▇▇▆▅▄▆▄▅▃▅▃▃▁▁▁▂▁▂▁
train_loss,█▇▆▅▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁

0,1
train_loss,1.81957


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.1735950998214884 corr: 0.008377857280703456
 num. neurons : 54
  epoch 2 loss: 0.17024380894181151 corr: 0.038091137461167976
 num. neurons : 54
  epoch 3 loss: 0.16686548590131217 corr: 0.02248357195822742
 num. neurons : 54
  epoch 4 loss: 0.16317124988721765 corr: 0.007676881169460737
 num. neurons : 54
  epoch 5 loss: 0.16301541584407533 corr: 0.009644988463717044
 num. neurons : 54
  epoch 6 loss: 0.15681496476746282 corr: 0.009088872192952645
 num. neurons : 54
  epoch 7 loss: 0.15402935623805292 corr: 0.04721592851534323
 num. neurons : 54
  epoch 8 loss: 0.15326492883721177 corr: 0.002285645323471782
 num. neurons : 54
  epoch 9 loss: 0.15154677655914767 corr: -0.04276835105069057
 num. neurons : 54
  epoch 10 loss: 0.14810500440072885 corr: -0.025313869636037406
 num. neurons : 54
  epoch 11 loss: 0.14597644505505025 corr: 0.0113781353940744

0,1
corr_to_avg,▅▇▆▅▅▅▇▄▁▂▅▄▃▅█▂█▄▅▆
test_loss,▇▇▇█▇▃▄▄▄▃▃▂▂▂▁▂▁▂▂▁
train_loss,█▇▇▆▆▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁

0,1
train_loss,2.17557


0,1
corr,▁▂▃▆▄▄▅▆▅▂▇▅▂▃▅▇▆▄▆▄▅▅▃▃▂▂▅▄▆▆▆▂▆█▆▄▄▄▅▅

0,1
corr,-0.09925


layer: layer2
0


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▃▅▄▅▆▆▇▇▇▇▇▇███████
test_loss,█▆▅▅▄▃▄▃▂▂▂▂▂▂▂▂▁▁▁▂
train_loss,█▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
train_loss,1.65471


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod