In [1]:
import torch
import numpy as np
from experiments.utils import pickle_read
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from torchvision.utils import make_grid

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar
import pytorch_lightning as pl
from neuralpredictors.measures.modules import Corr, PoissonLoss
from torch.nn import Parameter
from energy_model.lucas_gabor_filter import GaborFilter
from energy_model.utils import plot_f, create_grating
from energy_model.energy_model import EnergyModel
from datetime import timedelta

In [2]:
import sys
print(sys.path)

['/auto/budejovice1/mpicek/reCNN_visual_prosthesis', '/auto/budejovice1/mpicek/reCNN_visual_prosthesis', '/opt/conda/lib/python38.zip', '/opt/conda/lib/python3.8', '/opt/conda/lib/python3.8/lib-dynload', '', '/opt/conda/lib/python3.8/site-packages', '/opt/conda/lib/python3.8/site-packages/IPython/extensions', '/auto/vestec1-elixir/home/mpicek/.ipython', '/auto/budejovice1/mpicek/reCNN_visual_prosthesis/predict_neural_responses']


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
poiss = PoissonLoss()
out = poiss(torch.ones(1)*1000, torch.ones(1)*1000)
print(out)


tensor(-5907.7554)




In [5]:

ENTITY = "csng-cuni"
PROJECT = "reCNN_visual_prosthesis"
ground_truth_positions_file_path = "data/antolik/position_dictionary.pickle"
ground_truth_orientations_file_path = "data/antolik/oris.pickle"

model = None

config = {
    # GENERAL
    "seed": 2,
    "batch_size": 10,
    "lr": 0.01,
    "max_epochs": 100,

    # CORE GENERAL CONFIG
    "core_hidden_channels": 8,
    "core_layers": 5,
    "core_input_kern": 7,
    "core_hidden_kern": 9,

    # ROTATION EQUIVARIANCE CORE CONFIG
    "num_rotations": 8,       
    "stride": 1,               
    "upsampling": 2,           
    "rot_eq_batch_norm": True, 
    "stack": -1 ,               
    "depth_separable": True,

    # READOUT CONFIG
    "readout_bias": False,
    "nonlinearity": "softplus",
    
    # REGULARIZATION
    "core_gamma_input": 0.00307424496692959,
    "core_gamma_hidden": 0.28463619129195233,
    "readout_gamma": 0.17,
    "input_regularizer": "LaplaceL2norm", # for RotEqCore - default 
    "use_avg_reg": True,

    "reg_readout_spatial_smoothness": 0.0027,
    "reg_group_sparsity": 0.1,
    "reg_spatial_sparsity": 0.45,

    # TRAINER
    "patience": 7,
    "train_on_val": False, # in case you want to quickly check that your model "compiles" correctly
    "test": True,
    "observed_val_metric": "val/corr",

    "test_average_batch": False,
    "compute_oracle_fraction": False,
    "conservative_oracle": True,
    "jackknife_oracle": True,
    "generate_oracle_figure": False,

    # ANTOLIK
    "region": "region1",
    "dataset_artifact_name": "Antolik_dataset:latest",

    # BOTTLENECK
    "bottleneck_kernel": 15,

    "fixed_sigma": False,
    "init_mu_range": 0.9,
    "init_sigma_range": 0.8,

}

In [6]:
# from Antolik_dataset import AntolikDataModule

path_train = "/storage/brno2/home/mpicek/reCNN_visual_prosthesis/data/antolik_reparametrized/one_trials.pickle"
path_test = "/storage/brno2/home/mpicek/reCNN_visual_prosthesis/data/antolik_reparametrized/ten_trials.pickle"

dataset_config = {
    "train_data_dir": path_train,
    "test_data_dir": path_test,
    "batch_size": config["batch_size"],
    "normalize": True,
    "val_size": 500,
    "brain_crop": None,
    "stimulus_crop": None,
    # "brain_crop": 0.8,
    # "stimulus_crop": "auto",
    # "stimulus_crop": [110, 110],
    # "ground_truth_positions_file_path": "data/antolik/position_dictionary.pickle",
}


In [7]:
import pickle
import pytorch_lightning as pl
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch
from neuralpredictors.data.samplers import SubsetSequentialSampler
from typing import Optional
import pathlib
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import math
from experiments.utils import pickle_read
from Antolik_dataset import AntolikDataModule


In [8]:
dm = AntolikDataModule(**dataset_config)


In [9]:
dm.prepare_data()
dm.setup()

Data loaded successfully!


