In [13]:
from model_trainer import run_training_without_logging
from models import reCNN_bottleneck_CyclicGauss3d_no_scaling
from model_trainer import Lurz_dataset_preparation_function, Antolik_dataset_preparation_function


ENTITY = "csng-cuni"
PROJECT = "reCNN_visual_prosthesis"

model = None

config = {
    # GENERAL
    "seed": 42,
    "batch_size": 10,
    "lr": 0.001,
    "max_epochs": 1,

    # CORE GENERAL CONFIG
    "core_hidden_channels": 8,
    "core_layers": 5,
    "core_input_kern": 7,
    "core_hidden_kern": 9,

    # ROTATION EQUIVARIANCE CORE CONFIG
    "num_rotations": 8,       
    "stride": 1,               
    "upsampling": 2,           
    "rot_eq_batch_norm": True, 
    "stack": -1 ,               
    "depth_separable": True,

    # READOUT CONFIG
    "readout_bias": False,
    "nonlinearity": "softplus",
    
    # REGULARIZATION
    "core_gamma_input": 0.00307424496692959,
    "core_gamma_hidden": 0.28463619129195233,
    "readout_gamma": 0.17,
    "input_regularizer": "LaplaceL2norm", # for RotEqCore - default 
    "use_avg_reg": True,

    "reg_readout_spatial_smoothness": 0.0027,
    "reg_group_sparsity": 0.1,
    "reg_spatial_sparsity": 0.45,

    # TRAINER
    "patience": 10,
    "train_on_val": True, # in case you want to quickly check that your model "compiles" correctly
    "test": True,
    "observed_val_metric": "val/corr",

    "test_average_batch": False,
    "compute_oracle_fraction": False,
    "conservative_oracle": True,
    "jackknife_oracle": True,
    "generate_oracle_figure": False,

    # ANTOLIK
    "region": "region1",
    "dataset_artifact_name": "Antolik_dataset:latest",

    # BOTTLENECK
    "bottleneck_kernel": 15,

    "fixed_sigma": False,
    "init_mu_range": 0.9,
    "init_sigma_range": 0.8,

}



In [14]:
from Antolik_dataset import AntolikDataModule

path_train = "/storage/brno2/home/mpicek/reCNN_visual_prosthesis/data/antolik/one_trials.pickle"
path_test = "/storage/brno2/home/mpicek/reCNN_visual_prosthesis/data/antolik/ten_trials.pickle"

dataset_config = {
    "train_data_dir": path_test,
    "test_data_dir": path_test,
    "batch_size": config["batch_size"],
    "normalize": True,
    "val_size": 500,
    "brain_crop": 0.8,
    "stimulus_crop": "auto",
    "ground_truth_positions_file_path": "data/antolik/position_dictionary.pickle",
}


dm = AntolikDataModule(**dataset_config)
dm.prepare_data()
dm.setup()

# update config for initialization of model (<- certain config parameters depend on data)
config.update(
    {
        "input_channels": dm.get_input_shape()[0],
        "input_size_x": dm.get_input_shape()[1],
        "input_size_y": dm.get_input_shape()[2],
        "num_neurons": dm.get_output_shape()[0],
        "mean_activity": dm.get_mean(),
    }
)

12px will be discarded from each side.
Data loaded successfully!
Loaded precomputed mean from /storage/brno2/home/mpicek/reCNN_visual_prosthesis/data/antolik/ten_trials_mean.npy


In [15]:
dm.get_output_shape()[0]

4714

In [16]:
from pprint import pprint
pprint(config)

