In [1]:
# imports 
import torch
import wandb
import numpy as np
from torch.nn import PoissonNLLLoss
from fix_models.metrics import corr_to_avg

from fix_models.datasets import get_datasets_and_loaders

In [2]:
# config

# all parameters
config = dict()
config["modality"] = "video" # or image

# paths
input_dir = f'../data/{config["modality"]}/'
stimulus_dir = f'../data/{config["modality"]}/stimuli/'
embedding_dir = f'../data/{config["modality"]}/embeddings/'
model_output_path = f'../data/{config["modality"]}/model_output/results'

# dataset and dataloader hyperparameters 
config["win_size"] = 240
config['pos'] = (400, 180)
config["feat_ext_type"] = 'resnet3d'
config["stim_size"] = 32 
config["stim_dur_ms"] = 200
config["stim_shape"] = (1, 3, 5, config["stim_size"], config["stim_size"])
config["first_frame_only"] = False
config["exp_var_thresholds"] = [0.25, 0.25, 0.25]
config["batch_size"] = 16

# model hyperparameters
config["layer"] = "layer1"
config["use_sigma"] = True
config["center_readout"] = False
config["use_pool"] = True
config["pool_size"] = 4
config["pool_stride"] = 2
config["use_pretrained"] = True
config["flatten_time"] = True

# training parameters 
config["lr"] = 0.001 
config["num_epochs"] = 20
config["l2_weight"] = 0

# logging
config["wandb"] = True

# save model
config["save"] = True

# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# session names
session_ids = ["082824", "082924", "083024"]

In [3]:
def train_model(full_model_fcn, model_name):
    # corr avgs
    corr_avgs = []

    config['model_name'] = model_name

    print(config['l2_weight'])
    
    for ses_idx, session_id in enumerate(session_ids):
        # set sess_corr_avg
        sess_corr_avg = -1
        sess_corrs = []

        # set session index 
        config["session_id"] = session_id

        # setup logging
        if config["wandb"]:
            wandb.init(
                project=f'{config["modality"]}-cs230',
                config=config,
            )
            wandb.define_metric("corr_to_avg", summary="max")
            wandb.define_metric("test_loss", summary="min")

        # load datasets and loaders 
        train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], config["stim_dur_ms"], config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], pos = config['pos'])

        full_model = full_model_fcn(train_dataset)

        for name, param in full_model.named_parameters():
            print(name)
            print(param.shape)
            print(param.requires_grad)
        # set which parameters to use regularization with and which not to
        params_with_l2 = []
        params_without_l2 = []
        for name, param in full_model.named_parameters():
            if 'mu' in name or 'sigma' in name or 'bias' in name:
                params_without_l2.append(param)
            else:
                params_with_l2.append(param)

        # setup Adam optimizer
        optimizer = torch.optim.Adam([
            {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
            {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
        ], lr=config["lr"], weight_decay=config['l2_weight'])
    
        # using poisson loss 
        loss_func = PoissonNLLLoss(log_input=False, full=True)
            
        for epochs in range(config["num_epochs"]):
            epoch_loss = 0
            for i, (stimulus, targets) in enumerate(train_loader): 
                stimulus = stimulus.to(device)
                targets = targets.to(device)
                
                optimizer.zero_grad()
                preds = full_model(stimulus)
                loss = loss_func(preds, targets)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
    
            # printing corr to avg and loss metrics 
            with torch.no_grad():
                corr_avg = corr_to_avg(full_model, test_loader, modality=config["modality"], device=device)
                test_loss = 0
                for i, (stimulus, targets) in enumerate(test_loader):
                    stimulus = stimulus.to(device)
                    targets = targets.to(device)
                    preds = full_model(stimulus) 
                    loss = loss_func(preds, targets)
                    test_loss += loss.item()
                    
            if config["wandb"]:
                wandb.log({"corr_to_avg": np.nanmean(corr_avg), "train_loss": epoch_loss / len(train_loader), "test_loss": test_loss / len(test_loader)})
            
            if np.nanmean(corr_avg) > sess_corr_avg:
                sess_corr_avg = np.nanmean(corr_avg)
                sess_corrs = corr_avg
                
            print('  epoch {} loss: {} corr: {}'.format(epochs + 1, epoch_loss / len(train_dataset), np.nanmean(corr_avg)))
            print(f' num. neurons : {len(corr_avg)}')
            
        if config["save"]:
            torch.save(full_model.state_dict(), f"{model_output_path}_{session_id}.pickle")
            
        corr_avgs.append(sess_corrs)
        
        if config["wandb"]:
            wandb.finish()
    
    if config["wandb"]:
        wandb.init(
            project=f'{config["modality"]}-cs230',
            config=config,
        )
        for sess_corr in corr_avgs:
            for corr in sess_corr:
                wandb.log({"corr": corr})
        wandb.finish()

In [5]:
#### train 4 encoding models: 
from fix_models.models import FullModel

for layer in ['layer1', 'layer2', 'layer3', 'layer4']:
    config['layer'] = layer
    print(f'layer: {layer}')

    #### 1) frozen pretrained CNN with poisson GLM readout
    cnn_frozen_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=True, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 2) unfrozen pretrained CNN with poisson GLM readout
    cnn_unfrozen_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=False, use_pretrained = True, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 3) CNN trained from scratch with poisson GLM readout
    cnn_untrained_fcn = lambda train_dataset:FullModel(feat_ext_type = 'resnet3d', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)
    #### 4) poisson GLM trained from scratch
    poisson_glm_fcn = lambda train_dataset: FullModel(feat_ext_type = 'none', freeze_weights=False, use_pretrained = False, modality=config["modality"], layer=config["layer"], stim_shape=config["stim_shape"], train_dataset=train_dataset, use_pool = config['use_pool'], pool_size = config['pool_size'], pool_stride = config["pool_stride"], device=device)

    config['l2_weight'] = 0 #l2_weight
    train_model(cnn_frozen_fcn, "frozen pretrained")
    train_model(cnn_unfrozen_fcn, "unfrozen pretrained")
    train_model(cnn_untrained_fcn, "untrained")
    config['l2_weight'] = 100
    train_model(poisson_glm_fcn, "linear")