In [10]:
config.update(
        {
            "input_channels": dm.get_input_shape()[0],
            "input_size_x": dm.get_input_shape()[1],
            "input_size_y": dm.get_input_shape()[2],
            "num_neurons": dm.get_output_shape()[0],
            "mean_activity": dm.get_mean(),
            "filtered_neurons":dm.get_filtered_neurons(),
        }
    )

Loaded precomputed mean from /storage/brno2/home/mpicek/reCNN_visual_prosthesis/data/antolik_reparametrized/one_trials_mean.npy


In [11]:

# resolution = (dm.get_input_shape()[1], dm.get_input_shape()[2])
# xlim = [-dm.get_stimulus_visual_angle()/2, dm.get_stimulus_visual_angle()/2]
# ylim = [-dm.get_stimulus_visual_angle()/2, dm.get_stimulus_visual_angle()/2]

# pos_x, pos_y, orientations = dm.get_ground_truth(ground_truth_positions_file_path, ground_truth_orientations_file_path)

# model = EnergyModel(pos_x, pos_y, orientations, resolution, xlim, ylim, default_ori_shift=90, learning_rate=0.01, counter_clockwise_rotation=True, multivariate=True, **config)


In [12]:
config.update(
    {
        # "ground_truth_positions_file_path": "data/antolik/position_dictionary.pickle",
        # "ground_truth_orientations_file_path": "data/antolik/oris.pickle",
        "ground_truth_positions_file_path": "data/antolik/positions_reparametrized.pickle",
        "ground_truth_orientations_file_path": "data/antolik/oris_reparametrized.pickle",
        "init_to_ground_truth_positions": False,
        "init_to_ground_truth_orientations": False,
        "freeze_positions": False,
        "freeze_orientations": False,
        "orientation_shift": 87.4,
        "factor": 5.5,
        "sample": False,
        "filtered_neurons":None,
    }
)

In [13]:


early_stopping_monitor="val/corr"
early_stopping_mode="max"
model_checkpoint_monitor="val/corr"
model_checkpoint_mode="max"

use_wandb = True


In [14]:
type(dm)

Antolik_dataset.AntolikDataModule

In [20]:
from models import reCNN_bottleneck_CyclicGauss3d_no_scaling

config["positions_minus_x"] = False
config["positions_minus_y"] = True
config["do_not_sample"] = True

model_artifact_name = None
needs_ground_truth = False
model_needs_dataloader = True
model_class = reCNN_bottleneck_CyclicGauss3d_no_scaling
if needs_ground_truth:
    pos_x, pos_y, orientations = dm.get_ground_truth(config["ground_truth_positions_file_path"], config["ground_truth_orientations_file_path"])
    resolution = (dm.get_input_shape()[1], dm.get_input_shape()[2])
    xlim = [-dm.get_stimulus_visual_angle()/2, dm.get_stimulus_visual_angle()/2]
    ylim = [-dm.get_stimulus_visual_angle()/2, dm.get_stimulus_visual_angle()/2]
    # model = model_class(pos_x, pos_y, orientations, resolution, xlim, ylim, **config)
elif model_needs_dataloader:
    model = model_class(dm, **config)
else:
    model = model_class(**config)