{'batch_size': 10,
 'bottleneck_kernel': 15,
 'compute_oracle_fraction': False,
 'conservative_oracle': True,
 'core_gamma_hidden': 0.28463619129195233,
 'core_gamma_input': 0.00307424496692959,
 'core_hidden_channels': 8,
 'core_hidden_kern': 9,
 'core_input_kern': 7,
 'core_layers': 5,
 'dataset_artifact_name': 'Antolik_dataset:latest',
 'depth_separable': True,
 'fixed_sigma': False,
 'generate_oracle_figure': False,
 'init_mu_range': 0.9,
 'init_sigma_range': 0.8,
 'input_channels': 1,
 'input_regularizer': 'LaplaceL2norm',
 'input_size_x': 86,
 'input_size_y': 86,
 'jackknife_oracle': True,
 'lr': 0.001,
 'max_epochs': 1,
 'mean_activity': tensor([0.2360, 0.1985, 0.2985,  ..., 1.5215, 1.7172, 1.6257]),
 'nonlinearity': 'softplus',
 'num_neurons': 4714,
 'num_rotations': 8,
 'observed_val_metric': 'val/corr',
 'patience': 10,
 'readout_bias': False,
 'readout_gamma': 0.17,
 'reg_group_sparsity': 0.1,
 'reg_readout_spatial_smoothness': 0.0027,
 'reg_spatial_sparsity': 0.45,
 'region

In [17]:
model = run_training_without_logging(config, Antolik_dataset_preparation_function, ENTITY, PROJECT, model_class=reCNN_bottleneck_CyclicGauss3d_no_scaling)

Global seed set to 42


{'batch_size': 10,
 'bottleneck_kernel': 15,
 'compute_oracle_fraction': False,
 'conservative_oracle': True,
 'core_gamma_hidden': 0.28463619129195233,
 'core_gamma_input': 0.00307424496692959,
 'core_hidden_channels': 8,
 'core_hidden_kern': 9,
 'core_input_kern': 7,
 'core_layers': 5,
 'dataset_artifact_name': 'Antolik_dataset:latest',
 'depth_separable': True,
 'fixed_sigma': False,
 'generate_oracle_figure': False,
 'init_mu_range': 0.9,
 'init_sigma_range': 0.8,
 'input_channels': 1,
 'input_regularizer': 'LaplaceL2norm',
 'input_size_x': 86,
 'input_size_y': 86,
 'jackknife_oracle': True,
 'lr': 0.001,
 'max_epochs': 1,
 'mean_activity': tensor([0.2360, 0.1985, 0.2985,  ..., 1.5215, 1.7172, 1.6257]),
 'nonlinearity': 'softplus',
 'num_neurons': 4714,
 'num_rotations': 8,
 'observed_val_metric': 'val/corr',
 'patience': 10,
 'readout_bias': False,
 'readout_gamma': 0.17,
 'reg_group_sparsity': 0.1,
 'reg_readout_spatial_smoothness': 0.0027,
 'reg_spatial_sparsity': 0.45,
 'region

  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Loaded precomputed mean from /storage/brno2/home/mpicek/reCNN_visual_prosthesis/data/antolik/ten_trials_mean.npy


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-78cbf6b9-71e3-3f13-3119-6e49ac6f7aef]

  | Name    | Type                                | Params
----------------------------------------------------------------
0 | loss    | PoissonLoss                         | 0     
1 | corr    | Corr                                | 0     
2 | core    | RotationEquivariant2dCoreBottleneck | 458 K 
3 | readout | Gaussian3dCyclicNoScale             | 33.0 K
4 | nonlin  | Softplus                            | 0     
----------------------------------------------------------------
133 K     Trainable params
358 K     Non-trainable params
491 K     Total params
1.965     Total estimated model params size (MB)


Validation sanity check: 100%|██████████| 2/2 [00:00<00:00, 13.00it/s]

  rank_zero_warn(
  rank_zero_warn(


                                                                      

Global seed set to 42
  rank_zero_warn(


Epoch 0: 100%|██████████| 100/100 [00:22<00:00,  4.51it/s, loss=1.22]
Best model's val/corr: 0.0064005638
/auto/budejovice1/mpicek/reCNN_visual_prosthesis/lightning_logs/version_16/checkpoints/epoch=0-step=49.ckpt


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-78cbf6b9-71e3-3f13-3119-6e49ac6f7aef]
  rank_zero_warn(
  rank_zero_warn(


Testing: 100%|██████████| 50/50 [00:03<00:00, 13.80it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-78cbf6b9-71e3-3f13-3119-6e49ac6f7aef]


Testing: 100%|██████████| 500/500 [00:36<00:00, 13.64it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-78cbf6b9-71e3-3f13-3119-6e49ac6f7aef]


Testing: 100%|██████████| 500/500 [00:08<00:00, 60.94it/s]
Validation dataset:
    Correlation: 0.0064 
Test dataset with averaged responses of repeated trials:
    Correlation: 0.0136 
    Fraction oracle conservative: 0.0314 
    Fraction oracle jackknife: 0.0564 