layer: layer1
0


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33met22[0m. Use [1m`wandb login --relogin`[0m to force relogin


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.11971184977778682 corr: 0.1411561720009646
 num. neurons : 54
  epoch 2 loss: 0.11071030822801001 corr: 0.11844485913714636
 num. neurons : 54
  epoch 3 loss: 0.11010752183419686 corr: 0.10726472906333752
 num. neurons : 54
  epoch 4 loss: 0.10989189089080434 corr: 0.1492363559551482
 num. neurons : 54
  epoch 5 loss: 0.1097915001857428 corr: 0.20662128215548028


0,1
corr_to_avg,▃▂▁▃▇▅▃▅▄▅▆▆▆▇█▇▇▇█▇
test_loss,█▆▅▄▄▅▅▄▃▄▂▃▃▁▂▃▁▂▁▃
train_loss,█▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.72452


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.1.mu
torch.Size([95, 2])
True
model.1.sigma
torch.Size([95])
True
model.1.poisson_linear.linear.weight
torch.Size([95, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.11322425641314522 corr: 0.15243800999530488
 num. neurons : 95
  epoch 2 loss: 0.10636829025458291 corr: 0.17595817716659506
 num. neurons : 95
  epoch 3 loss: 0.1057932906749985 corr: 0.1899716343673713
 num. neurons : 95
  epoch 4 loss: 0.1054267930734844 corr: 0.203550364511327
 num. neurons : 95
  epoch 5 loss: 0.10532792881521255 corr: 0.1750038053284786
 nu

0,1
corr_to_avg,▁▂▃▄▂▅▂▅▄▇▇▆▇███▆▇▇▇
test_loss,█▅▅▅▄▃▂▃▂▁▂▂▂▁▂▁▂▁▁▁
train_loss,█▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
train_loss,1.63972


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.13499187587528177 corr: 0.12183234949544966
 num. neurons : 54
  epoch 2 loss: 0.11987168215605254 corr: 0.17570166846053645
 num. neurons : 54
  epoch 3 loss: 0.11860771409790391 corr: 0.18995573714770012
 num. neurons : 54
  epoch 4 loss: 0.11825785926727457 corr: 0.12308779536834902
 num. neurons : 54
  epoch 5 loss: 0.1180444964287867 corr: 0.1936493478178513

0,1
corr_to_avg,▁▃▄▁▄▃▄▇▅▆▆▇█▇▆██▇██
test_loss,█▅▄▄▄▄▃▄▃▂▃▂▁▂▁▂▂▁▂▁
train_loss,█▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.86021


0,1
corr,▅▅▅▄▆▆▇▆▇▅▆▆▇▇▅▆▇▆▅▄▁▇▅█▄▆▅▆▅▆▆▇▅▅▆▄▄▅▆▅

0,1
corr,0.3867


0


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.11542467305689683 corr: 0.028716704857319815
 num. neurons : 54
  epoch 2 loss: 0.11036782264709473 corr: 0.09549239197286756
 num. neurons : 54
  epoch 3 loss: 0.11006301291194963 corr: 0.08663872011856187
 num. neurons : 54
  epoch 4 loss: 0.10992128089622215 corr: 0.06881272590273341
 num. neurons : 54
  epoch 5 loss: 0.10977907121917348 corr: 0.14203829583079683
 n

0,1
corr_to_avg,▁▃▃▂▄▅▅▄▅▆▇▇▇▇▇▇█▇██
test_loss,█▇▇█▆▅▆▅▅▃▃▃▂▂▂▁▁▂▁▁
train_loss,█▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
train_loss,1.69962


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([95, 2])
True
model.1.sigma
torch.Size([95])
True
model.1.poisson_linear.linear.weight
torch.Size([95, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.11001958853287222 corr: 0.006370007557453737
 num. neurons : 95
  epoch 2 loss: 0.10601983095338831 corr: 0.13721229920008915
 num. neurons : 95
  epoch 3 loss: 0.10575780980874107 corr: 0.14341386831428793
 num. neurons : 95
  epoch 4 loss: 0.10533650826409226 corr: 0.14978830918196345
 num. neurons : 95
  epoch 5 loss: 0.10517595812912386 corr: 0.17451204892471914
 n

0,1
corr_to_avg,▁▄▄▅▅▆▅▆▆▇█▇▇▇██████
test_loss,█▇█▇▆▆▄▅▄▄▂▂▃▂▂▂▁▂▂▁
train_loss,█▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
train_loss,1.61145


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.12565484264834115 corr: 0.038344091413455954
 num. neurons : 54
  epoch 2 loss: 0.1188017301254814 corr: 0.038410125069592876
 num. neurons : 54
  epoch 3 loss: 0.11847253467725671 corr: 0.06268410421292353
 num. neurons : 54
  epoch 4 loss: 0.11815750223403497 corr: 0.10176583447333025
 num. neurons : 54
  epoch 5 loss: 0.11792382960205282 corr: 0.11098815524586206
 n

0,1
corr_to_avg,▁▁▂▃▃▄▄▆▅▇▇█▇▇▇█▇█▇▇
test_loss,█▇▅▅▅▅▄▃▄▂▂▃▃▂▃▂▂▃▂▁
train_loss,█▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁

0,1
train_loss,1.8232


0,1
corr,▃▄▁▆▅▆▅▆▇▇█▅▆▇▃▃▃▇▇▇▄▅▁▃▆▆▂▂▆▃▄▅▃▄▅▅▆▇▅▆

0,1
corr,0.37803


0


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.11468338495419349 corr: 0.04794707155730058
 num. neurons : 54
  epoch 2 loss: 0.11116113992384923 corr: 0.07616055395934379
 num. neurons : 54
  epoch 3 loss: 0.1109850842864425 corr: 0.04984696446636102
 num. neurons : 54
  epoch 4 loss: 0.11061539231995006 corr: 0.08878463360019728
 num. neurons : 54
  epoch 5 loss: 0.11040811491601261 corr: 0.07954158635585938
 num

0,1
corr_to_avg,▁▂▁▂▂▃▃▃▁▄▄▅▆▄▅▇▇▇▇█
test_loss,█▆▇▆▆▅▆▅▅▅▄▄▄▃▄▂▂▁▂▁
train_loss,█▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁

0,1
train_loss,1.71177


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([95, 2])
True
model.1.sigma
torch.Size([95])
True
model.1.poisson_linear.linear.weight
torch.Size([95, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.1102906804434292 corr: 0.03798632667568064
 num. neurons : 95
  epoch 2 loss: 0.10725807257347707 corr: 0.05447175859519283
 num. neurons : 95
  epoch 3 loss: 0.10664440892753801 corr: 0.1266031997082627
 num. neurons : 95
  epoch 4 loss: 0.10644053014785207 corr: 0.08954136092399452
 num. neurons : 95
  epoch 5 loss: 0.10592031447675215 corr: 0.11213714855891996
 num.

0,1
corr_to_avg,▁▁▃▂▃▄▄▄▅▅▅▇▆▇▇▇█▇▇█
test_loss,██▇▆▆▆▅▄▄▃▃▂▃▂▃▂▁▂▁▁
train_loss,█▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁

0,1
train_loss,1.63066


readout input shape: 320
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.1.mu
torch.Size([54, 2])
True
model.1.sigma
torch.Size([54])
True
model.1.poisson_linear.linear.weight
torch.Size([54, 320])
True
model.1.poisson_linear.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.12465544596327652 corr: 0.012332993387948432
 num. neurons : 54
  epoch 2 loss: 0.1196177595467724 corr: 0.05073651028903067
 num. neurons : 54
  epoch 3 loss: 0.11914079276036667 corr: 0.05256499819715316
 num. neurons : 54
  epoch 4 loss: 0.11884576394268127 corr: 0.0590014210895141
 num. neurons : 54
  epoch 5 loss: 0.1187085795423682 corr: 0.15596740747634924
 num.

0,1
corr_to_avg,▁▂▂▂▄▃▄▅▅▅▅▆▇▇▇▇▇▇▇█
test_loss,█▇▇▅▅▆▄▄▄▄▃▄▃▃▂▂▂▄▃▁
train_loss,█▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
train_loss,1.83644


0,1
corr,▂▅▃▄▃▂▅▅▂▆█▅▆▆▄▂▂▄▂▄▄▇▄▇▇▅▃▅▇▂▁▇▇▅▄▄▆▄▆▅

0,1
corr,0.40119


100


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.14995239870047863 corr: -0.0038529157763796456
 num. neurons : 54
  epoch 2 loss: 0.1474048912378005 corr: -0.05671359284674722
 num. neurons : 54
  epoch 3 loss: 0.14456361252584576 corr: -0.006050159369621809
 num. neurons : 54
  epoch 4 loss: 0.14093821184134778 corr: 0.004737145754248771
 num. neurons : 54
  epoch 5 loss: 0.13896327324855476 corr: -0.00853658614774139
 num. neurons : 54
  epoch 6 loss: 0.13811771116138977 corr: -0.005399715796069087
 num. neurons : 54
  epoch 7 loss: 0.13568587785885658 corr: 0.0025135657519691416
 num. neurons : 54
  epoch 8 loss: 0.13338802190474522 corr: -0.011734927959059768
 num. neurons : 54
  epoch 9 loss: 0.13229606498906643 corr: -0.00608360599346543
 num. neurons : 54
  epoch 10 loss: 0.13045316372388674 corr: 0.03292582675118912
 num. neurons : 54
  epoch 11 loss: 0.13021904209513724 corr: -0.004803498

0,1
corr_to_avg,▅▁▅▅▄▅▅▄▅█▅▆▅▆▆▇▆▄▇█
test_loss,██▇▅▅▄▅▄▄▃▂▂▂▂▂▂▂▁▁▁
train_loss,█▇▇▆▅▅▄▄▃▃▃▂▂▂▂▁▂▁▁▁

0,1
train_loss,1.95324


readout input shape: 15360
model.1.linear.weight
torch.Size([95, 15360])
True
model.1.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.1360421685648214 corr: 0.018646517129874102
 num. neurons : 95
  epoch 2 loss: 0.13247152666770975 corr: 0.004803005988310479
 num. neurons : 95
  epoch 3 loss: 0.12975116574951492 corr: 0.03455693755503983
 num. neurons : 95
  epoch 4 loss: 0.1272927403450012 corr: 0.011582845075603119
 num. neurons : 95
  epoch 5 loss: 0.1255346815623538 corr: 0.037274761375642756
 num. neurons : 95
  epoch 6 loss: 0.12401995721287752 corr: -0.0024397353897709973
 num. neurons : 95
  epoch 7 loss: 0.12262563330964893 corr: 0.0495670665116475
 num. neurons : 95
  epoch 8 loss: 0.12117280105021612 corr: 0.038864721137023876
 num. neurons : 95
  epoch 9 loss: 0.12044504281738042 corr: 0.031562478900961664
 num. neurons : 95
  epoch 10 loss: 0.11966460749741 corr: 0.036550479648757184
 num. neurons : 95
  epoch 11 loss: 0.11815509047183691 corr: -0.010302736096494514
 

0,1
corr_to_avg,▅▅▆▅▆▄▇▆▆▆▄▄▄▆▅▆▆▅▁█
test_loss,█▇▇▆▅▄▆▄▅▃▅▃▃▁▁▁▂▁▂▁
train_loss,█▇▆▅▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁

0,1
train_loss,1.81957


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.1735950998214884 corr: 0.008377857280703456
 num. neurons : 54
  epoch 2 loss: 0.17024380894181151 corr: 0.038091137461167976
 num. neurons : 54
  epoch 3 loss: 0.16686548590131217 corr: 0.02248357195822742
 num. neurons : 54
  epoch 4 loss: 0.16317124988721765 corr: 0.007676881169460737
 num. neurons : 54
  epoch 5 loss: 0.16301541584407533 corr: 0.009644988463717044
 num. neurons : 54
  epoch 6 loss: 0.15681496476746282 corr: 0.009088872192952645
 num. neurons : 54
  epoch 7 loss: 0.15402935623805292 corr: 0.04721592851534323
 num. neurons : 54
  epoch 8 loss: 0.15326492883721177 corr: 0.002285645323471782
 num. neurons : 54
  epoch 9 loss: 0.15154677655914767 corr: -0.04276835105069057
 num. neurons : 54
  epoch 10 loss: 0.14810500440072885 corr: -0.025313869636037406
 num. neurons : 54
  epoch 11 loss: 0.14597644505505025 corr: 0.0113781353940744

0,1
corr_to_avg,▅▇▆▅▅▅▇▄▁▂▅▄▃▅█▂█▄▅▆
test_loss,▇▇▇█▇▃▄▄▄▃▃▂▂▂▁▂▁▂▂▁
train_loss,█▇▇▆▆▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁

0,1
train_loss,2.17557


0,1
corr,▁▂▃▆▄▄▅▆▅▂▇▅▂▃▅▇▆▄▆▄▅▅▃▃▂▂▅▄▆▆▆▂▆█▆▄▄▄▅▅

0,1
corr,-0.09925


layer: layer2
0


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▃▅▄▅▆▆▇▇▇▇▇▇███████
test_loss,█▆▅▅▄▃▄▃▂▂▂▂▂▂▂▂▁▁▁▂
train_loss,█▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
train_loss,1.65471


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▃▅▆▆▇▆▇▇▇██████████
test_loss,█▆▆▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁
train_loss,█▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
train_loss,1.57636


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▃▅▆▅▆▇▇▇▇██████████
test_loss,█▆▅▄▄▃▃▂▂▃▃▂▁▂▂▂▂▂▁▂
train_loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,1.78744


0,1
corr,▅▅▂▄▃▇▆▅▇▅▅▁█▆▅▄▅▄▆▅▅▃▄▁█▇▄▄▂▆▄▇▃▂▆▄█▆▄▅

0,1
corr,0.5348


0


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▃▃▅▇▇▇▇██▇▇▇▇▇▇▆▆▅▆
test_loss,▅▅▄▄▃▃▂▂▁▁▁▂▃▂▃▄▅▅█▆
train_loss,█▇▇▇▆▆▆▆▆▅▅▅▄▄▃▃▂▂▁▁

0,1
train_loss,1.32807


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▃▃▅▆▆▇██████▇█▇▇▇▇▇
test_loss,█▇▅▄▃▃▁▃▂▁▂▃▃▃▂▃▄▅▆▇
train_loss,█▇▇▇▆▆▆▆▅▅▅▄▄▃▃▃▂▂▁▁

0,1
train_loss,1.27167


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▃▄▄▆▇▆▇███▇█▇█▇▇▇▆▇
test_loss,███▄▃▃▁▂▂▁▁▂▁▂▁▄▅▄█▇
train_loss,█▇▆▆▆▆▆▆▅▅▅▄▄▄▃▃▂▂▁▁

0,1
train_loss,1.46023


0,1
corr,▃▄▂▃▄▁▆▁▆▅▄▄▅▄▄▃▇▅▅▃▆█▁▅▃▂▂▅▃▂▄▂▄▆▄▃▆▇▄▅

0,1
corr,0.34092


0


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▃▄▄▃▄▄▅▅▇▆▆▇▇▇▇▇███
test_loss,█▆▆▅▆▅▄▃▄▃▃▃▂▂▂▁▂▁▁▂
train_loss,█▇▇▆▆▆▅▅▄▄▄▄▃▃▃▂▂▂▁▁

0,1
train_loss,1.6635


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▂▁▂▃▃▄▅▅▆▆▆▆▇▇▇█████
test_loss,█▇▇▆▅▅▄▃▃▂▃▂▂▃▃▁▂▁▁▁
train_loss,█▇▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁

0,1
train_loss,1.57236


readout input shape: 384
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▃▂▂▄▄▃▅▅▆▆▇▇██▇████
test_loss,▆▇▇▇█▅▅▄▄▃▃▄▂▂▃▃▁▁▃▂
train_loss,█▆▆▆▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁▁

0,1
train_loss,1.79724


0,1
corr,▅▃▃▄▆▄▁▅▅▅▅▅▂▅▅▅▄▄▄▅▄▄▅▅▂▃█▅▄▇▅▅▄▆▅▅▄▄▅▄

0,1
corr,0.34331


100


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.15095129648844402 corr: -0.04376878406043991
 num. neurons : 54
  epoch 2 loss: 0.1471584129333496 corr: -0.01585685344952155
 num. neurons : 54
  epoch 3 loss: 0.1442125361054032 corr: -0.05351647689135093
 num. neurons : 54
  epoch 4 loss: 0.1426661091674993 corr: -0.020611609962141102
 num. neurons : 54
  epoch 5 loss: 0.13935285532916034 corr: -0.01872228107421703
 num. neurons : 54
  epoch 6 loss: 0.13714045642334738 corr: -0.0018577465969192941
 num. neurons : 54
  epoch 7 loss: 0.13523867660098607 corr: -0.01448592535369925
 num. neurons : 54
  epoch 8 loss: 0.13352622991726723 corr: 0.012582612287564044
 num. neurons : 54
  epoch 9 loss: 0.1324846354825997 corr: -0.010809339018037743
 num. neurons : 54
  epoch 10 loss: 0.13032600391058274 corr: 0.01701858310361484
 num. neurons : 54
  epoch 11 loss: 0.12894863140435867 corr: 0.033893517565767

0,1
corr_to_avg,▂▃▁▃▃▄▄▅▄▆▇▄▄▇▄▂▅▂▇█
test_loss,▆█▅▄▅▄▃▂▃▂▂▃▂▁▃▂▂▂▁▁
train_loss,█▇▆▆▅▅▄▄▃▃▃▂▂▂▂▁▂▁▁▁

0,1
train_loss,1.95362


readout input shape: 15360
model.1.linear.weight
torch.Size([95, 15360])
True
model.1.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.13565370618360828 corr: 0.0324399311345626
 num. neurons : 95
  epoch 2 loss: 0.1324557623938116 corr: 0.028031179329707688
 num. neurons : 95
  epoch 3 loss: 0.13006829347910057 corr: -0.007691426779435869
 num. neurons : 95
  epoch 4 loss: 0.12744276436211552 corr: 0.04695674283970989
 num. neurons : 95
  epoch 5 loss: 0.1253991852256016 corr: 0.0003027810009921688
 num. neurons : 95
  epoch 6 loss: 0.12399224167718938 corr: 0.041051640443041326
 num. neurons : 95
  epoch 7 loss: 0.12272318439333851 corr: 0.03753911142120452
 num. neurons : 95
  epoch 8 loss: 0.12199876008857607 corr: 0.008150738098785835
 num. neurons : 95
  epoch 9 loss: 0.12103514196985055 corr: 0.05884900033730749
 num. neurons : 95
  epoch 10 loss: 0.11940210595804983 corr: 0.0022065438512842103
 num. neurons : 95
  epoch 11 loss: 0.11887161631858786 corr: 0.017151867307564496

0,1
corr_to_avg,▅▅▁▇▂▆▆▃█▂▄▁▂▂▅▄▇▆▆▅
test_loss,▇▇█▅▆▄▆▃▄▃▂▃▂▂▁▂▂▂▁▁
train_loss,█▇▆▅▅▄▄▃▃▃▂▂▂▂▁▁▁▁▁▁

0,1
train_loss,1.82638


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.17384846832441248 corr: 0.037618810233018614
 num. neurons : 54
  epoch 2 loss: 0.17024091872475056 corr: 0.05913861282537757
 num. neurons : 54
  epoch 3 loss: 0.1660075079791823 corr: 0.013819172536978022
 num. neurons : 54
  epoch 4 loss: 0.16322767247959058 corr: 0.03229347311244887
 num. neurons : 54
  epoch 5 loss: 0.1585797097900851 corr: 0.02364683851780048
 num. neurons : 54
  epoch 6 loss: 0.1574810328690902 corr: 0.02160693269783046
 num. neurons : 54
  epoch 7 loss: 0.155250028899632 corr: -0.031193891053875163
 num. neurons : 54
  epoch 8 loss: 0.15302252018462267 corr: 0.005225175930782751
 num. neurons : 54
  epoch 9 loss: 0.14931785351635612 corr: -0.028252520483481747
 num. neurons : 54
  epoch 10 loss: 0.14869031831027135 corr: 0.04469949589308829
 num. neurons : 54
  epoch 11 loss: 0.14767485291109533 corr: 0.016137908852156962
 nu

0,1
corr_to_avg,▆█▄▆▅▅▁▄▁▇▅▄▅▃▆▄▅▆▅▃
test_loss,▅▄▄▃▃█▃▂▂▂▂▂▂▁▃▁▁▁▁▂
train_loss,█▇▇▆▅▅▅▄▃▃▃▃▂▂▂▂▁▁▁▁

0,1
train_loss,2.17794


0,1
corr,▄▅▄▆▅▅▄▅▆▄▅▄▃▄▄▆▄▆▃▄▄▂▄▂▅▆▄▄▅█▆▅▂▃▁▅▃▅▄▄

0,1
corr,-0.02874


layer: layer3
0


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▃▅▅▆▆▇▆▇▇▇▇████████
test_loss,█▇▅▅▄▄▃▃▃▂▂▂▂▁▂▂▂▂▁▁
train_loss,█▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
train_loss,1.62542


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▃▄▆▆▆▆▇▇▇▇▇████████
test_loss,█▆▅▅▄▄▃▂▂▂▂▁▂▂▂▁▁▂▂▁
train_loss,█▅▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁

0,1
train_loss,1.54189


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▃▄▅▅▆▇▇▇▇▇▇████████
test_loss,█▇▅▄▄▃▃▃▂▂▁▂▂▁▁▂▁▁▁▁
train_loss,█▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
train_loss,1.76303


0,1
corr,▂▅▆▂▁▆▆▅▇▄▅▅▃▃█▅▄▅▆▄▄▄▅▆▄▆▆▆▃▅▂▆▁▂▃▆▁▄█▇

0,1
corr,0.58738


0


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▂▄▅▆▆▆▇▇██▇▇▇▇▇▇▆▇▆
test_loss,▃▃▂▂▂▂▁▂▁▁▁▃▄▄▄▅▆▇▇█
train_loss,█▇▇▇▇▆▆▅▅▄▄▃▃▂▂▂▁▁▁▁

0,1
train_loss,0.98294


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▂▄▄▅▆▇█▇████▇▆▆▆▆▆▅
test_loss,▃▂▂▂▂▁▁▁▁▁▁▁▂▃▄▄▆▇▇█
train_loss,█▇▇▇▇▆▆▆▅▅▄▄▃▃▃▂▂▁▁▁

0,1
train_loss,0.93987


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▂▃▁▂▄▅▇███▇▇▇▆▆▆▆▆▆▆
test_loss,▃▃▃▂▂▁▂▁▂▁▁▂▃▃▅▆▇███
train_loss,█▇▇▇▇▇▆▆▅▅▄▄▃▃▂▂▂▁▁▁

0,1
train_loss,1.08441


0,1
corr,▃▇▆▆▇▃▇▅▅▄▇▅▅▆▁▄▅▅▃▅▅▇▄▄▆█▇█▆▄█▃▆▅▅▅▃▄▄▅

0,1
corr,0.41376


0


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▁▃▁▁▄▂▅▂▅▅▇▇▆▇▇▇▇█▇
test_loss,▇▆█▆▅▅▅▄▄▄▄▃▂▃▃▃▃▁▁▂
train_loss,█▇▇▇▇▆▆▆▆▆▅▅▅▅▄▄▃▃▂▁

0,1
train_loss,1.56229


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▁▁▁▁▃▂▃▄▅▄▅▆▇▆▇▆▇███
test_loss,█▇█▆▅▄▄▄▅▄▄▄▄▂▂▄▁▂▁▁
train_loss,█▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▂▂▁

0,1
train_loss,1.50701


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

0,1
corr_to_avg,▂▂▁▃▃▂▃▄▅▅▇▆▇▇▇▇▇███
test_loss,█▇▇▆▅▆▅▄▅▄▃▇▃▄▃▂▁▁▁▁
train_loss,█▇▇▆▆▆▆▅▅▅▅▄▄▄▄▃▃▂▂▁

0,1
train_loss,1.73359


0,1
corr,▂▅▃▂▄▄▄▂▁▆▁▃▂▅▆▃█▆▂▂▇▆▃▃▃▂▅▁▄▅▃▆▂▆▅▃▄▁▆▃

0,1
corr,0.35175


100


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.1500046951976823 corr: -0.034138652008373335
 num. neurons : 54
  epoch 2 loss: 0.1467015728538419 corr: -0.015194406691384874
 num. neurons : 54
  epoch 3 loss: 0.14345594041141463 corr: 0.05671906499720175
 num. neurons : 54
  epoch 4 loss: 0.14153017732832168 corr: 0.014475090005036327
 num. neurons : 54
  epoch 5 loss: 0.13898173608897646 corr: -0.023752447634120547
 num. neurons : 54
  epoch 6 loss: 0.13687474951332 corr: 0.025623046425159897
 num. neurons : 54
  epoch 7 loss: 0.13580946910528488 corr: 0.03688330270790436
 num. neurons : 54
  epoch 8 loss: 0.13373428921640654 corr: -0.012540351824497975
 num. neurons : 54
  epoch 9 loss: 0.13200486919026316 corr: 0.020551060340449907
 num. neurons : 54
  epoch 10 loss: 0.1303183192971312 corr: -0.02153832318894552
 num. neurons : 54
  epoch 11 loss: 0.1294022735548608 corr: -0.023810434995751768

0,1
corr_to_avg,▁▂█▅▂▆▆▃▅▂▂▂▅▂▄▆▆▅▄▆
test_loss,▇▆▆█▅▄▄▄▄▃▃▂▃▂▃▁▂▂▃▁
train_loss,█▇▆▆▅▅▄▄▃▃▃▃▂▂▂▂▁▂▁▁

0,1
train_loss,1.95529


readout input shape: 15360
model.1.linear.weight
torch.Size([95, 15360])
True
model.1.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.1356720928122236 corr: -0.00871890425344479
 num. neurons : 95
  epoch 2 loss: 0.13222199410044086 corr: 0.051097064910031485
 num. neurons : 95
  epoch 3 loss: 0.13011004375537652 corr: 0.006658844540434523
 num. neurons : 95
  epoch 4 loss: 0.12733195489613797 corr: 0.0035515747816946145
 num. neurons : 95
  epoch 5 loss: 0.12572843279514012 corr: 0.033560702125966374
 num. neurons : 95
  epoch 6 loss: 0.12428738016108568 corr: 0.018923160000162895
 num. neurons : 95
  epoch 7 loss: 0.12296373912801294 corr: 0.054402094365688325
 num. neurons : 95
  epoch 8 loss: 0.12114642590128315 corr: 0.000865825247332276
 num. neurons : 95
  epoch 9 loss: 0.12010606109159778 corr: 0.02355149186447344
 num. neurons : 95
  epoch 10 loss: 0.11949041295426054 corr: 0.025246680704302055
 num. neurons : 95
  epoch 11 loss: 0.11876468508655488 corr: 0.013826542144249

0,1
corr_to_avg,▄█▅▅▇▆█▄▆▆▅▄▅▄▁▆▄▄▆▅
test_loss,█▇▆▆▆▄▃▄▃▄▂▃▂▂▂▁▂▂▁▁
train_loss,█▇▆▅▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁

0,1
train_loss,1.82191


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.17357887667475702 corr: 0.031306166722312644
 num. neurons : 54
  epoch 2 loss: 0.1695348153414722 corr: 0.018359569424531787
 num. neurons : 54
  epoch 3 loss: 0.16852589254370812 corr: 0.034733457076344024
 num. neurons : 54
  epoch 4 loss: 0.16272584473864632 corr: 0.038152976118665104
 num. neurons : 54
  epoch 5 loss: 0.1600164913348424 corr: -0.0016681243109702179
 num. neurons : 54
  epoch 6 loss: 0.15894620219690142 corr: 0.04949945792580136
 num. neurons : 54
  epoch 7 loss: 0.1536534147761835 corr: -0.031127063375139083
 num. neurons : 54
  epoch 8 loss: 0.15215785474827995 corr: 0.008791475715468653
 num. neurons : 54
  epoch 9 loss: 0.14964813739722455 corr: 0.022655608332574215
 num. neurons : 54
  epoch 10 loss: 0.1476116798150402 corr: 0.044790653566095184
 num. neurons : 54
  epoch 11 loss: 0.14635187847591016 corr: 0.0202892210622686

0,1
corr_to_avg,▇▆▇▇▄█▂▅▆█▆█▁█▄▃▂▂▅▂
test_loss,█▆▅▆▅▅▄▅▄▄▃▂▂▂▂▃▂▂▁▂
train_loss,█▇▇▆▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁

0,1
train_loss,2.17812


0,1
corr,▄▄▃▄▄▆█▆▃▃▆▆▇▄▅▇█▅█▅▅▃▅▄▄█▃▄▄▆▅▄▅▁▅▂▇▅▇▆

0,1
corr,0.14791


layer: layer4
0


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▄▅▆▇▇▇▇▇▇▇█████████
test_loss,█▆▅▄▃▃▃▃▂▂▂▂▂▂▂▁▂▂▂▁
train_loss,█▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
train_loss,1.63601


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▄▅▅▆▇▇▇▇██▇▇███████
test_loss,█▆▄▄▂▄▃▂▃▂▂▂▁▁▁▁▂▁▂▂
train_loss,█▅▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
train_loss,1.55041


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
False
model.0.core.stem.1.weight
torch.Size([64])
False
model.0.core.stem.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
False
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
False
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
False
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
False
mod

0,1
corr_to_avg,▁▃▅▆▇▇▇▇▇▇██████████
test_loss,█▆▄▃▃▄▂▂▃▂▂▂▂▂▁▁▂▃▃▁
train_loss,█▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
train_loss,1.77475


0,1
corr,▅▃▃▅▆▆▆▆▅▄▅▄▆▃▅▅▆▅▄▅▅▂█▂▅▆▄▁▄▃▅▆▅▄▅▅▅▇▇▆

0,1
corr,0.51483


0


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

  return pearsonr(pred, avg, axis=0).statistic


  epoch 6 loss: 0.0895608749507386 corr: -0.010644853994121494
 num. neurons : 54
  epoch 7 loss: 0.08440246599691885 corr: 0.0008983314501377844
 num. neurons : 54
  epoch 8 loss: 0.07936011891306183 corr: -0.011547445227205792
 num. neurons : 54
  epoch 9 loss: 0.07542194054450518 corr: 0.02556702493865053
 num. neurons : 54
  epoch 10 loss: 0.07164517249590086 corr: -0.05226258769513504
 num. neurons : 54
  epoch 11 loss: 0.06880115691526437 corr: 0.020837026659881585
 num. neurons : 54
  epoch 12 loss: 0.06694665296578113 corr: -0.04646924055621995
 num. neurons : 54
  epoch 13 loss: 0.06577398011713852 corr: 0.031074572455984176
 num. neurons : 54
  epoch 14 loss: 0.06447304955235234 corr: -0.014251785820937325
 num. neurons : 54
  epoch 15 loss: 0.06368020257832091 corr: -0.01843536130635414
 num. neurons : 54
  epoch 16 loss: 0.06318839700133712 corr: 0.007186064089907325
 num. neurons : 54
  epoch 17 loss: 0.06313854364701259 corr: -0.033630985613485244
 num. neurons : 54
  epo

0,1
corr_to_avg,▅▅▄█▆▄▅▄▇▁▇▁█▄▄▆▃▄▄▆
test_loss,▁▁▁▁▁▂▂▃▃▄▅▅▅▆▇▆▆▆▇█
train_loss,█▇▇▆▆▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁

0,1
train_loss,0.9955


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

  return pearsonr(pred, avg, axis=0).statistic


  epoch 3 loss: 0.09783259299413072 corr: 0.0571565876950959
 num. neurons : 95
  epoch 4 loss: 0.09429870575510395 corr: -0.009612136723438684
 num. neurons : 95
  epoch 5 loss: 0.09038013263522643 corr: -0.019106387471859814
 num. neurons : 95
  epoch 6 loss: 0.08581235745814458 corr: 0.015302924693616644
 num. neurons : 95
  epoch 7 loss: 0.08139744535166556 corr: 0.017152775237711387
 num. neurons : 95
  epoch 8 loss: 0.0771651716756571 corr: 0.0016408419088392993
 num. neurons : 95
  epoch 9 loss: 0.07304532309477242 corr: -0.012590957977489629
 num. neurons : 95
  epoch 10 loss: 0.06947488722376799 corr: -0.0002237118084066802
 num. neurons : 95
  epoch 11 loss: 0.06641681475165002 corr: 0.012271729249938497
 num. neurons : 95
  epoch 12 loss: 0.06374733051704487 corr: -0.052477326790497755
 num. neurons : 95
  epoch 13 loss: 0.061803755335782834 corr: -0.0115277164354361
 num. neurons : 95
  epoch 14 loss: 0.060299044098529514 corr: -0.009490332355161364
 num. neurons : 95
  epo

0,1
corr_to_avg,▅▄█▄▃▅▅▄▄▄▅▁▄▄▆▅▅▇▅▅
test_loss,▁▁▁▁▁▁▂▂▂▃▄▅▅▆▇▆▇▇▇█
train_loss,█▇▇▆▆▅▄▄▃▃▂▂▂▂▁▁▁▁▁▁

0,1
train_loss,0.90042


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

  return pearsonr(pred, avg, axis=0).statistic


  epoch 1 loss: 0.12162270176802045 corr: -0.0008541779102918
 num. neurons : 54
  epoch 2 loss: 0.11334553596290108 corr: -0.04779066109847837
 num. neurons : 54
  epoch 3 loss: 0.1106378796783928 corr: 0.024949134435992535
 num. neurons : 54
  epoch 4 loss: 0.10707813749931085 corr: 0.00878770621258881
 num. neurons : 54
  epoch 5 loss: 0.10205316770997001 corr: 0.005478017692818689
 num. neurons : 54
  epoch 6 loss: 0.09671314786574221 corr: 0.029528884576374807
 num. neurons : 54
  epoch 7 loss: 0.09090571908053827 corr: 0.026629333705906033
 num. neurons : 54
  epoch 8 loss: 0.0852006359129959 corr: 0.015940397026633798
 num. neurons : 54
  epoch 9 loss: 0.0802400352263091 corr: -0.009235411798752955
 num. neurons : 54
  epoch 10 loss: 0.07674710477678355 corr: -0.016111960805095346
 num. neurons : 54
  epoch 11 loss: 0.07388133762574556 corr: -0.007990536428526646
 num. neurons : 54
  epoch 12 loss: 0.0719981153347689 corr: -0.016147644505241376
 num. neurons : 54
  epoch 13 loss

0,1
corr_to_avg,▅▁▇▅▅▇▇▆▄▄▄▄▄▄▁▄█▄▄▅
test_loss,▂▁▁▁▂▂▂▂▃▄▄▅▅▅▆▇▇█▆▇
train_loss,█▇▇▆▅▅▄▃▃▂▂▂▁▁▁▁▁▁▁▁

0,1
train_loss,1.08492


0,1
corr,▄▃▄▄▅▄▄▄▄▂▃▄ ▄▄▄▄▄▄▄▄█▁▅▃▃ ▄▃▃▃▄▃▅▄▅▂▃▃▄

0,1
corr,-0.1203


0


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

  return pearsonr(pred, avg, axis=0).statistic


  epoch 7 loss: 0.10741927217554163 corr: 0.03094047970684141
 num. neurons : 54
  epoch 8 loss: 0.10660050286187066 corr: -0.02708719081486849
 num. neurons : 54
  epoch 9 loss: 0.10549159314897326 corr: -0.014648200766754257
 num. neurons : 54
  epoch 10 loss: 0.10440043655442603 corr: 0.03185780822478451
 num. neurons : 54
  epoch 11 loss: 0.10239568545494551 corr: 0.020513681288178404
 num. neurons : 54
  epoch 12 loss: 0.1004190104684712 corr: -0.0006438589165938195
 num. neurons : 54
  epoch 13 loss: 0.09780482922071292 corr: -0.022887788569414844
 num. neurons : 54
  epoch 14 loss: 0.09506120252020565 corr: -0.02996855050061921
 num. neurons : 54
  epoch 15 loss: 0.09135977491920377 corr: 0.015423672220983838
 num. neurons : 54
  epoch 16 loss: 0.08806018729268769 corr: -0.03552319802434823
 num. neurons : 54
  epoch 17 loss: 0.0847677774782534 corr: -0.00018892632050413892
 num. neurons : 54
  epoch 18 loss: 0.08164231783078041 corr: -0.038410106988290255
 num. neurons : 54
  e

0,1
corr_to_avg,▂▇█▅▇▄▆▂▃▆▅▄▂▂▅▁▄▁▆▃
test_loss,▃▁▁▂▁▂▁▁▁▁▂▂▂▁▂▃▃▄▆█
train_loss,█▅▅▅▅▅▅▅▄▄▄▄▄▃▃▂▂▂▁▁

0,1
train_loss,1.20574


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

  return pearsonr(pred, avg, axis=0).statistic


  epoch 3 loss: 0.10499563772640927 corr: -0.0039136614206282515
 num. neurons : 95
  epoch 4 loss: 0.10447430373486424 corr: -0.017961336662141426
 num. neurons : 95
  epoch 5 loss: 0.10370527256221672 corr: 0.0028239910522236573
 num. neurons : 95
  epoch 6 loss: 0.10310225118517252 corr: 0.0006697366240751614
 num. neurons : 95
  epoch 7 loss: 0.10259285747068715 corr: -0.0025499812889001774
 num. neurons : 95
  epoch 8 loss: 0.10171740404598376 corr: -0.00872647185369193
 num. neurons : 95
  epoch 9 loss: 0.10063786962269489 corr: 0.0013663745595091548
 num. neurons : 95
  epoch 10 loss: 0.09927076575643728 corr: 0.019990014226103588
 num. neurons : 95
  epoch 11 loss: 0.09769157258627927 corr: -0.0054011524467143
 num. neurons : 95
  epoch 12 loss: 0.09526856045448343 corr: -0.003251896944362709
 num. neurons : 95
  epoch 13 loss: 0.09259408094496004 corr: -0.004265935846955718
 num. neurons : 95
  epoch 14 loss: 0.0893483720524773 corr: 0.004253396552748678
 num. neurons : 95
  e

0,1
corr_to_avg,▆▆▅▃▆▅▅▄▅█▅▅▅▆▄▆▄▅▅▁
test_loss,▂▁▁▁▁▁▁▁▁▁▂▁▂▂▂▄▄▅▇█
train_loss,█▆▅▅▅▅▅▅▅▅▄▄▄▃▃▃▂▂▁▁

0,1
train_loss,1.12958


readout input shape: 512
model.0.stim
torch.Size([1, 3, 5, 32, 32])
True
model.0.core.stem.0.weight
torch.Size([64, 3, 3, 7, 7])
True
model.0.core.stem.1.weight
torch.Size([64])
True
model.0.core.stem.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.0.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.0.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.0.conv2.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv1.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv1.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv1.1.bias
torch.Size([64])
True
model.0.core.layer1.1.conv2.0.weight
torch.Size([64, 64, 3, 3, 3])
True
model.0.core.layer1.1.conv2.1.weight
torch.Size([64])
True
model.0.core.layer1.1.conv2.1.bias
torch.Size([64])
True
model.0.core.layer

  return pearsonr(pred, avg, axis=0).statistic


  epoch 4 loss: 0.11747955085962561 corr: 0.001005722887469293
 num. neurons : 54
  epoch 5 loss: 0.11712187380524156 corr: 0.020409460434081844
 num. neurons : 54
  epoch 6 loss: 0.11663937526354244 corr: -0.016212074575959377
 num. neurons : 54
  epoch 7 loss: 0.11613883116442013 corr: 0.003147070193053982
 num. neurons : 54
  epoch 8 loss: 0.11513951899419343 corr: 0.008514657878610736
 num. neurons : 54
  epoch 9 loss: 0.11424807367223284 corr: 0.0026674289003730434
 num. neurons : 54
  epoch 10 loss: 0.11303223787204503 corr: -0.01715585812744317
 num. neurons : 54
  epoch 11 loss: 0.11149459125303862 corr: -0.027166665252279846
 num. neurons : 54
  epoch 12 loss: 0.10994801292825612 corr: 0.04852765797838665
 num. neurons : 54
  epoch 13 loss: 0.10780296641953137 corr: 0.03656120775359657
 num. neurons : 54
  epoch 14 loss: 0.10534149265543169 corr: 0.002161111869659526
 num. neurons : 54
  epoch 15 loss: 0.10225943304737162 corr: -0.012742634174778632
 num. neurons : 54
  epoch 

0,1
corr_to_avg,▁▆▇▅▆▃▅▅▅▃▃█▇▅▄▄▅▃▄▅
test_loss,▅▃▁▃▄▄▃▂▂▃▂▂▄▄▃▃▄▆█▇
train_loss,█▄▄▄▄▄▄▄▄▄▄▄▃▃▃▂▂▂▁▁

0,1
train_loss,1.3625


0,1
corr,█▄▆▅▂▂▂▃▁▂▃▄▃▃▁▂▄▅▄▂▃▄▅▄▂▃▄▄▃▄▂▄▁▅▅▅▁▅▄▆

0,1
corr,0.24321


100


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.14979745270293435 corr: 0.005912113176879469
 num. neurons : 54
  epoch 2 loss: 0.1472646884565 corr: 0.021935849121347516
 num. neurons : 54
  epoch 3 loss: 0.14500016883567526 corr: -0.011521064073509652
 num. neurons : 54
  epoch 4 loss: 0.14138739791917213 corr: 0.02287697537438306
 num. neurons : 54
  epoch 5 loss: 0.13928212931126724 corr: -0.021969425315065492
 num. neurons : 54
  epoch 6 loss: 0.13749420395603887 corr: -0.03009972983220406
 num. neurons : 54
  epoch 7 loss: 0.13461551089345672 corr: -0.011036492438976505
 num. neurons : 54
  epoch 8 loss: 0.13360802526827212 corr: -0.009236659575651819
 num. neurons : 54
  epoch 9 loss: 0.1317056327395969 corr: -0.05956108157421558
 num. neurons : 54
  epoch 10 loss: 0.12953671549573356 corr: -0.03227361937382774
 num. neurons : 54
  epoch 11 loss: 0.12923480122177688 corr: 0.0767717052981632

0,1
corr_to_avg,▄▅▃▅▃▃▃▄▁▂█▃▇▄▇▅▃▄▃▃
test_loss,██▅▅▄▄▅▄▃▂▂▂▁▂▁▂▂▁▁▁
train_loss,█▇▇▆▅▅▄▄▃▃▃▂▃▂▂▂▂▁▁▁

0,1
train_loss,1.94638


readout input shape: 15360
model.1.linear.weight
torch.Size([95, 15360])
True
model.1.linear.bias
torch.Size([95])
True
  epoch 1 loss: 0.13575671642862688 corr: 0.03246384019122222
 num. neurons : 95
  epoch 2 loss: 0.1323393753061744 corr: 0.025991631937607763
 num. neurons : 95
  epoch 3 loss: 0.13003556959292029 corr: 0.020403643781231533
 num. neurons : 95
  epoch 4 loss: 0.12764299334031748 corr: 0.015149203949906525
 num. neurons : 95
  epoch 5 loss: 0.12601739979539242 corr: 0.03708560879294564
 num. neurons : 95
  epoch 6 loss: 0.12378430609927751 corr: -0.00820102141181545
 num. neurons : 95
  epoch 7 loss: 0.1224942006365791 corr: -0.022102132384725304
 num. neurons : 95
  epoch 8 loss: 0.12133856549937064 corr: -0.009591811853176131
 num. neurons : 95
  epoch 9 loss: 0.12060847881576778 corr: 0.013455774599181312
 num. neurons : 95
  epoch 10 loss: 0.11923388810682047 corr: 0.04836671479336574
 num. neurons : 95
  epoch 11 loss: 0.11854477727600417 corr: -0.0073992389267023

0,1
corr_to_avg,▅▅▄▄▆▂▁▂▄▇▂▃▄▄▂▆▁▂█▃
test_loss,█▇▅▅▃▄▄▃▃▂▂▂▂▂▂▂▂▁▁▂
train_loss,█▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
train_loss,1.8182


readout input shape: 15360
model.1.linear.weight
torch.Size([54, 15360])
True
model.1.linear.bias
torch.Size([54])
True
  epoch 1 loss: 0.17413089025729298 corr: 0.025932934084958784
 num. neurons : 54
  epoch 2 loss: 0.16959022830414794 corr: 0.01310625218891367
 num. neurons : 54
  epoch 3 loss: 0.16672834857121546 corr: 0.022653607951812756
 num. neurons : 54
  epoch 4 loss: 0.16244988327229457 corr: 0.01532168308053673
 num. neurons : 54
  epoch 5 loss: 0.1616777854359245 corr: -0.002087453361536395
 num. neurons : 54
  epoch 6 loss: 0.1574547870453206 corr: 0.018594856395161976
 num. neurons : 54
  epoch 7 loss: 0.15461721490100092 corr: 0.04646225354692409
 num. neurons : 54
  epoch 8 loss: 0.15285483777470152 corr: -0.012330388392126452
 num. neurons : 54
  epoch 9 loss: 0.15005918837482143 corr: 0.001142336464616035
 num. neurons : 54
  epoch 10 loss: 0.14807774328825848 corr: 0.028766792452613134
 num. neurons : 54
  epoch 11 loss: 0.1467887496355898 corr: 0.013163213430183897

0,1
corr_to_avg,▆▅▆▅▄▆█▃▄▇▅▁▆▆▃▃▄▁▄▄
test_loss,█▇▇▅▅▄▄▅▄▄▂▃▂▂▂▃▂▃▁▂
train_loss,█▇▇▆▆▅▄▄▄▃▃▂▂▂▂▂▁▂▁▁

0,1
train_loss,2.18435


0,1
corr,▁▄▃▄▅▇▇▃▄█▅▅█▅▅▂▄▆▅▅▅▄▄▄▄▇▆▅▅▄▇▄▇▃▃▃▆▄▆▃

0,1
corr,-0.12494