mame dataloader????
tak cobude :(((
nemame dataloader!!




In [21]:
config["core_gamma_hidden"] = 0.008931320307500908
config["bottleneck_kernel"] = 15
config["core_gamma_input"] = 0.2384005754453638
config["core_hidden_channels"] = 6
config["core_hidden_kern"] = 19
config["core_input_kern"] = 5
config["core_layers"] = 5
config["depth_separable"] = True
config["lr"] = 0.0005
config["num_rotations"] = 8
config["upsampling"] = 1

In [17]:
# trainer = pl.Trainer(
#     callbacks=[],
#     max_epochs=config["max_epochs"],
#     gpus=[0],
#     logger=False,
#     log_every_n_steps=100,
#     # deterministic=True,
#     enable_checkpointing=True,
#     # fast_dev_run=True,
#     # fast_dev_run=7
#     # limit_train_batches=1
# )

# trainer.fit(
#     model,
#     train_dataloaders=dm.train_dataloader(),
#     val_dataloaders=dm.val_dataloader(),
# )

In [22]:
from datetime import timedelta
import wandb
from Lurz_dataset import LurzDataModule

from models import reCNN_FullFactorized
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar
import pytorch_lightning as pl
from models import reCNN_bottleneck_CyclicGauss3d
from pprint import pprint
from Antolik_dataset import AntolikDataModule


pl.seed_everything(config["seed"], workers=True)    

# init wandb run
run = wandb.init(
    config=config,
    project=PROJECT,
    entity=ENTITY,
)

# Access all hyperparameter values through wandb.config
# config = dict(wandb.config)
# pprint(config)


# setup wandb logger
wandb_logger = WandbLogger(log_model=True)
wandb_logger.watch(model, log="parameters", log_freq=250)

# define callbacks for the training
# early_stop = EarlyStopping(
#     monitor=early_stopping_monitor,
#     patience=config["patience"],
#     mode=early_stopping_mode,
# )
checkpoint_callback = ModelCheckpoint(
    save_top_k=1, monitor=model_checkpoint_monitor, mode=model_checkpoint_mode
)

# class LitProgressBar(ProgressBar):
#     def get_metrics(self, trainer, model):
#         # don't show the version number
#         items = super().get_metrics(trainer, model)
#         items.pop("v_num", None)
#         return items

# bar = LitProgressBar()


# define the trainer
trainer = pl.Trainer(
    callbacks=[checkpoint_callback],
    # max_epochs=config["max_epochs"],
    max_time=timedelta(hours=4),
    # max_epochs=1,
    gpus=[0],
    logger=wandb_logger,
    log_every_n_steps=250,
    # deterministic=True,
    enable_checkpointing=True,
)


trainer.fit(
    model,
    train_dataloaders=dm.train_dataloader(),
    val_dataloaders=dm.val_dataloader(),
    )



Global seed set to 2


[34m[1mwandb[0m: wandb version 0.13.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-78fdadac-cb84-e2a1-5ae2-1111d6ed43f1]

  | Name    | Type                                | Params
----------------------------------------------------------------
0 | loss    | PoissonLoss                         | 0     
1 | corr    | Corr                                | 0     
2 | core    | RotationEquivariant2dCoreBottleneck | 1.8 M 
3 | readout | Gaussian3dCyclicNoScale             | 35.0 K
4 | nonlin  | Softplus                            | 0     
----------------------------------------------------------------
259 K     Trainable params
1.6 M     Non-trainable params
1.8 M     Total params
7.335     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


                                                                      

Global seed set to 2
  rank_zero_warn(


Epoch 0:  16%|█▌        | 729/4500 [12:57<1:07:04,  1.07s/it, loss=-0.537, v_num=qnzf]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
import wandb
from Lurz_dataset import LurzDataModule

from models import reCNN_FullFactorized
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar
import pytorch_lightning as pl
from models import reCNN_bottleneck_CyclicGauss3d
from pprint import pprint
from Antolik_dataset import AntolikDataModule

if use_wandb:

    pl.seed_everything(config["seed"], workers=True)    

    # init wandb run
    run = wandb.init(
        config=config,
        project=PROJECT,
        entity=ENTITY,
    )

    # Access all hyperparameter values through wandb.config
    config = dict(wandb.config)
    pprint(config)


    # setup wandb logger
    wandb_logger = WandbLogger(log_model=True)
    wandb_logger.watch(model, log="parameters", log_freq=250)

    # define callbacks for the training
    early_stop = EarlyStopping(
        monitor=early_stopping_monitor,
        patience=config["patience"],
        mode=early_stopping_mode,
    )
    checkpoint_callback = ModelCheckpoint(
        save_top_k=1, monitor=model_checkpoint_monitor, mode=model_checkpoint_mode
    )

    class LitProgressBar(ProgressBar):
        def get_metrics(self, trainer, model):
            # don't show the version number
            items = super().get_metrics(trainer, model)
            items.pop("v_num", None)
            return items

    bar = LitProgressBar()


    # define the trainer
    trainer = pl.Trainer(
        callbacks=[early_stop, checkpoint_callback, bar],
        max_epochs=config["max_epochs"],
        # max_epochs=1,
        gpus=[0],
        logger=wandb_logger,
        log_every_n_steps=250,
        # deterministic=True,
        enable_checkpointing=True,
    )


    trainer.fit(
        model,
        train_dataloaders=dm.train_dataloader(),
        val_dataloaders=dm.val_dataloader(),
        )

    best_observed_val_metric = (
        checkpoint_callback.best_model_score.cpu().detach().numpy()
    )
    print(
        "Best model's "
        + config["observed_val_metric"]
        + ": "
        + str(best_observed_val_metric)
    )

    if model_artifact_name == None:
        model_artifact_name = model.__str__()
    
    print(model_artifact_name)
    print(model_artifact_name)

    # add best corr to metadata
    metadata = {**config, "best_model_score": best_observed_val_metric}

    # add model artifact
    best_model_artifact = wandb.Artifact(
        model_artifact_name, type="model", metadata=metadata
    )
    print(best_model_artifact)
    print(best_model_artifact)
    best_model_artifact.add_file(checkpoint_callback.best_model_path)
    run.log_artifact(best_model_artifact)

    # say to wandb that the best val/corr of the model is the best one
    # and not the last one!! (it is the default behavour!!)
    run.summary[config["observed_val_metric"]] = best_observed_val_metric

    print(checkpoint_callback.best_model_path)

    model = model_class.load_from_checkpoint(checkpoint_callback.best_model_path)

    if config["test"]:
        dm.model_performances(model, trainer)



        # result_artifact = wandb.Artifact(name="RESULT_" + model_artifact_name, type="result",
        #     metadata=results[0])
        # run.log_artifact(result_artifact)

else:
    pl.seed_everything(config["seed"], workers=True)
    pprint(config)


    # define callbacks for the training
    early_stop = EarlyStopping(
        monitor=early_stopping_monitor,
        patience=config["patience"],
        mode=early_stopping_mode,
    )
    checkpoint_callback = ModelCheckpoint(
        save_top_k=1, monitor=model_checkpoint_monitor, mode=model_checkpoint_mode
    )

    class LitProgressBar(ProgressBar):
        def get_metrics(self, trainer, model):
            # don't show the version number
            items = super().get_metrics(trainer, model)
            items.pop("v_num", None)
            return items

    bar = LitProgressBar()

    # define the trainer
    trainer = pl.Trainer(
        callbacks=[early_stop, checkpoint_callback, bar],
        max_epochs=config["max_epochs"],
        gpus=[0],
        # logger=wandb_logger,
        log_every_n_steps=1,
        # deterministic=True,
        enable_checkpointing=True,
    )

    if config["train_on_val"]:
        trainer.fit(
            model,
            train_dataloaders=dm.val_dataloader(),
            val_dataloaders=dm.val_dataloader(),
        )

    else:
        trainer.fit(
            model,
            train_dataloaders=dm.train_dataloader(),
            val_dataloaders=dm.val_dataloader(),
        )

    best_observed_val_metric = (
        checkpoint_callback.best_model_score.cpu().detach().numpy()
    )
    print(
        "Best model's "
        + config["observed_val_metric"]
        + ": "
        + str(best_observed_val_metric)
    )

    # add best corr to metadata
    metadata = {**config, "best_model_score": best_observed_val_metric}

    print(checkpoint_callback.best_model_path)

    model = model_class.load_from_checkpoint(checkpoint_callback.best_model_path)

    if config["test"]:
        dm.model_performances(model, trainer)

        # result_artifact = wandb.Artifact(name="RESULT_" + model_artifact_name, type="result",
        #     metadata=results[0])
        # run.log_artifact(result_artifact)

Global seed set to 2
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcsng-cuni[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.13.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


{'batch_size': 10,
 'bottleneck_kernel': 15,
 'compute_oracle_fraction': False,
 'conservative_oracle': True,
 'core_gamma_hidden': 0.28463619129195233,
 'core_gamma_input': 0.00307424496692959,
 'core_hidden_channels': 8,
 'core_hidden_kern': 9,
 'core_input_kern': 7,
 'core_layers': 5,
 'dataset_artifact_name': 'Antolik_dataset:latest',
 'depth_separable': True,
 'do_not_sample': True,
 'factor': 5.5,
 'filtered_neurons': None,
 'fixed_sigma': False,
 'freeze_orientations': False,
 'freeze_positions': False,
 'generate_oracle_figure': False,
 'ground_truth_orientations_file_path': 'data/antolik/oris_reparametrized.pickle',
 'ground_truth_positions_file_path': 'data/antolik/positions_reparametrized.pickle',
 'init_mu_range': 0.9,
 'init_sigma_range': 0.8,
 'init_to_ground_truth_orientations': False,
 'init_to_ground_truth_positions': False,
 'input_channels': 1,
 'input_regularizer': 'LaplaceL2norm',
 'input_size_x': 110,
 'input_size_y': 110,
 'jackknife_oracle': True,
 'lr': 0.01,
 

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [GPU-fe783ca2-3d3f-1409-1ba8-3ae437d45687]

  | Name    | Type                                | Params
----------------------------------------------------------------
0 | loss    | PoissonLoss                         | 0     
1 | corr    | Corr                                | 0     
2 | core    | RotationEquivariant2dCoreBottleneck | 458 K 
3 | readout | Gaussian3dCyclicNoScale             | 35.0 K
4 | nonlin  | Softplus                            | 0     
----------------------------------------------------------------
135 K     Trainable params
358 K     Non-trainable params
493 K     Total params
1.973     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


                                                                      

Global seed set to 2
  rank_zero_warn(


Epoch 0:  17%|█▋        | 83/500 [00:13<01:06,  6.26it/s, loss=-0.401] 

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


AttributeError: 'NoneType' object has no attribute 'cpu'