In [1]:
seed = 41

In [2]:
# !pip install -e ../../

In [3]:
# Import statements 
import torch
import numpy as np 
import deeplake
# !activeloop login -t eyJhbGciOiJIUzUxMiIsImlhdCI6MTY4NzAxNTQwOSwiZXhwIjoxNjkzNDk1MzgyfQ.eyJpZCI6Im5jaW1hc3pld3NraSJ9.WOnJ5R0wd8PDCRetum-H_Zfx44YrOVaRTfCCYulob1dwOsKZR2q9oREuN5QzKE3-J3qKuQNrc-TIFgITQ448lg
import os
os.environ["DEEPLAKE_DOWNLOAD_PATH"] = "/gpfs01/berens/data/data/dLGN_hmov2/deeplake/"
# del os.environ["DEEPLAKE_DOWNLOAD_PATH"] 
import matplotlib.pyplot as plt

from nnfabrik.utility.nn_helpers import set_random_seed
set_random_seed(seed)
from nnfabrik.builder import get_trainer # get_trainer basically just used to 

from sensorium.utility.scores import get_correlations # was using this for manual validation scoring
from sensorium.models.make_model import make_video_model # still using this, used to knit together video encoder model

from torchvision.transforms import Compose # needed to compose transforms together
from neuralpredictors.data.transforms import (AddBehaviorAsChannels,
                                              AddPupilCenterAsChannels,
                                              ChangeChannelsOrder, CutVideos,
                                              ExpandChannels, NeuroNormalizer,
                                              ScaleInputs, SelectInputChannel,
                                              Subsample, Subsequence, ToTensor) # this is for creating preprocessing pipelines

from dlgn_cnn.dataloading import deeplake_loader_dict
from dlgn_cnn.dataloading.deeplake_transforms import *



# Load Data

To allow for comparison to methods on V1 data, data is stored in local Deeplake datasets, which provide integration of pytorch dataloaders.  Dataloaders are necessary to construct models, and take in torchvision (and neuralpredictors/custom) transform chains to determine preprocessing.

### Question: How does preprocessing of data affect training of model?

We start with a minimal model, aiming to predict response tensor from video tensor.  We compare 4 conditions:

- Normalize video (z score) and standardize response (divide by std)
- Normalize video only
- Standardize response only
- No scaling of values

In all cases, we subsample the length of the sequences used as training data to consist of 60 frames (1 sec).  In principle, this should be aligned to the optogenetic trace.  We introduce a deeplake transform SubsequenceByOpto.

In [4]:
ds = deeplake.load('/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/',access_method='local')
ds.summary()

** Loaded local copy of dataset from /gpfs01/berens/data/data/dLGN_hmov2/deeplake/_gpfs01_berens_data_data_dLGN_hmov2_deeplake_dlgn_53x33_. Downloaded on: Fri Aug 25 20:10:39 2023
Dataset(path='/gpfs01/berens/data/data/dLGN_hmov2/deeplake/_gpfs01_berens_data_data_dLGN_hmov2_deeplake_dlgn_53x33_', tensors=['id', 'opto', 'responses', 'videos'])

  tensor     htype           shape           dtype  compression
  -------   -------         -------         -------  ------- 
    id       text          (288, 1)           str     None   
   opto     generic     (288, 122, 300)     float32   None   
 responses  generic     (288, 122, 300)     float32   None   
  videos    generic  (288, 1, 300, 33, 53)  float32   None   




In [6]:
ds.min_len

288

In [7]:
frames_per_sample = 60

# 4 lambda functions returning dictionary of compose torch transforms, keyed by data tensor names

norm_vid_stand_resp = lambda ds: {
        'videos': Compose(
            [
                NormalizeVideo(deeplake_dataset=ds),
                Tupelize('videos'), # necessary to input to neuralpredictors transform
                Subsequence(frames_per_sample, channel_first=['videos']),
                Detupelize('videos') # necessary to be return by dataloader correctly
            ]
        ), 
        'responses': Compose(
            [
                StandardizeResponse(deeplake_dataset=ds),
                Tupelize('responses'), # necessary to input to neuralpredictors transforms 
                Subsequence(frames_per_sample, channel_first=['videos']),
                Detupelize('responses'), # necessary to be return by dataloader correctly
            ]
        ),
        'id': None,
        } 

norm_vid = lambda ds: {
        'videos': Compose(
            [
                NormalizeVideo(deeplake_dataset=ds),
                Tupelize('videos'), # necessary to input to neuralpredictors transform
                Subsequence(frames_per_sample, channel_first=['videos']),
                # SubsequenceByOpto(deeplake_dataset=ds, frames_per_sample, channel_first=['videos']),
                Detupelize('videos') # necessary to be return by dataloader correctly
            ]
        ), 
        'responses': Compose(
            [
                Tupelize('responses'), # necessary to input to neuralpredictors transforms 
                # SubsequenceByOpto(deeplake_dataset=ds, frames_per_sample, channel_first=['videos']),
                Subsequence(frames_per_sample, channel_first=['videos']),
                Detupelize('responses'), # necessary to be return by dataloader correctly
            ]
        ),
        'id': None,
        }

stand_resp = lambda ds: {
        'videos': Compose(
            [
                Tupelize('videos'), # necessary to input to neuralpredictors transform
                Subsequence(frames_per_sample, channel_first=['videos']),
                Detupelize('videos') # necessary to be return by dataloader correctly
            ]
        ), 
        'responses': Compose(
            [
                StandardizeResponse(deeplake_dataset=ds),
                Tupelize('responses'), # necessary to input to neuralpredictors transforms 
                Subsequence(frames_per_sample, channel_first=['videos']),
                Detupelize('responses'), # necessary to be return by dataloader correctly
            ]
        ),
        'id': None,
        }

subseq_only = lambda ds: {
        'videos': Compose(
            [
                Tupelize('videos'), # necessary to input to neuralpredictors transform
                Subsequence(frames_per_sample, channel_first=['videos']),
                Detupelize('videos') # necessary to be return by dataloader correctly
            ]
        ), 
        'responses': Compose(
            [
                Tupelize('responses'), # necessary to input to neuralpredictors transforms 
                Subsequence(frames_per_sample, channel_first=['videos']),
                Detupelize('responses'), # necessary to be return by dataloader correctly
            ]
        ),
        'id': None,
        }


In [8]:
# paths for dataset
local_paths = [
    '/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/'
]

user_token = 'eyJhbGciOiJIUzUxMiIsImlhdCI6MTY4NzAxNTQwOSwiZXhwIjoxNjkzNDk1MzgyfQ.eyJpZCI6Im5jaW1hc3pld3NraSJ9.WOnJ5R0wd8PDCRetum-H_Zfx44YrOVaRTfCCYulob1dwOsKZR2q9oREuN5QzKE3-J3qKuQNrc-TIFgITQ448lg'
org_id = 'sinzlab'

print("Loading data..")

loaders_1 = deeplake_loader_dict(
    paths=local_paths,
    preproc = norm_vid_stand_resp,
    batch_size=4,
    scale=1,
    max_frame=None,
    frames=60, # @ 60 Hz = 1 sec, so as not to 
    offset=-1,
    include_behavior=False,
    include_pupil_centers=False,
    cuda='cuda:0',
    use_api=False,
    user_token=user_token,
    org_id=org_id
)

loaders_2 = deeplake_loader_dict(
    paths=local_paths,
    preproc = norm_vid,
    batch_size=4,
    scale=1,
    max_frame=None,
    frames=60, # @ 60 Hz = 1 sec, so as not to 
    offset=-1,
    include_behavior=False,
    include_pupil_centers=False,
    cuda='cuda:0',
    use_api=False,
    user_token=user_token,
    org_id=org_id
)

loaders_3 = deeplake_loader_dict(
    paths=local_paths,
    preproc = stand_resp,
    batch_size=4,
    scale=1,
    max_frame=None,
    frames=60, # @ 60 Hz = 1 sec, so as not to 
    offset=-1,
    include_behavior=False,
    include_pupil_centers=False,
    cuda='cuda:0',
    use_api=False,
    user_token=user_token,
    org_id=org_id
)

loaders_4 = deeplake_loader_dict(
    paths=local_paths,
    preproc = subseq_only,
    batch_size=4,
    scale=1,
    max_frame=None,
    frames=60, # @ 60 Hz = 1 sec, so as not to 
    offset=-1,
    include_behavior=False,
    include_pupil_centers=False,
    cuda='cuda:0',
    use_api=False,
    user_token=user_token,
    org_id=org_id
)

Loading data..
/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/ loaded successfully.





/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/ loaded successfully.





/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/ loaded successfully.





/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/ loaded successfully.





In [11]:
loaders_1['train']['/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/'].dataset.summary() #80% of 288 is 231

Dataset(path='/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/', index=Index([(165, 167, 166, 158, 157, 156, 227, 225, 226, 155, 153, 154, 184, 183, 185, 211, 212, 210, 150, 152, 151, 283, 284, 282, 10, 11, 9, 168, 169, 170, 79, 78, 80, 256, 257, 255, 70, 71, 69, 267, 269, 268, 195, 196, 197, 241, 242, 240, 37, 36, 38, 107, 105, 106, 194, 193, 192)]), tensors=['id', 'opto', 'responses', 'videos'])

  tensor     htype          shape           dtype  compression
  -------   -------        -------         -------  ------- 
    id       text          (57, 1)           str     None   
   opto     generic     (57, 122, 300)     float32   None   
 responses  generic     (57, 122, 300)     float32   None   
  videos    generic  (57, 1, 300, 33, 53)  float32   None   


In [30]:
rng = np.random.default_rng()

In [5]:
# Plotting predictions vs. observations
val_ds =  data_loaders['oracle']['/gpfs01/berens/data/data/dLGN_hmov2/deeplake/dlgn_53x33/'].dataset


    
# Let's test this!  Need to generate model predictions
# neurs = [30,40,45]
# scenes = [0,4,6,7]
# fake_pred = np.zeros_like(val_ds.responses)
# plot_obs_pred_psth(scenes, neurs, val_ds, fake_pred)

In [6]:
# core hyper_param ranges
# hidden_channels = [8, 16, 32]
input_kern_sizes = [3,5,7]
hidden_kern_sizes = [3,5,7]
depths = [2,3,4]
# GRU hyper_param ranges
# rec_channels_gru = [8, 16, 32]
input_kern_sizes_gru = [3,5,7]
hidden_kern_sizes_gru = [3,5,7]


In [13]:
make_core_dict = lambda hc, ik, hk, d, gamm_in, gamm_hidd: dict(
                    input_channels=1,
                    hidden_channels = hc,
                    input_kern = ik,
                    hidden_kern = hk,
                    layers = d,
                    gamma_input=gamm_in,
                    skip=0,
                    pad_input=False,
                    final_nonlinearity=False,
                    bias=True,
                    momentum=0.9,
                    batch_norm=True,
                    hidden_dilation=1,
                    laplace_padding=None,
                    input_regularizer="LaplaceL2norm",
                    stack=-1,
                    depth_separable=False,
                    linear=False,
                    attention_conv=False,
                    hidden_padding=None,
                    use_avg_reg=False,
                    final_batchnorm_scale=True,
                    gamma_hidden=gamm_hidd,
                )

make_gru_dict = lambda ic, rc, ik, rk: dict(
    input_channels = ic, # must be the last hidden channels from the core_dict
    rec_channels = rc, # determines the input channels to the readouts
    input_kern = ik,
    rec_kern = rk,
    gamma_rec=0,
)

In [24]:
# i will try to emulate the conv GRU model from the paper above, so downsizing:
core_dict = dict(
    input_channels=1,
    hidden_channels=32, # 64 -> 32
    input_kern=7, # 9 -> 7
    hidden_kern=3, # 7 -> 3
    layers=3, # 4 -> 3
    gamma_input=500,
    skip=0,
    pad_input=False,
    final_nonlinearity=False,
    bias=True,
    momentum=0.9,
    batch_norm=True,
    hidden_dilation=1,
    laplace_padding=None,
    input_regularizer="LaplaceL2norm",
    stack=-1,
    depth_separable=False,
    linear=False,
    attention_conv=False,
    hidden_padding=None,
    use_avg_reg=False,
    final_batchnorm_scale=True,
    gamma_hidden=500_000,
)

In [15]:
shifter_dict = dict(
    gamma_shifter=0,
    shift_layers=3,
    input_channels_shifter=2,
    hidden_channels_shifter=5,
)


readout_dict = dict(
    bias=True,
    init_mu_range=0.2,
    init_sigma=1.0,
    gamma_readout=0.0,
    gauss_type='full',
    # grid_mean_predictor={
    #     'type': 'cortex',
    #     'input_dimensions': 2,
    #     'hidden_layers': 1,
    #     'hidden_features': 30,
    #     'final_tanh': True
    # },
    grid_mean_predictor=None, # readout grid will not be predicted from some source_grid (cortical coordinates) but rather the mean will be treated as a parameter to be optimized directly
    share_features=False,
    share_grid=False,
    shared_match_ids=None,
    gamma_grid_dispersion=0.0,
)


In [9]:
# import wandb
# wandb.login()
# wandb API key:
# 2da1f9973215ee744c07a4bf6e31e168e10c02e0

### Training

In [12]:
trainer_fn = "sensorium.training.video_training_loop.standard_trainer"

make_trainer_dict = lambda hash: { # hash will be a code organizing where to save model dicts based on hyperparams
    'dataloaders' : data_loaders,
    'seed' : 111,
    'use_wandb' : True,
    'wandb_project': 'neuralpredictors for dLGN Sys ID',
    'wandb_entity': 'nwcimaszewski',
    'verbose': True,
    'lr_decay_steps': 10,
    'lr_init': 0.005, # increase from 0.005
    'device' : f"cuda",
    'detach_core' : False,
    'checkpoint_save_path' : '/gpfs01/berens/user/ncimaszewski/my-docker-folder/ncimaszewski/dl_for_sensorium/sensorium_2023/state_dicts/partial_grid_search/',
    'checkpoint_save_prefix': hash,
    'patience' : 10 # from 5 to 20 to 15 to 10.  This is number of epochs it will allow 
                 }

# Training models on different dataloaders

In [18]:
ik = 7 # kernel size on input
hk = 5 # on hidden layers
ikg = 7 # on input for gru
hkg = 5 # on hidden for gru
depth = 3
for i, data_loaders in enumerate([loaders_1, loaders_2, loaders_3, loaders_4]):
    core_dict = make_core_dict(32, ik, hk, 3, gamm_in=0, gamm_hidd=0) # no regularization
    gru_dict = make_gru_dict(32, 32, ikg, hkg) # input channels always same as hc
    
    # make full model
    gru_2d_model = make_video_model(data_loaders,
                                    seed,
                                    core_dict=core_dict.copy(),
                                    core_type='2D',
                                    readout_dict=readout_dict.copy(),
                                    readout_type='gaussian',
                                    use_gru=True,
                                    gru_dict=gru_dict.copy(),
                                    shifter_dict=shifter_dict,
                                    shifter_type='MLP',
                                    use_shifter=False,
                                    from_deeplake=True
                                   )
    hash = f'ik{ik}_hk{hk}_ikg{ikg}_hkg{hkg}_gi0_gh0_loader{i}'
    trainer_config = make_trainer_dict(hash)
    trainer = get_trainer(trainer_fn=trainer_fn, 
     trainer_config=trainer_config) # just calls nnfabrik's resolve_fn function, which just identifies the callable object indicated by path string, and feeds in config dict using partial(,**)
    
    # train
    validation_score, trainer_output, state_dict = trainer(gru_2d_model)

optim_step_count = 1


[34m[1mwandb[0m: Currently logged in as: [33mnwcimaszewski[0m. Use [1m`wandb login --relogin`[0m to force relogin


  transform = lambda x: (x - self._inputs_mean) / self._inputs_std


AssertionError: model prediction is too short (60 vs 280)

In [None]:
%debug

> [0;32m/gpfs01/berens/user/ncimaszewski/my-docker-folder/ncimaszewski/dl_for_sensorium/sensorium_2023/sensorium/utility/scores.py[0m(53)[0;36mmodel_predictions[0;34m()[0m
[0;32m     51 [0;31m                    [0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;34m-[0m[0mresp[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m [0;34m:[0m[0;34m,[0m [0;34m:[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m                )
[0m[0;32m---> 53 [0;31m                assert (
[0m[0;32m     54 [0;31m                    [0mout[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m1[0m[0;34m][0m [0;34m==[0m [0mresp[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m                ), f"model prediction is too short ({out.shape[1]} vs {resp.shape[-1]})"
[0m


ipdb>  out.shape


torch.Size([4, 60, 60])


ipdb>  resp.shape


torch.Size([4, 60, 280])


ipdb>  u


> [0;32m/gpfs01/berens/user/ncimaszewski/my-docker-folder/ncimaszewski/dl_for_sensorium/sensorium_2023/sensorium/utility/scores.py[0m(87)[0;36mget_correlations[0;34m()[0m
[0;32m     85 [0;31m    [0;31m# print(device, 'score.get_correlations')[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     86 [0;31m    [0;32mfor[0m [0mk[0m[0;34m,[0m [0mv[0m [0;32min[0m [0mdl[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 87 [0;31m        target, output = model_predictions(
[0m[0;32m     88 [0;31m            [0mdataloader[0m[0;34m=[0m[0mv[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     89 [0;31m            [0mmodel[0m[0;34m=[0m[0mmodel[0m[0;34m,[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  target.shape


*** NameError: name 'target' is not defined


ipdb>  d


> [0;32m/gpfs01/berens/user/ncimaszewski/my-docker-folder/ncimaszewski/dl_for_sensorium/sensorium_2023/sensorium/utility/scores.py[0m(53)[0;36mmodel_predictions[0;34m()[0m
[0;32m     51 [0;31m                    [0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0;34m-[0m[0mresp[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m [0;34m:[0m[0;34m,[0m [0;34m:[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     52 [0;31m                )
[0m[0;32m---> 53 [0;31m                assert (
[0m[0;32m     54 [0;31m                    [0mout[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m1[0m[0;34m][0m [0;34m==[0m [0mresp[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     55 [0;31m                ), f"model prediction is too short ({out.shape[1]} vs {resp.shape[-1]})"
[0m


ipdb>  out


tensor([[[1.5291, 1.5403, 1.5159,  ..., 1.3960, 1.8160, 1.7571],
         [1.5473, 1.5417, 1.5160,  ..., 1.4157, 1.8287, 1.7740],
         [1.5615, 1.5452, 1.5195,  ..., 1.4295, 1.8422, 1.7889],
         ...,
         [1.5497, 1.5509, 1.5317,  ..., 1.4365, 1.8169, 1.7528],
         [1.5463, 1.5403, 1.5363,  ..., 1.4388, 1.8143, 1.7503],
         [1.5478, 1.5259, 1.5327,  ..., 1.4428, 1.8135, 1.7564]],

        [[1.5216, 1.5636, 1.5105,  ..., 1.3770, 1.8197, 1.7517],
         [1.5292, 1.5745, 1.5135,  ..., 1.3811, 1.8318, 1.7618],
         [1.5335, 1.5811, 1.5144,  ..., 1.3830, 1.8389, 1.7678],
         ...,
         [1.5645, 1.6055, 1.5329,  ..., 1.4095, 1.8509, 1.7859],
         [1.5645, 1.6054, 1.5334,  ..., 1.4099, 1.8513, 1.7858],
         [1.5638, 1.6053, 1.5338,  ..., 1.4097, 1.8516, 1.7858]],

        [[1.5166, 1.5587, 1.5122,  ..., 1.3660, 1.7947, 1.7327],
         [1.5206, 1.5715, 1.5199,  ..., 1.3680, 1.7978, 1.7339],
         [1.5232, 1.5820, 1.5262,  ..., 1.3702, 1.8004, 1.

ipdb>  out.shape


torch.Size([4, 60, 60])


In [15]:
# partial GRID SEARCH
# core hyper_param ranges
# hidden_channels = [8, 16, 32]
input_kern_sizes = [7]
hidden_kern_sizes = [7]
depths = [2,3,4]
# GRU hyper_param ranges
# rec_channels_gru = [8, 16, 32]
input_kern_sizes_gru = [7]
hidden_kern_sizes_gru = [7]

# gammas_in = [100, 500]
# gammas_hidd = [5000, 50000]

gammas_in = [0]
gammas_hidd = [0]

for ik in input_kern_sizes: # 4 * 
    for hk in hidden_kern_sizes: # 4 *
        for ikg in input_kern_sizes_gru: #4 *
            for hkg in hidden_kern_sizes_gru:
                for gi in gammas_in:
                    for gh in gammas_hidd:
                        core_dict = make_core_dict(32, ik, hk, 3, gamm_in=gi, gamm_hidd=gh)
                        gru_dict = make_gru_dict(32, 32, ikg, hkg) # input channels always same as hc
        
                        # make full model
                        gru_2d_model = make_video_model(data_loaders,
                                                        seed,
                                                        core_dict=core_dict.copy(),
                                                        core_type='2D',
                                                        readout_dict=readout_dict.copy(),
                                                        readout_type='gaussian',
                                                        use_gru=True,
                                                        gru_dict=gru_dict.copy(),
                                                        shifter_dict=shifter_dict,
                                                        shifter_type='MLP',
                                                        use_shifter=False,
                                                        from_deeplake=True
                                                       )
                        hash = f'ik{ik}_hk{hk}_ikg{ikg}_hkg{hkg}_gi{gi}_gh{gh}'
                        trainer_config = make_trainer_dict(hash)
                        trainer = get_trainer(trainer_fn=trainer_fn, 
                         trainer_config=trainer_config) # just calls nnfabrik's resolve_fn function, which just identifies the callable object indicated by path string, and feeds in config dict using partial(,**)
                        
                        # train
                        validation_score, trainer_output, state_dict = trainer(gru_2d_model)

  xavier_normal(m.weight.data)
  init.constant(m.bias.data, 0.0)


optim_step_count = 1


  transform = lambda x: (x - self._inputs_mean) / self._inputs_std
Epoch 1: 100%|██████████| 29/29 [00:02<00:00, 12.23it/s]


Epoch 1, Batch 28, Train loss -3044064.0, Validation loss -1512640.625
EPOCH=1  validation_correlation=0.008808107115328312


Epoch 2: 100%|██████████| 29/29 [00:02<00:00, 12.29it/s]


Epoch 2, Batch 28, Train loss -3210791.75, Validation loss -1594852.625
EPOCH=2  validation_correlation=0.01962447538971901


Epoch 3: 100%|██████████| 29/29 [00:02<00:00, 12.09it/s]


Epoch 3, Batch 28, Train loss -3092348.75, Validation loss -1536211.75
EPOCH=3  validation_correlation=0.03594963625073433


Epoch 4: 100%|██████████| 29/29 [00:02<00:00, 12.18it/s]


Epoch 4, Batch 28, Train loss -3343928.75, Validation loss -1662413.125
EPOCH=4  validation_correlation=0.031828779727220535


Epoch 5: 100%|██████████| 29/29 [00:02<00:00, 12.30it/s]


Epoch 5, Batch 28, Train loss -2811462.75, Validation loss -1396373.5
EPOCH=5  validation_correlation=0.04337739199399948


Epoch 6: 100%|██████████| 29/29 [00:02<00:00, 12.10it/s]


Epoch 6, Batch 28, Train loss -3257239.75, Validation loss -1618342.0
EPOCH=6  validation_correlation=0.05243707075715065


Epoch 7: 100%|██████████| 29/29 [00:02<00:00, 12.35it/s]


Epoch 7, Batch 28, Train loss -3085271.0, Validation loss -1531985.0
EPOCH=7  validation_correlation=0.02249753102660179


Epoch 8: 100%|██████████| 29/29 [00:02<00:00, 12.05it/s]


Epoch 8, Batch 28, Train loss -3253544.5, Validation loss -1616982.25
EPOCH=8  validation_correlation=0.03376353904604912


Epoch 9: 100%|██████████| 29/29 [00:02<00:00, 11.93it/s]


Epoch 9, Batch 28, Train loss -2870209.5, Validation loss -1427568.0
EPOCH=9  validation_correlation=0.025380603969097137


Epoch 10: 100%|██████████| 29/29 [00:02<00:00, 12.17it/s]


Epoch 10, Batch 28, Train loss -3017347.5, Validation loss -1498947.875
EPOCH=10  validation_correlation=0.04662341624498367


Epoch 11: 100%|██████████| 29/29 [00:02<00:00, 11.88it/s]


Epoch 11, Batch 28, Train loss -2877307.75, Validation loss -1429251.625
EPOCH=11  validation_correlation=0.04269106313586235


Epoch 12: 100%|██████████| 29/29 [00:02<00:00, 12.34it/s]


Epoch 12, Batch 28, Train loss -2902257.0, Validation loss -1443786.875
EPOCH=12  validation_correlation=0.03886541724205017


Epoch 13: 100%|██████████| 29/29 [00:02<00:00, 12.27it/s]


Epoch 13, Batch 28, Train loss -2826824.25, Validation loss -1403337.125
EPOCH=13  validation_correlation=0.05335870012640953


Epoch 14: 100%|██████████| 29/29 [00:02<00:00, 12.52it/s]


Epoch 14, Batch 28, Train loss -3127803.75, Validation loss -1554732.875
EPOCH=14  validation_correlation=0.02662484720349312


Epoch 15: 100%|██████████| 29/29 [00:02<00:00, 12.36it/s]


Epoch 15, Batch 28, Train loss -3338384.5, Validation loss -1658915.25
EPOCH=15  validation_correlation=0.0388505682349205
Epoch 00015: reducing learning rate of group 0 to 1.5000e-03.


Epoch 16: 100%|██████████| 29/29 [00:02<00:00, 12.21it/s]


Epoch 16, Batch 28, Train loss -3140321.0, Validation loss -1559852.0
EPOCH=16  validation_correlation=0.036591194570064545


Epoch 17: 100%|██████████| 29/29 [00:02<00:00, 12.20it/s]


Epoch 17, Batch 28, Train loss -3180350.75, Validation loss -1580313.625
EPOCH=17  validation_correlation=0.01849479228258133


Epoch 18: 100%|██████████| 29/29 [00:02<00:00, 12.27it/s]


Epoch 18, Batch 28, Train loss -3114967.0, Validation loss -1547323.5
EPOCH=18  validation_correlation=0.04276779294013977


Epoch 19: 100%|██████████| 29/29 [00:02<00:00, 12.26it/s]


Epoch 19, Batch 28, Train loss -2970227.5, Validation loss -1475239.75
EPOCH=19  validation_correlation=0.0428583137691021


Epoch 20: 100%|██████████| 29/29 [00:02<00:00, 12.01it/s]


Epoch 20, Batch 28, Train loss -3230021.0, Validation loss -1604081.25
EPOCH=20  validation_correlation=0.042417123913764954


Epoch 21: 100%|██████████| 29/29 [00:02<00:00, 12.27it/s]


Epoch 21, Batch 28, Train loss -2885177.75, Validation loss -1432795.75
EPOCH=21  validation_correlation=0.05652221292257309


Epoch 22: 100%|██████████| 29/29 [00:02<00:00, 12.32it/s]


Epoch 22, Batch 28, Train loss -3376944.25, Validation loss -1677310.625
EPOCH=22  validation_correlation=0.03637035936117172


Epoch 23: 100%|██████████| 29/29 [00:02<00:00, 12.28it/s]


Epoch 23, Batch 28, Train loss -2859561.75, Validation loss -1420891.625
EPOCH=23  validation_correlation=0.036441802978515625


Epoch 24: 100%|██████████| 29/29 [00:02<00:00, 12.49it/s]


Epoch 24, Batch 28, Train loss -3437706.25, Validation loss -1708107.75
EPOCH=24  validation_correlation=0.04175708070397377


Epoch 25: 100%|██████████| 29/29 [00:02<00:00, 12.37it/s]


Epoch 25, Batch 28, Train loss -3050816.0, Validation loss -1513748.875
EPOCH=25  validation_correlation=0.03841095790266991


Epoch 26: 100%|██████████| 29/29 [00:02<00:00, 12.20it/s]


Epoch 26, Batch 28, Train loss -3122812.0, Validation loss -1550765.125
EPOCH=26  validation_correlation=0.025681883096694946
Epoch 00026: reducing learning rate of group 0 to 4.5000e-04.


Epoch 27: 100%|██████████| 29/29 [00:02<00:00, 12.10it/s]


Epoch 27, Batch 28, Train loss -3389281.5, Validation loss -1683545.375
EPOCH=27  validation_correlation=0.02927527390420437


Epoch 28: 100%|██████████| 29/29 [00:02<00:00, 12.14it/s]


Epoch 28, Batch 28, Train loss -3094254.25, Validation loss -1536211.875
EPOCH=28  validation_correlation=0.04775271937251091


Epoch 29: 100%|██████████| 29/29 [00:02<00:00, 12.02it/s]


Epoch 29, Batch 28, Train loss -3250261.5, Validation loss -1614681.125
EPOCH=29  validation_correlation=0.05569639429450035


Epoch 30: 100%|██████████| 29/29 [00:02<00:00, 11.26it/s]


Epoch 30, Batch 28, Train loss -2894189.75, Validation loss -1438153.125
EPOCH=30  validation_correlation=0.026785975322127342


Epoch 31: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 31, Batch 28, Train loss -3092906.5, Validation loss -1536771.625
EPOCH=31  validation_correlation=0.03263973072171211


Epoch 32: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 32, Batch 28, Train loss -2948677.25, Validation loss -1464750.875
EPOCH=32  validation_correlation=0.039328522980213165


Epoch 33: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 33, Batch 28, Train loss -3211344.75, Validation loss -1595071.875
EPOCH=33  validation_correlation=0.04527928680181503


Epoch 34: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 34, Batch 28, Train loss -3199791.25, Validation loss -1589487.0
EPOCH=34  validation_correlation=0.01916675642132759


Epoch 35: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


Epoch 35, Batch 28, Train loss -2956297.0, Validation loss -1467966.625
EPOCH=35  validation_correlation=0.02425498329102993


Epoch 36: 100%|██████████| 29/29 [00:04<00:00,  6.33it/s]


Epoch 36, Batch 28, Train loss -2763345.75, Validation loss -1371156.375
EPOCH=36  validation_correlation=0.033394038677215576


Epoch 37: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 37, Batch 28, Train loss -3048525.0, Validation loss -1514442.5
EPOCH=37  validation_correlation=0.02662312611937523
Epoch 00037: reducing learning rate of group 0 to 1.3500e-04.


Epoch 38: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 38, Batch 28, Train loss -3448114.25, Validation loss -1713943.375
EPOCH=38  validation_correlation=0.047305576503276825


Epoch 39: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 39, Batch 28, Train loss -3028672.75, Validation loss -1505140.25
EPOCH=39  validation_correlation=0.04390624538064003


Epoch 40: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


Epoch 40, Batch 28, Train loss -2823775.5, Validation loss -1402256.5
EPOCH=40  validation_correlation=0.04237544909119606


Epoch 41: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 41, Batch 28, Train loss -2991563.0, Validation loss -1485120.625
EPOCH=41  validation_correlation=0.03403312712907791


Epoch 42: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 42, Batch 28, Train loss -2931841.0, Validation loss -1457008.625
EPOCH=42  validation_correlation=0.032695066183805466


Epoch 43: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 43, Batch 28, Train loss -3182621.75, Validation loss -1582086.25
EPOCH=43  validation_correlation=0.03475270792841911


Epoch 44: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 44, Batch 28, Train loss -3541676.5, Validation loss -1760919.5
EPOCH=44  validation_correlation=0.031675852835178375


Epoch 45: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 45, Batch 28, Train loss -2921083.0, Validation loss -1451099.0
EPOCH=45  validation_correlation=0.01824856922030449


Epoch 46: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 46, Batch 28, Train loss -3142729.25, Validation loss -1561455.75
EPOCH=46  validation_correlation=0.025677485391497612


Epoch 47: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 47, Batch 28, Train loss -3232164.0, Validation loss -1605556.625
EPOCH=47  validation_correlation=0.0222874004393816


Epoch 48: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 48, Batch 28, Train loss -3168244.25, Validation loss -1573290.375
EPOCH=48  validation_correlation=0.028540775179862976
Epoch 00048: reducing learning rate of group 0 to 1.0000e-04.


Epoch 49: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 49, Batch 28, Train loss -2933458.5, Validation loss -1457000.125
EPOCH=49  validation_correlation=0.03746483474969864


Epoch 50: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 50, Batch 28, Train loss -3061574.75, Validation loss -1521404.0
EPOCH=50  validation_correlation=0.014930603094398975


Epoch 51: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


Epoch 51, Batch 28, Train loss -3026483.75, Validation loss -1502895.0
EPOCH=51  validation_correlation=0.039563123136758804


Epoch 52: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 52, Batch 28, Train loss -3068256.25, Validation loss -1524225.875
EPOCH=52  validation_correlation=0.035954419523477554


Epoch 53: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 53, Batch 28, Train loss -3147438.75, Validation loss -1563356.0
EPOCH=53  validation_correlation=0.030214203521609306


Epoch 54: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 54, Batch 28, Train loss -3555701.0, Validation loss -1765507.875
EPOCH=54  validation_correlation=0.040519606322050095


Epoch 55: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 55, Batch 28, Train loss -3303584.5, Validation loss -1641575.375
EPOCH=55  validation_correlation=0.050430748611688614


Epoch 56: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 56, Batch 28, Train loss -3098847.0, Validation loss -1539161.75
EPOCH=56  validation_correlation=0.023500891402363777


Epoch 57: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 57, Batch 28, Train loss -3153553.0, Validation loss -1566108.25
EPOCH=57  validation_correlation=0.038743067532777786


Epoch 58: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 58, Batch 28, Train loss -2949076.0, Validation loss -1465129.5
EPOCH=58  validation_correlation=0.02767900750041008


Epoch 59: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 59, Batch 28, Train loss -2992764.25, Validation loss -1486467.125
EPOCH=59  validation_correlation=0.030466726049780846


Epoch 60: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 60, Batch 28, Train loss -3417065.5, Validation loss -1697830.625
EPOCH=60  validation_correlation=0.017960738390684128


Epoch 61: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


Epoch 61, Batch 28, Train loss -3181587.25, Validation loss -1579938.0
EPOCH=61  validation_correlation=0.0504816398024559


Epoch 62: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 62, Batch 28, Train loss -2996190.5, Validation loss -1488393.375
EPOCH=62  validation_correlation=0.030433228239417076


Epoch 63: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


Epoch 63, Batch 28, Train loss -2811869.75, Validation loss -1395914.375
EPOCH=63  validation_correlation=0.029975982382893562


Epoch 64: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


Epoch 64, Batch 28, Train loss -2780306.25, Validation loss -1380478.5
EPOCH=64  validation_correlation=0.04155348613858223


Epoch 65: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 65, Batch 28, Train loss -2966024.5, Validation loss -1473172.75
EPOCH=65  validation_correlation=0.03333694487810135


Epoch 66: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 66, Batch 28, Train loss -3156216.25, Validation loss -1567588.25
EPOCH=66  validation_correlation=0.02635342627763748


Epoch 67: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 67, Batch 28, Train loss -2827353.75, Validation loss -1405374.75
EPOCH=67  validation_correlation=0.036450523883104324


Epoch 68: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


Epoch 68, Batch 28, Train loss -3002681.0, Validation loss -1492063.5
EPOCH=68  validation_correlation=0.014659143052995205


Epoch 69: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 69, Batch 28, Train loss -3240902.25, Validation loss -1608973.125
EPOCH=69  validation_correlation=0.03248017653822899


Epoch 70: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 70, Batch 28, Train loss -3185794.5, Validation loss -1582790.0
EPOCH=70  validation_correlation=0.020398210734128952


Epoch 71: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 71, Batch 28, Train loss -3056599.0, Validation loss -1518737.875
EPOCH=71  validation_correlation=0.03298470005393028


Epoch 72: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 72, Batch 28, Train loss -3223293.75, Validation loss -1600569.625
EPOCH=72  validation_correlation=0.032263148576021194


Epoch 73: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 73, Batch 28, Train loss -3098748.5, Validation loss -1539061.375
EPOCH=73  validation_correlation=0.031812816858291626


Epoch 74: 100%|██████████| 29/29 [00:04<00:00,  6.33it/s]


Epoch 74, Batch 28, Train loss -2816709.25, Validation loss -1398896.125
EPOCH=74  validation_correlation=0.03766736388206482


Epoch 75: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


Epoch 75, Batch 28, Train loss -3158151.0, Validation loss -1567696.625
EPOCH=75  validation_correlation=0.035313840955495834


Epoch 76: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 76, Batch 28, Train loss -2993480.5, Validation loss -1486586.0
EPOCH=76  validation_correlation=0.01822786033153534


Epoch 77: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 77, Batch 28, Train loss -2896496.5, Validation loss -1439504.875
EPOCH=77  validation_correlation=0.022129526361823082


Epoch 78: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


Epoch 78, Batch 28, Train loss -3030134.0, Validation loss -1503983.25
EPOCH=78  validation_correlation=0.03400972858071327


Epoch 79: 100%|██████████| 29/29 [00:04<00:00,  6.33it/s]


Epoch 79, Batch 28, Train loss -3009569.5, Validation loss -1494234.125
EPOCH=79  validation_correlation=0.02742766961455345


Epoch 80: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 80, Batch 28, Train loss -3068736.75, Validation loss -1523749.625
EPOCH=80  validation_correlation=0.026747742667794228


Epoch 81: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


Epoch 81, Batch 28, Train loss -3027382.75, Validation loss -1503898.375
EPOCH=81  validation_correlation=0.023135459050536156


Epoch 82: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


Epoch 82, Batch 28, Train loss -2895511.75, Validation loss -1437991.875
EPOCH=82  validation_correlation=0.02694825455546379


Epoch 83: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


Epoch 83, Batch 28, Train loss -3188386.25, Validation loss -1583335.25
EPOCH=83  validation_correlation=0.026987960562109947


Epoch 84: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


Epoch 84, Batch 28, Train loss -3275089.25, Validation loss -1626495.875
EPOCH=84  validation_correlation=0.0338171124458313


Epoch 85: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 85, Batch 28, Train loss -2974709.25, Validation loss -1477882.75
EPOCH=85  validation_correlation=0.02417650818824768


Epoch 86: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 86, Batch 28, Train loss -2876969.0, Validation loss -1428878.25
EPOCH=86  validation_correlation=0.04217561334371567


Epoch 87: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


Epoch 87, Batch 28, Train loss -3187351.0, Validation loss -1582804.625
EPOCH=87  validation_correlation=0.03466757386922836


Epoch 88: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 88, Batch 28, Train loss -3252884.25, Validation loss -1615555.0
EPOCH=88  validation_correlation=0.039778899401426315


Epoch 89: 100%|██████████| 29/29 [00:04<00:00,  6.36it/s]


Epoch 89, Batch 28, Train loss -3217136.75, Validation loss -1598443.25
EPOCH=89  validation_correlation=0.03733843192458153


Epoch 90: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


Epoch 90, Batch 28, Train loss -3461868.75, Validation loss -1717322.625
EPOCH=90  validation_correlation=0.04346952587366104


Epoch 91: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


Epoch 91, Batch 28, Train loss -3353525.75, Validation loss -1666388.875
EPOCH=91  validation_correlation=0.025743409991264343


Epoch 92: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


Epoch 92, Batch 28, Train loss -3187686.0, Validation loss -1582736.125
EPOCH=92  validation_correlation=0.038959257304668427


Epoch 93: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


Epoch 93, Batch 28, Train loss -3166756.0, Validation loss -1572414.0
EPOCH=93  validation_correlation=0.04294988512992859


Epoch 94: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


Epoch 94, Batch 28, Train loss -3307984.0, Validation loss -1643933.75
EPOCH=94  validation_correlation=0.04549966752529144


Epoch 95: 100%|██████████| 29/29 [00:04<00:00,  6.36it/s]


Epoch 95, Batch 28, Train loss -3049295.0, Validation loss -1514471.0
EPOCH=95  validation_correlation=0.041933294385671616


Epoch 96: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


Epoch 96, Batch 28, Train loss -2928837.0, Validation loss -1454926.5
EPOCH=96  validation_correlation=0.04245451092720032


Epoch 97: 100%|██████████| 29/29 [00:04<00:00,  6.40it/s]


Epoch 97, Batch 28, Train loss -2910065.0, Validation loss -1445490.0
EPOCH=97  validation_correlation=0.03646809607744217


Epoch 98: 100%|██████████| 29/29 [00:04<00:00,  6.40it/s]


Epoch 98, Batch 28, Train loss -3245052.5, Validation loss -1611110.875
EPOCH=98  validation_correlation=0.03151732683181763


Epoch 99: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


Epoch 99, Batch 28, Train loss -2998210.75, Validation loss -1488791.625
EPOCH=99  validation_correlation=0.032652597874403


Epoch 100: 100%|██████████| 29/29 [00:04<00:00,  6.36it/s]


Epoch 100, Batch 28, Train loss -3099435.75, Validation loss -1539420.375
EPOCH=100  validation_correlation=0.035057492554187775


Epoch 101: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


Epoch 101, Batch 28, Train loss -3096302.5, Validation loss -1538156.25
EPOCH=101  validation_correlation=0.04813479259610176


Epoch 102: 100%|██████████| 29/29 [00:04<00:00,  6.42it/s]


Epoch 102, Batch 28, Train loss -3184502.25, Validation loss -1581078.875
EPOCH=102  validation_correlation=0.023285534232854843


Epoch 103: 100%|██████████| 29/29 [00:04<00:00,  6.40it/s]


Epoch 103, Batch 28, Train loss -3341838.5, Validation loss -1659910.125
EPOCH=103  validation_correlation=0.05187623202800751


Epoch 104: 100%|██████████| 29/29 [00:04<00:00,  6.43it/s]


Epoch 104, Batch 28, Train loss -3403906.75, Validation loss -1690631.625
EPOCH=104  validation_correlation=0.019444040954113007


 FINAL validation_correlation [ 1.66832134e-01  6.68299152e-03 -5.60565665e-02 -9.24981311e-02
  1.40516207e-01 -1.27410337e-01  4.42668386e-02  5.32032028e-02
  1.21511891e-01 -1.06642015e-01  1.01399925e-02 -5.54291643e-02
  4.30341139e-02  1.25531822e-01  1.11362629e-01 -1.89343691e-01
 -1.58071369e-02  4.00059437e-03 -6.68150112e-02 -5.61446808e-02
 -3.20960246e-02  1.26764163e-01  2.00023241e-02  1.60980567e-01
  3.03074345e-02 -1.04253581e-02 -8.42160136e-02  1.34250328e-01
 -1.14237182e-02  2.01769397e-01 -4.45366614e-02 -2.51471791e-02
  2.33806465e-02 -2.66440008e-02  7.36148059e-02 -2.81053837e-02
 -1.15646392e-01  7.67625421e-02 -2.85811145e-02  3.20972241e-02
  2.05981374e-01 -1.74530894e-02 -9.20342728e-02  8.20232034e-02
  8.21525007e-02  2.22058401e-01  9.14444476e-02  1.07301921e-01
 -1.12932362e-01 -2.69530900e-02  2.40163915e-02  2.6098197

VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Batch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Epoch Train loss,█▃▃▅▃▂▆▃▅▅▁▃▅▄▇▅▄▄▄▄▃▂▄▃▅▅▂▄▄▅▅▅▆▃▁▂▅▄▅▂
Epoch validation loss,▅▅▃▃▇▄▄▆▂▁▂▇▆▆▅█▄▇▄▅▄▅▆▄█▇▃▃█▇▅▇▆▃▁▄▆▃▅▂
mean_validation_correlation,▁▅█▅▆▄▅▆▅▆▄▄▆▃▄▆▅▃▄▆▄▃▄█▆▅▅▅▆▃▄▄▃▆▇▆▆▅▇▃

0,1
Batch,3016.0
Epoch,104.0
Epoch Train loss,-90115296.0
Epoch validation loss,-1690631.625
mean_validation_correlation,0.01944


  xavier_normal(m.weight.data)
  init.constant(m.bias.data, 0.0)


optim_step_count = 1


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666959635913372, max=1.0)…

  transform = lambda x: (x - self._inputs_mean) / self._inputs_std
Epoch 1: 100%|██████████| 29/29 [00:04<00:00,  6.42it/s]


Epoch 1, Batch 28, Train loss -3439569.25, Validation loss -1703889.75
EPOCH=1  validation_correlation=0.0278420839458704


Epoch 2: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 2, Batch 28, Train loss -3340961.25, Validation loss -1657936.0
EPOCH=2  validation_correlation=0.022460544481873512


Epoch 3: 100%|██████████| 29/29 [00:04<00:00,  6.44it/s]


Epoch 3, Batch 28, Train loss -3377478.75, Validation loss -1678723.5
EPOCH=3  validation_correlation=0.041703153401613235


Epoch 4: 100%|██████████| 29/29 [00:04<00:00,  6.50it/s]


Epoch 4, Batch 28, Train loss -3244992.0, Validation loss -1610637.25
EPOCH=4  validation_correlation=0.039412569254636765


Epoch 5: 100%|██████████| 29/29 [00:04<00:00,  6.43it/s]


Epoch 5, Batch 28, Train loss -3502093.25, Validation loss -1739558.625
EPOCH=5  validation_correlation=0.054330114275217056


Epoch 6: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 6, Batch 28, Train loss -3197887.5, Validation loss -1588576.75
EPOCH=6  validation_correlation=0.03691743686795235


Epoch 7: 100%|██████████| 29/29 [00:04<00:00,  6.43it/s]


Epoch 7, Batch 28, Train loss -3034916.75, Validation loss -1506637.25
EPOCH=7  validation_correlation=0.0076507641933858395


Epoch 8: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 8, Batch 28, Train loss -3186060.25, Validation loss -1582855.25
EPOCH=8  validation_correlation=0.03806640952825546


Epoch 9: 100%|██████████| 29/29 [00:04<00:00,  6.44it/s]


Epoch 9, Batch 28, Train loss -3221868.75, Validation loss -1598241.125
EPOCH=9  validation_correlation=0.03841795399785042


Epoch 10: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 10, Batch 28, Train loss -3515027.0, Validation loss -1743911.75
EPOCH=10  validation_correlation=0.052620287984609604


Epoch 11: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


Epoch 11, Batch 28, Train loss -3279408.0, Validation loss -1630477.75
EPOCH=11  validation_correlation=0.052787721157073975


Epoch 12: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 12, Batch 28, Train loss -3008468.25, Validation loss -1497388.0
EPOCH=12  validation_correlation=0.03122696466743946


Epoch 13: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 13, Batch 28, Train loss -3311950.0, Validation loss -1646036.875
EPOCH=13  validation_correlation=0.05460822209715843


Epoch 14: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 14, Batch 28, Train loss -3430743.75, Validation loss -1703082.25
EPOCH=14  validation_correlation=0.02444327063858509


Epoch 15: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 15, Batch 28, Train loss -3137927.5, Validation loss -1559010.125
EPOCH=15  validation_correlation=0.06276834011077881
Epoch 00015: reducing learning rate of group 0 to 1.5000e-03.


Epoch 16: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 16, Batch 28, Train loss -3323784.0, Validation loss -1651088.5
EPOCH=16  validation_correlation=0.04006921872496605


Epoch 17: 100%|██████████| 29/29 [00:04<00:00,  6.50it/s]


Epoch 17, Batch 28, Train loss -3182595.0, Validation loss -1580337.625
EPOCH=17  validation_correlation=0.033001258969306946


Epoch 18: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 18, Batch 28, Train loss -3341862.5, Validation loss -1662933.25
EPOCH=18  validation_correlation=0.042165786027908325


Epoch 19: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 19, Batch 28, Train loss -3179163.0, Validation loss -1579903.125
EPOCH=19  validation_correlation=0.04851388558745384


Epoch 20: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


Epoch 20, Batch 28, Train loss -3250663.0, Validation loss -1615869.75
EPOCH=20  validation_correlation=0.04386761412024498


Epoch 21: 100%|██████████| 29/29 [00:04<00:00,  6.50it/s]


Epoch 21, Batch 28, Train loss -3301047.0, Validation loss -1642591.375
EPOCH=21  validation_correlation=0.065159872174263


Epoch 22: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 22, Batch 28, Train loss -3460659.0, Validation loss -1720474.75
EPOCH=22  validation_correlation=0.04422236979007721


Epoch 23: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 23, Batch 28, Train loss -3069699.25, Validation loss -1525197.25
EPOCH=23  validation_correlation=0.04048560559749603


Epoch 24: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 24, Batch 28, Train loss -3452952.75, Validation loss -1713042.125
EPOCH=24  validation_correlation=0.05993000790476799


Epoch 25: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


Epoch 25, Batch 28, Train loss -3155766.0, Validation loss -1565450.25
EPOCH=25  validation_correlation=0.042081911116838455


Epoch 26: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 26, Batch 28, Train loss -3043852.75, Validation loss -1511909.375
EPOCH=26  validation_correlation=0.04252323508262634
Epoch 00026: reducing learning rate of group 0 to 4.5000e-04.


Epoch 27: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 27, Batch 28, Train loss -3126546.25, Validation loss -1550144.0
EPOCH=27  validation_correlation=0.044109415262937546


Epoch 28: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 28, Batch 28, Train loss -3176333.5, Validation loss -1580014.25
EPOCH=28  validation_correlation=0.05393802002072334


Epoch 29: 100%|██████████| 29/29 [00:04<00:00,  6.50it/s]


Epoch 29, Batch 28, Train loss -3535140.0, Validation loss -1755765.625
EPOCH=29  validation_correlation=0.0531134307384491


Epoch 30: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 30, Batch 28, Train loss -3169869.75, Validation loss -1574169.0
EPOCH=30  validation_correlation=0.021127749234437943


Epoch 31: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 31, Batch 28, Train loss -3202601.75, Validation loss -1593327.25
EPOCH=31  validation_correlation=0.03668339550495148


Epoch 32: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 32, Batch 28, Train loss -3297326.25, Validation loss -1638602.75
EPOCH=32  validation_correlation=0.03975650668144226


Epoch 33: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


Epoch 33, Batch 28, Train loss -3350585.0, Validation loss -1664971.75
EPOCH=33  validation_correlation=0.049861013889312744


Epoch 34: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 34, Batch 28, Train loss -3314286.25, Validation loss -1647569.625
EPOCH=34  validation_correlation=0.04301892966032028


Epoch 35: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 35, Batch 28, Train loss -3244744.25, Validation loss -1610747.125
EPOCH=35  validation_correlation=0.033434294164180756


Epoch 36: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 36, Batch 28, Train loss -3353304.75, Validation loss -1664921.625
EPOCH=36  validation_correlation=0.036853816360235214


Epoch 37: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 37, Batch 28, Train loss -3220924.5, Validation loss -1600529.375
EPOCH=37  validation_correlation=0.04208189249038696
Epoch 00037: reducing learning rate of group 0 to 1.3500e-04.


Epoch 38: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 38, Batch 28, Train loss -3296304.0, Validation loss -1637483.125
EPOCH=38  validation_correlation=0.04267651587724686


Epoch 39: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 39, Batch 28, Train loss -3094927.0, Validation loss -1539333.375
EPOCH=39  validation_correlation=0.03835664317011833


Epoch 40: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 40, Batch 28, Train loss -3361474.0, Validation loss -1671702.375
EPOCH=40  validation_correlation=0.056432269513607025


Epoch 41: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 41, Batch 28, Train loss -3014586.25, Validation loss -1497297.875
EPOCH=41  validation_correlation=0.03279043734073639


Epoch 42: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 42, Batch 28, Train loss -3300423.5, Validation loss -1639955.625
EPOCH=42  validation_correlation=0.04796835407614708


Epoch 43: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]


Epoch 43, Batch 28, Train loss -3222298.25, Validation loss -1600282.375
EPOCH=43  validation_correlation=0.043182894587516785


Epoch 44: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 44, Batch 28, Train loss -3293942.25, Validation loss -1637774.0
EPOCH=44  validation_correlation=0.05177898332476616


Epoch 45: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 45, Batch 28, Train loss -3093645.25, Validation loss -1537337.375
EPOCH=45  validation_correlation=0.020483853295445442


Epoch 46: 100%|██████████| 29/29 [00:04<00:00,  6.58it/s]


Epoch 46, Batch 28, Train loss -3341136.75, Validation loss -1660543.5
EPOCH=46  validation_correlation=0.041552767157554626


Epoch 47: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 47, Batch 28, Train loss -3018900.75, Validation loss -1499079.625
EPOCH=47  validation_correlation=0.042467232793569565


Epoch 48: 100%|██████████| 29/29 [00:04<00:00,  6.50it/s]


Epoch 48, Batch 28, Train loss -3307231.25, Validation loss -1639681.25
EPOCH=48  validation_correlation=0.025174468755722046
Epoch 00048: reducing learning rate of group 0 to 1.0000e-04.


Epoch 49: 100%|██████████| 29/29 [00:04<00:00,  6.57it/s]


Epoch 49, Batch 28, Train loss -3379015.25, Validation loss -1677818.625
EPOCH=49  validation_correlation=0.04382961615920067


Epoch 50: 100%|██████████| 29/29 [00:04<00:00,  6.00it/s]


Epoch 50, Batch 28, Train loss -3141555.75, Validation loss -1559597.625
EPOCH=50  validation_correlation=0.029163477942347527


Epoch 51: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 51, Batch 28, Train loss -3244919.75, Validation loss -1611336.375
EPOCH=51  validation_correlation=0.05093991011381149


Epoch 52: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 52, Batch 28, Train loss -3222883.0, Validation loss -1601123.5
EPOCH=52  validation_correlation=0.044554561376571655


Epoch 53: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 53, Batch 28, Train loss -3288882.0, Validation loss -1631340.625
EPOCH=53  validation_correlation=0.023455748334527016


Epoch 54: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 54, Batch 28, Train loss -3261736.0, Validation loss -1620669.0
EPOCH=54  validation_correlation=0.047550030052661896


Epoch 55: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 55, Batch 28, Train loss -3168474.5, Validation loss -1574137.75
EPOCH=55  validation_correlation=0.052927155047655106


Epoch 56: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 56, Batch 28, Train loss -3096640.75, Validation loss -1539524.375
EPOCH=56  validation_correlation=0.018190084025263786


Epoch 57: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 57, Batch 28, Train loss -3460684.0, Validation loss -1719352.375
EPOCH=57  validation_correlation=0.06056036800146103


Epoch 58: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 58, Batch 28, Train loss -3146854.25, Validation loss -1564127.0
EPOCH=58  validation_correlation=0.03025047667324543


Epoch 59: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 59, Batch 28, Train loss -3270496.75, Validation loss -1623829.25
EPOCH=59  validation_correlation=0.020787183195352554


Epoch 60: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]


Epoch 60, Batch 28, Train loss -3105217.25, Validation loss -1542597.5
EPOCH=60  validation_correlation=0.02320898324251175


Epoch 61: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 61, Batch 28, Train loss -3071776.0, Validation loss -1524790.75
EPOCH=61  validation_correlation=0.058055274188518524


Epoch 62: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


Epoch 62, Batch 28, Train loss -3424630.75, Validation loss -1700248.5
EPOCH=62  validation_correlation=0.03368370607495308


Epoch 63: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 63, Batch 28, Train loss -3143411.0, Validation loss -1562129.625
EPOCH=63  validation_correlation=0.05561711639165878


Epoch 64: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 64, Batch 28, Train loss -3230914.5, Validation loss -1604567.25
EPOCH=64  validation_correlation=0.04035442695021629


Epoch 65: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 65, Batch 28, Train loss -3475048.5, Validation loss -1725236.25
EPOCH=65  validation_correlation=0.04025490954518318


Epoch 66: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 66, Batch 28, Train loss -3128525.25, Validation loss -1553614.0
EPOCH=66  validation_correlation=0.032991405576467514


Epoch 67: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 67, Batch 28, Train loss -3386192.5, Validation loss -1681635.125
EPOCH=67  validation_correlation=0.04864727333188057


Epoch 68: 100%|██████████| 29/29 [00:04<00:00,  6.57it/s]


Epoch 68, Batch 28, Train loss -3448964.75, Validation loss -1713073.125
EPOCH=68  validation_correlation=0.028281858190894127


Epoch 69: 100%|██████████| 29/29 [00:04<00:00,  6.57it/s]


Epoch 69, Batch 28, Train loss -3063742.25, Validation loss -1521546.0
EPOCH=69  validation_correlation=0.04602326080203056


Epoch 70: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 70, Batch 28, Train loss -3519694.0, Validation loss -1745833.625
EPOCH=70  validation_correlation=0.012734415009617805


Epoch 71: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 71, Batch 28, Train loss -3300732.0, Validation loss -1639869.75
EPOCH=71  validation_correlation=0.041756246238946915


Epoch 72: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 72, Batch 28, Train loss -3098113.75, Validation loss -1539978.625
EPOCH=72  validation_correlation=0.045264508575201035


Epoch 73: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 73, Batch 28, Train loss -3251158.75, Validation loss -1615179.375
EPOCH=73  validation_correlation=0.04368836432695389


Epoch 74: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 74, Batch 28, Train loss -3418134.25, Validation loss -1696043.875
EPOCH=74  validation_correlation=0.03606795519590378


Epoch 75: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]


Epoch 75, Batch 28, Train loss -3124338.5, Validation loss -1551466.0
EPOCH=75  validation_correlation=0.038160573691129684


Epoch 76: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


Epoch 76, Batch 28, Train loss -3247763.75, Validation loss -1613526.125
EPOCH=76  validation_correlation=0.03620816022157669


Epoch 77: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 77, Batch 28, Train loss -3288693.0, Validation loss -1632626.625
EPOCH=77  validation_correlation=0.014843636192381382


Epoch 78: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 78, Batch 28, Train loss -3391775.75, Validation loss -1684147.0
EPOCH=78  validation_correlation=0.0368538536131382


Epoch 79: 100%|██████████| 29/29 [00:04<00:00,  6.63it/s]


Epoch 79, Batch 28, Train loss -3389679.75, Validation loss -1684310.0
EPOCH=79  validation_correlation=0.03297024965286255


Epoch 80: 100%|██████████| 29/29 [00:04<00:00,  6.58it/s]


Epoch 80, Batch 28, Train loss -3135263.0, Validation loss -1558049.625
EPOCH=80  validation_correlation=0.03140003979206085


Epoch 81: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 81, Batch 28, Train loss -3433463.75, Validation loss -1706926.0
EPOCH=81  validation_correlation=0.04329119622707367


Epoch 82: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]


Epoch 82, Batch 28, Train loss -3395727.5, Validation loss -1687375.75
EPOCH=82  validation_correlation=0.03540707007050514


Epoch 83: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 83, Batch 28, Train loss -3278833.75, Validation loss -1628689.25
EPOCH=83  validation_correlation=0.03661656379699707


Epoch 84: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 84, Batch 28, Train loss -3214049.25, Validation loss -1596255.375
EPOCH=84  validation_correlation=0.03890460357069969


Epoch 85: 100%|██████████| 29/29 [00:04<00:00,  6.57it/s]


Epoch 85, Batch 28, Train loss -3413046.5, Validation loss -1695673.25
EPOCH=85  validation_correlation=0.01668686419725418


Epoch 86: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 86, Batch 28, Train loss -3261855.5, Validation loss -1620003.5
EPOCH=86  validation_correlation=0.05522824078798294


Epoch 87: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 87, Batch 28, Train loss -3213668.5, Validation loss -1596868.625
EPOCH=87  validation_correlation=0.05195092409849167


Epoch 88: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 88, Batch 28, Train loss -3296766.25, Validation loss -1636984.25
EPOCH=88  validation_correlation=0.04994964599609375


Epoch 89: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 89, Batch 28, Train loss -3256301.0, Validation loss -1616347.625
EPOCH=89  validation_correlation=0.057220276445150375


Epoch 90: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 90, Batch 28, Train loss -3123776.5, Validation loss -1551215.5
EPOCH=90  validation_correlation=0.05286543071269989


Epoch 91: 100%|██████████| 29/29 [00:04<00:00,  6.57it/s]


Epoch 91, Batch 28, Train loss -3439552.5, Validation loss -1709293.625
EPOCH=91  validation_correlation=0.03402131423354149


Epoch 92: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 92, Batch 28, Train loss -3181380.0, Validation loss -1581373.5
EPOCH=92  validation_correlation=0.04567335546016693


Epoch 93: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 93, Batch 28, Train loss -3236293.5, Validation loss -1608032.625
EPOCH=93  validation_correlation=0.03165417164564133


Epoch 94: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 94, Batch 28, Train loss -2951704.75, Validation loss -1467813.875
EPOCH=94  validation_correlation=0.04873427748680115


Epoch 95: 100%|██████████| 29/29 [00:04<00:00,  6.62it/s]


Epoch 95, Batch 28, Train loss -3141803.25, Validation loss -1559092.875
EPOCH=95  validation_correlation=0.03614036366343498


Epoch 96: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 96, Batch 28, Train loss -3172622.25, Validation loss -1576449.625
EPOCH=96  validation_correlation=0.044487763196229935


Epoch 97: 100%|██████████| 29/29 [00:04<00:00,  6.58it/s]


Epoch 97, Batch 28, Train loss -3189840.25, Validation loss -1584270.625
EPOCH=97  validation_correlation=0.03255162015557289


Epoch 98: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 98, Batch 28, Train loss -3355044.25, Validation loss -1666089.5
EPOCH=98  validation_correlation=0.04201219603419304


Epoch 99: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 99, Batch 28, Train loss -3137185.25, Validation loss -1559827.875
EPOCH=99  validation_correlation=0.03821636363863945


Epoch 100: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 100, Batch 28, Train loss -3354254.75, Validation loss -1667215.75
EPOCH=100  validation_correlation=0.039885684847831726


Epoch 101: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 101, Batch 28, Train loss -3182430.5, Validation loss -1580622.375
EPOCH=101  validation_correlation=0.03221922367811203


Epoch 102: 100%|██████████| 29/29 [00:04<00:00,  6.64it/s]


Epoch 102, Batch 28, Train loss -3314475.75, Validation loss -1647364.625
EPOCH=102  validation_correlation=0.03874073177576065


Epoch 103: 100%|██████████| 29/29 [00:04<00:00,  6.62it/s]


Epoch 103, Batch 28, Train loss -3351800.5, Validation loss -1663072.0
EPOCH=103  validation_correlation=0.04050995036959648


Epoch 104: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


Epoch 104, Batch 28, Train loss -3130479.0, Validation loss -1557235.5
EPOCH=104  validation_correlation=0.03007916547358036


 FINAL validation_correlation [-0.12324396 -0.01034652  0.10979459  0.00974759  0.16141316  0.06712832
 -0.09768895  0.11968698  0.10609881  0.13156985 -0.01900141 -0.02486833
 -0.01780079  0.02623915  0.11653848  0.14020735  0.01643892 -0.00758991
  0.03110812  0.06584491  0.04387206  0.31534022  0.0156926   0.17774148
  0.00300121  0.02090228 -0.07311064  0.12477145 -0.02999055  0.17314008
 -0.03327333 -0.01824508 -0.04818261 -0.07246035  0.03212944  0.05549524
 -0.06953572 -0.03327978 -0.02278681  0.03962059  0.17329049 -0.00528075
 -0.08601236  0.05667795  0.13465856  0.1883326   0.08949111  0.08901332
 -0.04395938 -0.01311084  0.03409529  0.05137739 -0.01326226  0.12950176
 -0.02644549 -0.05678316  0.01251276  0.01748434  0.06170031 -0.14783497
  0.07916028 -0.0475851  -0.08250668  0.19346857  0.14067174  0.10017065
  0.07543784  0.04718336  0.23457548  0.

VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Batch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Epoch Train loss,█▃▃▅▇▁▅▅▂▄▆▃▆▂▄▁▆▅█▄▃▄▄▅▆▂▅▇▁▅▅▃▄▂▆▅▂▄▅▅
Epoch validation loss,▂▂▆▆▄▂▃▆▁▁▇▆▄▅▅▃▅▇▄▅▄▇▄█▅▂█▇▂▄▇▂▂▄▇▅▆▃▆▇
mean_validation_correlation,▃▅▄▅▇▂▅▆▆█▆▂▅▄▅▇▅▂▃▇▂▂▂█▅▆▆▆▄▁▄▄▁▆▇▄▆▅▄▃

0,1
Batch,3016.0
Epoch,104.0
Epoch Train loss,-89171824.0
Epoch validation loss,-1557235.5
mean_validation_correlation,0.03008


  xavier_normal(m.weight.data)
  init.constant(m.bias.data, 0.0)


optim_step_count = 1


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670123611887296, max=1.0…

  transform = lambda x: (x - self._inputs_mean) / self._inputs_std
Epoch 1: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 1, Batch 28, Train loss -3481668.75, Validation loss -1726877.875
EPOCH=1  validation_correlation=-0.004414687864482403


Epoch 2: 100%|██████████| 29/29 [00:04<00:00,  6.70it/s]


Epoch 2, Batch 28, Train loss -3107222.5, Validation loss -1545780.125
EPOCH=2  validation_correlation=-0.009702609851956367


Epoch 3: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 3, Batch 28, Train loss -3342583.5, Validation loss -1660898.25
EPOCH=3  validation_correlation=0.01256908942013979


Epoch 4: 100%|██████████| 29/29 [00:04<00:00,  6.62it/s]


Epoch 4, Batch 28, Train loss -3079332.0, Validation loss -1530865.875
EPOCH=4  validation_correlation=0.007977079600095749


Epoch 5: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 5, Batch 28, Train loss -2955248.0, Validation loss -1467841.0
EPOCH=5  validation_correlation=0.016413116827607155


Epoch 6: 100%|██████████| 29/29 [00:04<00:00,  6.62it/s]


Epoch 6, Batch 28, Train loss -3400124.25, Validation loss -1689665.75
EPOCH=6  validation_correlation=0.03921574726700783


Epoch 7: 100%|██████████| 29/29 [00:04<00:00,  6.64it/s]


Epoch 7, Batch 28, Train loss -3087152.5, Validation loss -1535572.25
EPOCH=7  validation_correlation=0.010788113810122013


Epoch 8: 100%|██████████| 29/29 [00:04<00:00,  6.68it/s]


Epoch 8, Batch 28, Train loss -3244109.0, Validation loss -1611993.25
EPOCH=8  validation_correlation=0.03042674995958805


Epoch 9: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


Epoch 9, Batch 28, Train loss -3057185.5, Validation loss -1516955.625
EPOCH=9  validation_correlation=0.026948045939207077


Epoch 10: 100%|██████████| 29/29 [00:04<00:00,  6.64it/s]


Epoch 10, Batch 28, Train loss -3066373.0, Validation loss -1523403.25
EPOCH=10  validation_correlation=0.03920690715312958


Epoch 11: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 11, Batch 28, Train loss -3121471.25, Validation loss -1548833.25
EPOCH=11  validation_correlation=0.025445405393838882


Epoch 12: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


Epoch 12, Batch 28, Train loss -3120323.0, Validation loss -1552266.125
EPOCH=12  validation_correlation=0.03482404351234436


Epoch 13: 100%|██████████| 29/29 [00:04<00:00,  6.63it/s]


Epoch 13, Batch 28, Train loss -3159781.0, Validation loss -1570664.375
EPOCH=13  validation_correlation=0.048914723098278046


Epoch 14: 100%|██████████| 29/29 [00:04<00:00,  6.65it/s]


Epoch 14, Batch 28, Train loss -3100554.0, Validation loss -1535723.25
EPOCH=14  validation_correlation=0.044782985001802444


Epoch 15: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 15, Batch 28, Train loss -3161189.0, Validation loss -1570749.25
EPOCH=15  validation_correlation=0.036331869661808014


Epoch 16: 100%|██████████| 29/29 [00:04<00:00,  6.71it/s]


Epoch 16, Batch 28, Train loss -2958768.5, Validation loss -1470628.25
EPOCH=16  validation_correlation=0.03749469667673111


Epoch 17: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 17, Batch 28, Train loss -3357117.25, Validation loss -1668009.375
EPOCH=17  validation_correlation=0.032972000539302826


Epoch 18: 100%|██████████| 29/29 [00:04<00:00,  6.68it/s]


Epoch 18, Batch 28, Train loss -3225683.0, Validation loss -1601044.0
EPOCH=18  validation_correlation=0.022181198000907898


Epoch 19: 100%|██████████| 29/29 [00:04<00:00,  6.72it/s]


Epoch 19, Batch 28, Train loss -2932247.5, Validation loss -1457856.125
EPOCH=19  validation_correlation=0.04729345068335533
Epoch 00019: reducing learning rate of group 0 to 1.5000e-03.


Epoch 20: 100%|██████████| 29/29 [00:04<00:00,  6.66it/s]


Epoch 20, Batch 28, Train loss -3288524.75, Validation loss -1632179.125
EPOCH=20  validation_correlation=0.04134918004274368


Epoch 21: 100%|██████████| 29/29 [00:04<00:00,  6.71it/s]


Epoch 21, Batch 28, Train loss -3226076.5, Validation loss -1603786.5
EPOCH=21  validation_correlation=0.04862973839044571


Epoch 22: 100%|██████████| 29/29 [00:04<00:00,  6.63it/s]


Epoch 22, Batch 28, Train loss -3387937.0, Validation loss -1681489.375
EPOCH=22  validation_correlation=0.04181628301739693


Epoch 23: 100%|██████████| 29/29 [00:04<00:00,  6.69it/s]


Epoch 23, Batch 28, Train loss -3222193.25, Validation loss -1599226.875
EPOCH=23  validation_correlation=0.03436165675520897


Epoch 24: 100%|██████████| 29/29 [00:04<00:00,  6.64it/s]


Epoch 24, Batch 28, Train loss -3104062.5, Validation loss -1541302.125
EPOCH=24  validation_correlation=0.04461948201060295


Epoch 25: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 25, Batch 28, Train loss -3227079.25, Validation loss -1606013.75
EPOCH=25  validation_correlation=0.045370183885097504


Epoch 26: 100%|██████████| 29/29 [00:04<00:00,  6.68it/s]


Epoch 26, Batch 28, Train loss -2870736.0, Validation loss -1423931.125
EPOCH=26  validation_correlation=0.04448501765727997


Epoch 27: 100%|██████████| 29/29 [00:04<00:00,  6.70it/s]


Epoch 27, Batch 28, Train loss -3102543.75, Validation loss -1540844.875
EPOCH=27  validation_correlation=0.014917566440999508


Epoch 28: 100%|██████████| 29/29 [00:04<00:00,  6.65it/s]


Epoch 28, Batch 28, Train loss -3412236.0, Validation loss -1689342.375
EPOCH=28  validation_correlation=0.0370485782623291


Epoch 29: 100%|██████████| 29/29 [00:04<00:00,  6.63it/s]


Epoch 29, Batch 28, Train loss -3129158.75, Validation loss -1555516.0
EPOCH=29  validation_correlation=0.03153882175683975


Epoch 30: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


Epoch 30, Batch 28, Train loss -3235275.75, Validation loss -1609198.0
EPOCH=30  validation_correlation=0.02913506329059601
Epoch 00030: reducing learning rate of group 0 to 4.5000e-04.


Epoch 31: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


Epoch 31, Batch 28, Train loss -2906531.5, Validation loss -1443658.375
EPOCH=31  validation_correlation=0.0372404120862484


Epoch 32: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


Epoch 32, Batch 28, Train loss -3370958.0, Validation loss -1673788.375
EPOCH=32  validation_correlation=0.037050507962703705


Epoch 33: 100%|██████████| 29/29 [00:04<00:00,  6.64it/s]


Epoch 33, Batch 28, Train loss -3083844.0, Validation loss -1532395.875
EPOCH=33  validation_correlation=0.042611803859472275


Epoch 34: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


Epoch 34, Batch 28, Train loss -3066732.75, Validation loss -1523358.5
EPOCH=34  validation_correlation=0.04771524295210838


Epoch 35: 100%|██████████| 29/29 [00:04<00:00,  6.72it/s]


Epoch 35, Batch 28, Train loss -2927389.0, Validation loss -1454065.125
EPOCH=35  validation_correlation=0.02211104705929756


Epoch 36: 100%|██████████| 29/29 [00:04<00:00,  6.63it/s]


Epoch 36, Batch 28, Train loss -2879871.0, Validation loss -1426738.75
EPOCH=36  validation_correlation=0.046513404697179794


Epoch 37: 100%|██████████| 29/29 [00:04<00:00,  6.71it/s]


Epoch 37, Batch 28, Train loss -3038385.0, Validation loss -1511059.5
EPOCH=37  validation_correlation=0.042417801916599274


Epoch 38: 100%|██████████| 29/29 [00:04<00:00,  6.73it/s]


Epoch 38, Batch 28, Train loss -3296910.25, Validation loss -1636133.625
EPOCH=38  validation_correlation=0.05183594673871994


Epoch 39: 100%|██████████| 29/29 [00:04<00:00,  6.62it/s]


Epoch 39, Batch 28, Train loss -3171422.75, Validation loss -1574555.0
EPOCH=39  validation_correlation=0.046008385717868805


Epoch 40: 100%|██████████| 29/29 [00:04<00:00,  6.69it/s]


Epoch 40, Batch 28, Train loss -3435248.75, Validation loss -1706533.375
EPOCH=40  validation_correlation=0.024405431002378464


Epoch 41: 100%|██████████| 29/29 [00:04<00:00,  6.74it/s]


Epoch 41, Batch 28, Train loss -3264903.25, Validation loss -1622079.875
EPOCH=41  validation_correlation=0.07561150193214417


Epoch 42: 100%|██████████| 29/29 [00:04<00:00,  6.74it/s]


Epoch 42, Batch 28, Train loss -3149420.25, Validation loss -1566733.5
EPOCH=42  validation_correlation=0.04246925190091133


Epoch 43: 100%|██████████| 29/29 [00:04<00:00,  6.68it/s]


Epoch 43, Batch 28, Train loss -3300991.5, Validation loss -1639596.5
EPOCH=43  validation_correlation=0.02924143709242344


Epoch 44: 100%|██████████| 29/29 [00:04<00:00,  6.72it/s]


Epoch 44, Batch 28, Train loss -3389992.5, Validation loss -1685367.0
EPOCH=44  validation_correlation=0.022572167217731476


Epoch 45: 100%|██████████| 29/29 [00:04<00:00,  6.75it/s]


Epoch 45, Batch 28, Train loss -3048321.25, Validation loss -1515746.25
EPOCH=45  validation_correlation=0.022505493834614754


Epoch 46: 100%|██████████| 29/29 [00:04<00:00,  6.68it/s]


Epoch 46, Batch 28, Train loss -3464027.5, Validation loss -1718071.75
EPOCH=46  validation_correlation=0.017077770084142685


Epoch 47: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


Epoch 47, Batch 28, Train loss -3131484.75, Validation loss -1555632.0
EPOCH=47  validation_correlation=0.04964998736977577


Epoch 48: 100%|██████████| 29/29 [00:04<00:00,  6.70it/s]


Epoch 48, Batch 28, Train loss -3083097.0, Validation loss -1527997.5
EPOCH=48  validation_correlation=0.017531147226691246
Epoch 00048: reducing learning rate of group 0 to 1.3500e-04.


Epoch 49: 100%|██████████| 29/29 [00:04<00:00,  6.74it/s]


Epoch 49, Batch 28, Train loss -3327678.25, Validation loss -1655321.75
EPOCH=49  validation_correlation=0.03235159069299698


Epoch 50: 100%|██████████| 29/29 [00:04<00:00,  7.08it/s]


Epoch 50, Batch 28, Train loss -3460684.0, Validation loss -1717954.125
EPOCH=50  validation_correlation=0.027184495702385902


Epoch 51: 100%|██████████| 29/29 [00:02<00:00, 10.21it/s]


Epoch 51, Batch 28, Train loss -3221016.75, Validation loss -1598137.375
EPOCH=51  validation_correlation=0.01251702755689621


Epoch 52: 100%|██████████| 29/29 [00:04<00:00,  6.23it/s]


Epoch 52, Batch 28, Train loss -2933929.0, Validation loss -1456785.25
EPOCH=52  validation_correlation=0.03579362481832504


Epoch 53: 100%|██████████| 29/29 [00:04<00:00,  6.23it/s]


Epoch 53, Batch 28, Train loss -3205948.75, Validation loss -1590542.625
EPOCH=53  validation_correlation=0.056117944419384


Epoch 54: 100%|██████████| 29/29 [00:04<00:00,  6.24it/s]


Epoch 54, Batch 28, Train loss -3260215.75, Validation loss -1620029.5
EPOCH=54  validation_correlation=0.03782026469707489


Epoch 55: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 55, Batch 28, Train loss -3262662.0, Validation loss -1619512.5
EPOCH=55  validation_correlation=0.0501377210021019


Epoch 56: 100%|██████████| 29/29 [00:04<00:00,  6.24it/s]


Epoch 56, Batch 28, Train loss -3308963.25, Validation loss -1644090.5
EPOCH=56  validation_correlation=0.06595472246408463


Epoch 57: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 57, Batch 28, Train loss -3003918.75, Validation loss -1492020.75
EPOCH=57  validation_correlation=0.046591538935899734


Epoch 58: 100%|██████████| 29/29 [00:04<00:00,  6.23it/s]


Epoch 58, Batch 28, Train loss -3047896.75, Validation loss -1514417.125
EPOCH=58  validation_correlation=0.0388677753508091


Epoch 59: 100%|██████████| 29/29 [00:04<00:00,  6.24it/s]


Epoch 59, Batch 28, Train loss -3051762.0, Validation loss -1517372.625
EPOCH=59  validation_correlation=0.041647955775260925
Epoch 00059: reducing learning rate of group 0 to 1.0000e-04.


Epoch 60: 100%|██████████| 29/29 [00:04<00:00,  6.24it/s]


Epoch 60, Batch 28, Train loss -2775354.0, Validation loss -1377517.75
EPOCH=60  validation_correlation=0.04190308228135109


Epoch 61: 100%|██████████| 29/29 [00:04<00:00,  6.21it/s]


Epoch 61, Batch 28, Train loss -3090572.0, Validation loss -1532684.125
EPOCH=61  validation_correlation=0.040835656225681305


Epoch 62: 100%|██████████| 29/29 [00:04<00:00,  6.20it/s]


Epoch 62, Batch 28, Train loss -3257712.25, Validation loss -1616952.375
EPOCH=62  validation_correlation=0.043106380850076675


Epoch 63: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 63, Batch 28, Train loss -3082514.5, Validation loss -1531800.25
EPOCH=63  validation_correlation=0.03475373238325119


Epoch 64: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 64, Batch 28, Train loss -3125635.5, Validation loss -1551007.5
EPOCH=64  validation_correlation=0.049058135598897934


Epoch 65: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 65, Batch 28, Train loss -2932201.0, Validation loss -1455440.75
EPOCH=65  validation_correlation=0.047313518822193146


Epoch 66: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 66, Batch 28, Train loss -3219299.25, Validation loss -1598541.875
EPOCH=66  validation_correlation=0.03138110786676407


Epoch 67: 100%|██████████| 29/29 [00:04<00:00,  6.23it/s]


Epoch 67, Batch 28, Train loss -3254031.0, Validation loss -1615299.375
EPOCH=67  validation_correlation=0.02477203868329525


Epoch 68: 100%|██████████| 29/29 [00:04<00:00,  6.23it/s]


Epoch 68, Batch 28, Train loss -3243888.5, Validation loss -1612840.375
EPOCH=68  validation_correlation=0.05597924068570137


Epoch 69: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 69, Batch 28, Train loss -2984408.75, Validation loss -1481629.625
EPOCH=69  validation_correlation=0.02617751993238926


Epoch 70: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 70, Batch 28, Train loss -3061291.5, Validation loss -1521170.125
EPOCH=70  validation_correlation=0.030167467892169952


Epoch 71: 100%|██████████| 29/29 [00:04<00:00,  6.22it/s]


Epoch 71, Batch 28, Train loss -3328571.75, Validation loss -1654738.25
EPOCH=71  validation_correlation=0.03904784843325615


Epoch 72: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 72, Batch 28, Train loss -3145894.25, Validation loss -1563278.75
EPOCH=72  validation_correlation=0.038086675107479095


Epoch 73: 100%|██████████| 29/29 [00:04<00:00,  6.24it/s]


Epoch 73, Batch 28, Train loss -2959834.25, Validation loss -1470052.875
EPOCH=73  validation_correlation=0.0217787716537714


Epoch 74: 100%|██████████| 29/29 [00:04<00:00,  6.24it/s]


Epoch 74, Batch 28, Train loss -2929695.0, Validation loss -1454893.125
EPOCH=74  validation_correlation=0.042485084384679794


Epoch 75: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 75, Batch 28, Train loss -2876291.5, Validation loss -1430199.25
EPOCH=75  validation_correlation=0.05183401703834534


Epoch 76: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 76, Batch 28, Train loss -2954357.25, Validation loss -1468093.0
EPOCH=76  validation_correlation=0.037771593779325485


Epoch 77: 100%|██████████| 29/29 [00:03<00:00,  7.35it/s]


Epoch 77, Batch 28, Train loss -3264995.0, Validation loss -1622697.125
EPOCH=77  validation_correlation=0.030177820473909378


Epoch 78: 100%|██████████| 29/29 [00:03<00:00,  9.03it/s]


Epoch 78, Batch 28, Train loss -3198825.25, Validation loss -1588746.875
EPOCH=78  validation_correlation=0.045456551015377045


Epoch 79: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 79, Batch 28, Train loss -3111859.0, Validation loss -1544870.625
EPOCH=79  validation_correlation=0.04709427431225777


Epoch 80: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 80, Batch 28, Train loss -3088232.0, Validation loss -1534289.0
EPOCH=80  validation_correlation=0.04975977540016174


Epoch 81: 100%|██████████| 29/29 [00:04<00:00,  6.22it/s]


Epoch 81, Batch 28, Train loss -3194653.75, Validation loss -1588549.875
EPOCH=81  validation_correlation=0.045044220983982086


Epoch 82: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 82, Batch 28, Train loss -3172740.75, Validation loss -1575532.125
EPOCH=82  validation_correlation=0.02291071228682995


Epoch 83: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 83, Batch 28, Train loss -3206128.5, Validation loss -1592272.875
EPOCH=83  validation_correlation=0.021866926923394203


Epoch 84: 100%|██████████| 29/29 [00:04<00:00,  6.22it/s]


Epoch 84, Batch 28, Train loss -3096672.75, Validation loss -1536309.0
EPOCH=84  validation_correlation=0.04403827339410782


Epoch 85: 100%|██████████| 29/29 [00:04<00:00,  6.22it/s]


Epoch 85, Batch 28, Train loss -3027536.5, Validation loss -1505390.5
EPOCH=85  validation_correlation=0.030566226691007614


Epoch 86: 100%|██████████| 29/29 [00:04<00:00,  6.24it/s]


Epoch 86, Batch 28, Train loss -3030193.25, Validation loss -1505772.5
EPOCH=86  validation_correlation=0.04483076184988022


Epoch 87: 100%|██████████| 29/29 [00:04<00:00,  6.23it/s]


Epoch 87, Batch 28, Train loss -3032591.25, Validation loss -1506583.75
EPOCH=87  validation_correlation=0.04405275732278824


Epoch 88: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 88, Batch 28, Train loss -2891003.25, Validation loss -1436087.75
EPOCH=88  validation_correlation=0.03964191675186157


Epoch 89: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 89, Batch 28, Train loss -3092169.25, Validation loss -1535504.75
EPOCH=89  validation_correlation=0.052129581570625305


Epoch 90: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 90, Batch 28, Train loss -3237095.0, Validation loss -1609181.75
EPOCH=90  validation_correlation=0.030606338754296303


Epoch 91: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 91, Batch 28, Train loss -3091170.5, Validation loss -1536083.375
EPOCH=91  validation_correlation=0.04965883865952492


Epoch 92: 100%|██████████| 29/29 [00:04<00:00,  6.23it/s]


Epoch 92, Batch 28, Train loss -3111557.5, Validation loss -1548318.0
EPOCH=92  validation_correlation=0.04940066859126091


Epoch 93: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 93, Batch 28, Train loss -3157066.0, Validation loss -1567264.0
EPOCH=93  validation_correlation=0.05357063561677933


Epoch 94: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 94, Batch 28, Train loss -2843709.0, Validation loss -1413340.875
EPOCH=94  validation_correlation=0.04461215063929558


Epoch 95: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 95, Batch 28, Train loss -3205591.75, Validation loss -1591519.75
EPOCH=95  validation_correlation=0.04446060582995415


Epoch 96: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 96, Batch 28, Train loss -3067810.0, Validation loss -1523278.5
EPOCH=96  validation_correlation=0.03260447457432747


Epoch 97: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 97, Batch 28, Train loss -3139947.75, Validation loss -1559964.625
EPOCH=97  validation_correlation=0.036291904747486115


Epoch 98: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 98, Batch 28, Train loss -3157748.75, Validation loss -1568417.375
EPOCH=98  validation_correlation=0.04878944158554077


Epoch 99: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 99, Batch 28, Train loss -3508351.5, Validation loss -1742319.625
EPOCH=99  validation_correlation=0.039415426552295685


Epoch 100: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


Epoch 100, Batch 28, Train loss -3228987.0, Validation loss -1606873.125
EPOCH=100  validation_correlation=0.046554163098335266


Epoch 101: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


Epoch 101, Batch 28, Train loss -3285456.0, Validation loss -1632806.25
EPOCH=101  validation_correlation=0.04957132413983345


Epoch 102: 100%|██████████| 29/29 [00:04<00:00,  6.25it/s]


Epoch 102, Batch 28, Train loss -3125347.75, Validation loss -1552652.0
EPOCH=102  validation_correlation=0.029397454112768173


Epoch 103: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 103, Batch 28, Train loss -3058045.0, Validation loss -1519443.5
EPOCH=103  validation_correlation=0.05284101888537407


Epoch 104: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 104, Batch 28, Train loss -3056635.0, Validation loss -1518443.25
EPOCH=104  validation_correlation=0.042033568024635315


Epoch 105: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 105, Batch 28, Train loss -3127389.75, Validation loss -1552400.5
EPOCH=105  validation_correlation=0.047075167298316956


Epoch 106: 100%|██████████| 29/29 [00:04<00:00,  6.33it/s]


Epoch 106, Batch 28, Train loss -3315509.0, Validation loss -1649287.125
EPOCH=106  validation_correlation=0.04998590052127838


Epoch 107: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 107, Batch 28, Train loss -3137019.75, Validation loss -1558682.75
EPOCH=107  validation_correlation=0.02969784289598465


Epoch 108: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


Epoch 108, Batch 28, Train loss -3133399.5, Validation loss -1557461.875
EPOCH=108  validation_correlation=0.037484344094991684


Epoch 109: 100%|██████████| 29/29 [00:04<00:00,  6.28it/s]


Epoch 109, Batch 28, Train loss -2870120.25, Validation loss -1426609.375
EPOCH=109  validation_correlation=0.0343363955616951


Epoch 110: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 110, Batch 28, Train loss -3128287.0, Validation loss -1554519.625
EPOCH=110  validation_correlation=0.05637826770544052


Epoch 111: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 111, Batch 28, Train loss -2799486.0, Validation loss -1390194.5
EPOCH=111  validation_correlation=0.04275144264101982


Epoch 112: 100%|██████████| 29/29 [00:04<00:00,  6.29it/s]


Epoch 112, Batch 28, Train loss -3111270.5, Validation loss -1544834.375
EPOCH=112  validation_correlation=0.03362998366355896


Epoch 113: 100%|██████████| 29/29 [00:04<00:00,  6.33it/s]


Epoch 113, Batch 28, Train loss -3026580.75, Validation loss -1506094.375
EPOCH=113  validation_correlation=0.05970250070095062


Epoch 114: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 114, Batch 28, Train loss -3196848.0, Validation loss -1589147.5
EPOCH=114  validation_correlation=0.04624162241816521


Epoch 115: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 115, Batch 28, Train loss -2970816.0, Validation loss -1476718.25
EPOCH=115  validation_correlation=0.03612462058663368


Epoch 116: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 116, Batch 28, Train loss -3584385.25, Validation loss -1781200.25
EPOCH=116  validation_correlation=0.01956058293581009


Epoch 117: 100%|██████████| 29/29 [00:04<00:00,  6.33it/s]


Epoch 117, Batch 28, Train loss -2877126.0, Validation loss -1427773.875
EPOCH=117  validation_correlation=0.06216846778988838


Epoch 118: 100%|██████████| 29/29 [00:04<00:00,  6.32it/s]


Epoch 118, Batch 28, Train loss -3019131.75, Validation loss -1498893.5
EPOCH=118  validation_correlation=0.030085766687989235


Epoch 119: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 119, Batch 28, Train loss -3091865.75, Validation loss -1536264.75
EPOCH=119  validation_correlation=0.03461724892258644


Epoch 120: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


Epoch 120, Batch 28, Train loss -3194247.75, Validation loss -1585775.25
EPOCH=120  validation_correlation=0.05703132972121239


Epoch 121: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


Epoch 121, Batch 28, Train loss -3076192.5, Validation loss -1527706.875
EPOCH=121  validation_correlation=0.038802556693553925


Epoch 122: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 122, Batch 28, Train loss -3238487.75, Validation loss -1610189.375
EPOCH=122  validation_correlation=0.03193923458456993


Epoch 123: 100%|██████████| 29/29 [00:04<00:00,  6.33it/s]


Epoch 123, Batch 28, Train loss -3435406.25, Validation loss -1708437.375
EPOCH=123  validation_correlation=0.039305802434682846


Epoch 124: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 124, Batch 28, Train loss -3248990.75, Validation loss -1615436.0
EPOCH=124  validation_correlation=0.04170682653784752


Epoch 125: 100%|██████████| 29/29 [00:04<00:00,  6.36it/s]


Epoch 125, Batch 28, Train loss -3184756.25, Validation loss -1584540.0
EPOCH=125  validation_correlation=0.03185176104307175


Epoch 126: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 126, Batch 28, Train loss -3148112.75, Validation loss -1565136.25
EPOCH=126  validation_correlation=0.02371549978852272


Epoch 127: 100%|██████████| 29/29 [00:04<00:00,  6.42it/s]


Epoch 127, Batch 28, Train loss -2884005.75, Validation loss -1433668.625
EPOCH=127  validation_correlation=0.040688805282115936


Epoch 128: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


Epoch 128, Batch 28, Train loss -3177653.75, Validation loss -1580184.25
EPOCH=128  validation_correlation=0.041169650852680206


 FINAL validation_correlation [-0.06718403 -0.0644016   0.0400522  -0.0663787  -0.01024811  0.00616032
 -0.0008841   0.11076792  0.08815111 -0.11182246 -0.08437144  0.01417047
 -0.00736432  0.23409225  0.00103776 -0.07435224  0.02613295  0.0634546
 -0.08594895  0.12699148 -0.1008964  -0.0063795  -0.20199959  0.02086805
  0.06189115  0.08961201  0.26452672  0.22063406  0.00995078  0.19829
 -0.04413905  0.04691803  0.03633422 -0.0198384   0.2625047  -0.00298086
  0.16035302  0.10718294  0.13476226  0.03611291 -0.05419668  0.15522555
  0.07852341 -0.02301615  0.11732852  0.08506277  0.2182685   0.23364437
  0.08072647 -0.0280943   0.19055094  0.04398108 -0.04842104  0.1659322
  0.05171755  0.11828978  0.0479669   0.05554968 -0.05912356 -0.03054831
  0.05956043  0.00963165  0.16569598  0.07793544  0.2770273   0.0965395
 -0.06584385 -0.09317604  0.32329857  0.188

0,1
Batch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Epoch Train loss,█▅▃▄▇▃▁▃▄▁▃▅▅▂▃▅█▅▄▃▃▃▃▇▃▅▃▅▃▂▂▅▅▄▄▄▅▅▂▃
Epoch validation loss,▅▅▃▅▅▃▃▅▇▃▅▇▁▂▁▁▆▃█▃▃▃▄▆▄▄▅▇▄▇▄▄▅▄▄▄▆▄▃▄
mean_validation_correlation,▁▃▅▆▆▄▆▆▆▅▆▆▄▄▃▄▅█▆▆▅▇▅▅▆▄▆▆▆▆▆▅▆▅▇▆▅▇▆▆

0,1
Batch,3712.0
Epoch,128.0
Epoch Train loss,-90377760.0
Epoch validation loss,-1580184.25
mean_validation_correlation,0.04117


  xavier_normal(m.weight.data)
  init.constant(m.bias.data, 0.0)


optim_step_count = 1


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016672157868742942, max=1.0…

  transform = lambda x: (x - self._inputs_mean) / self._inputs_std
Epoch 1: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


Epoch 1, Batch 28, Train loss -2853773.5, Validation loss -1413654.375
EPOCH=1  validation_correlation=-0.00882685650140047


Epoch 2: 100%|██████████| 29/29 [00:04<00:00,  6.36it/s]


Epoch 2, Batch 28, Train loss -2851835.75, Validation loss -1414200.0
EPOCH=2  validation_correlation=0.012248002924025059


Epoch 3: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


Epoch 3, Batch 28, Train loss -3149626.25, Validation loss -1563867.375
EPOCH=3  validation_correlation=0.017127439379692078


Epoch 4: 100%|██████████| 29/29 [00:04<00:00,  6.41it/s]


Epoch 4, Batch 28, Train loss -2890010.5, Validation loss -1438407.875
EPOCH=4  validation_correlation=0.029246674850583076


Epoch 5: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


Epoch 5, Batch 28, Train loss -3102130.25, Validation loss -1541625.75
EPOCH=5  validation_correlation=0.040390871465206146


Epoch 6: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


Epoch 6, Batch 28, Train loss -2707915.25, Validation loss -1345800.5
EPOCH=6  validation_correlation=0.05778616666793823


Epoch 7: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


Epoch 7, Batch 28, Train loss -2821273.0, Validation loss -1400662.875
EPOCH=7  validation_correlation=0.02689231000840664


Epoch 8: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


Epoch 8, Batch 28, Train loss -2883300.0, Validation loss -1436315.5
EPOCH=8  validation_correlation=0.051225971430540085


Epoch 9: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


Epoch 9, Batch 28, Train loss -2922783.5, Validation loss -1455240.0
EPOCH=9  validation_correlation=0.024364352226257324


Epoch 10: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


Epoch 10, Batch 28, Train loss -3116741.25, Validation loss -1550526.0
EPOCH=10  validation_correlation=0.048957906663417816


Epoch 11: 100%|██████████| 29/29 [00:04<00:00,  6.43it/s]


Epoch 11, Batch 28, Train loss -2774018.0, Validation loss -1382140.25
EPOCH=11  validation_correlation=0.04089833050966263


Epoch 12: 100%|██████████| 29/29 [00:04<00:00,  6.42it/s]


Epoch 12, Batch 28, Train loss -2693212.5, Validation loss -1333883.0
EPOCH=12  validation_correlation=0.0267281923443079


Epoch 13: 100%|██████████| 29/29 [00:04<00:00,  6.36it/s]


Epoch 13, Batch 28, Train loss -2898748.5, Validation loss -1438151.0
EPOCH=13  validation_correlation=0.04443557187914848


Epoch 14: 100%|██████████| 29/29 [00:04<00:00,  6.41it/s]


Epoch 14, Batch 28, Train loss -2884427.0, Validation loss -1434766.5
EPOCH=14  validation_correlation=0.021150890737771988


Epoch 15: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


Epoch 15, Batch 28, Train loss -3082063.25, Validation loss -1533830.625
EPOCH=15  validation_correlation=0.043102655559778214
Epoch 00015: reducing learning rate of group 0 to 1.5000e-03.


Epoch 16: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 16, Batch 28, Train loss -2856428.5, Validation loss -1422569.625
EPOCH=16  validation_correlation=0.032587651163339615


Epoch 17: 100%|██████████| 29/29 [00:04<00:00,  6.42it/s]


Epoch 17, Batch 28, Train loss -2822857.0, Validation loss -1404809.375
EPOCH=17  validation_correlation=0.0157190952450037


Epoch 18: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 18, Batch 28, Train loss -2827265.25, Validation loss -1405046.375
EPOCH=18  validation_correlation=0.021850839257240295


Epoch 19: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 19, Batch 28, Train loss -3004453.75, Validation loss -1496524.625
EPOCH=19  validation_correlation=0.008108135312795639


Epoch 20: 100%|██████████| 29/29 [00:04<00:00,  6.42it/s]


Epoch 20, Batch 28, Train loss -2936190.0, Validation loss -1459060.375
EPOCH=20  validation_correlation=0.051972389221191406


Epoch 21: 100%|██████████| 29/29 [00:04<00:00,  6.41it/s]


Epoch 21, Batch 28, Train loss -2940224.25, Validation loss -1463087.5
EPOCH=21  validation_correlation=0.03433777764439583


Epoch 22: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


Epoch 22, Batch 28, Train loss -3136565.75, Validation loss -1559625.625
EPOCH=22  validation_correlation=0.030601177364587784


Epoch 23: 100%|██████████| 29/29 [00:04<00:00,  6.44it/s]


Epoch 23, Batch 28, Train loss -2692394.75, Validation loss -1338388.5
EPOCH=23  validation_correlation=0.027738897129893303


Epoch 24: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


Epoch 24, Batch 28, Train loss -3012980.25, Validation loss -1497723.0
EPOCH=24  validation_correlation=0.05442200228571892


Epoch 25: 100%|██████████| 29/29 [00:04<00:00,  6.44it/s]


Epoch 25, Batch 28, Train loss -2804080.5, Validation loss -1393024.375
EPOCH=25  validation_correlation=0.022718697786331177


Epoch 26: 100%|██████████| 29/29 [00:04<00:00,  6.41it/s]


Epoch 26, Batch 28, Train loss -2957077.0, Validation loss -1477254.75
EPOCH=26  validation_correlation=0.0373738631606102
Epoch 00026: reducing learning rate of group 0 to 4.5000e-04.


Epoch 27: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 27, Batch 28, Train loss -3103529.0, Validation loss -1543274.25
EPOCH=27  validation_correlation=0.026891009882092476


Epoch 28: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 28, Batch 28, Train loss -2860429.0, Validation loss -1422579.25
EPOCH=28  validation_correlation=0.027627797797322273


Epoch 29: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 29, Batch 28, Train loss -3117735.25, Validation loss -1550222.25
EPOCH=29  validation_correlation=0.017810411751270294


Epoch 30: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 30, Batch 28, Train loss -2746062.75, Validation loss -1364601.875
EPOCH=30  validation_correlation=0.020006002858281136


Epoch 31: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 31, Batch 28, Train loss -3012595.75, Validation loss -1496953.75
EPOCH=31  validation_correlation=0.0357472226023674


Epoch 32: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 32, Batch 28, Train loss -3059789.5, Validation loss -1523442.25
EPOCH=32  validation_correlation=0.012261069379746914


Epoch 33: 100%|██████████| 29/29 [00:04<00:00,  6.50it/s]


Epoch 33, Batch 28, Train loss -2944938.5, Validation loss -1465746.875
EPOCH=33  validation_correlation=0.03866678476333618


Epoch 34: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 34, Batch 28, Train loss -2686568.25, Validation loss -1335936.375
EPOCH=34  validation_correlation=0.017929071560502052


Epoch 35: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 35, Batch 28, Train loss -2864869.0, Validation loss -1424901.125
EPOCH=35  validation_correlation=0.01594049483537674


Epoch 36: 100%|██████████| 29/29 [00:04<00:00,  6.41it/s]


Epoch 36, Batch 28, Train loss -2906056.75, Validation loss -1444075.5
EPOCH=36  validation_correlation=0.020943183451890945


Epoch 37: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


Epoch 37, Batch 28, Train loss -2915918.25, Validation loss -1450069.375
EPOCH=37  validation_correlation=0.010617817752063274
Epoch 00037: reducing learning rate of group 0 to 1.3500e-04.


Epoch 38: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 38, Batch 28, Train loss -2909516.25, Validation loss -1446899.375
EPOCH=38  validation_correlation=0.01960379257798195


Epoch 39: 100%|██████████| 29/29 [00:04<00:00,  6.50it/s]


Epoch 39, Batch 28, Train loss -2862437.75, Validation loss -1424725.5
EPOCH=39  validation_correlation=0.02968895062804222


Epoch 40: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 40, Batch 28, Train loss -2905298.75, Validation loss -1443143.0
EPOCH=40  validation_correlation=0.04204612225294113


Epoch 41: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 41, Batch 28, Train loss -2914109.75, Validation loss -1449376.875
EPOCH=41  validation_correlation=0.008627152070403099


Epoch 42: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


Epoch 42, Batch 28, Train loss -2874377.25, Validation loss -1429191.25
EPOCH=42  validation_correlation=0.016915762796998024


Epoch 43: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


Epoch 43, Batch 28, Train loss -2952537.75, Validation loss -1466888.875
EPOCH=43  validation_correlation=0.01864621788263321


Epoch 44: 100%|██████████| 29/29 [00:04<00:00,  6.44it/s]


Epoch 44, Batch 28, Train loss -2956734.5, Validation loss -1469647.875
EPOCH=44  validation_correlation=0.033875469118356705


Epoch 45: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 45, Batch 28, Train loss -2856418.25, Validation loss -1418929.5
EPOCH=45  validation_correlation=0.02599361352622509


Epoch 46: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 46, Batch 28, Train loss -3067417.25, Validation loss -1524550.375
EPOCH=46  validation_correlation=0.03660871833562851


Epoch 47: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 47, Batch 28, Train loss -2711500.25, Validation loss -1348155.75
EPOCH=47  validation_correlation=0.04529443010687828


Epoch 48: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 48, Batch 28, Train loss -2896174.25, Validation loss -1440414.875
EPOCH=48  validation_correlation=0.02948862500488758
Epoch 00048: reducing learning rate of group 0 to 1.0000e-04.


Epoch 49: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 49, Batch 28, Train loss -2897039.5, Validation loss -1440590.875
EPOCH=49  validation_correlation=0.03903203085064888


Epoch 50: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 50, Batch 28, Train loss -3039655.25, Validation loss -1510624.0
EPOCH=50  validation_correlation=0.011102539487183094


Epoch 51: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 51, Batch 28, Train loss -3042971.5, Validation loss -1512673.75
EPOCH=51  validation_correlation=0.03597823902964592


Epoch 52: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


Epoch 52, Batch 28, Train loss -2920336.5, Validation loss -1451843.125
EPOCH=52  validation_correlation=0.03940833732485771


Epoch 53: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


Epoch 53, Batch 28, Train loss -2918719.25, Validation loss -1450813.625
EPOCH=53  validation_correlation=0.03637843206524849


Epoch 54: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 54, Batch 28, Train loss -3004635.75, Validation loss -1495347.5
EPOCH=54  validation_correlation=0.0451173335313797


Epoch 55: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


Epoch 55, Batch 28, Train loss -2795328.5, Validation loss -1389799.5
EPOCH=55  validation_correlation=0.04086482152342796


Epoch 56: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 56, Batch 28, Train loss -2985752.25, Validation loss -1483210.5
EPOCH=56  validation_correlation=0.016456395387649536


Epoch 57: 100%|██████████| 29/29 [00:04<00:00,  6.42it/s]


Epoch 57, Batch 28, Train loss -2987322.5, Validation loss -1484929.375
EPOCH=57  validation_correlation=0.03340781480073929


Epoch 58: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 58, Batch 28, Train loss -2962356.25, Validation loss -1471677.5
EPOCH=58  validation_correlation=0.014660246670246124


Epoch 59: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 59, Batch 28, Train loss -2837937.25, Validation loss -1411364.875
EPOCH=59  validation_correlation=0.00996240321546793


Epoch 60: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


Epoch 60, Batch 28, Train loss -2834994.0, Validation loss -1410899.625
EPOCH=60  validation_correlation=0.008720438927412033


Epoch 61: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 61, Batch 28, Train loss -2744494.5, Validation loss -1364872.25
EPOCH=61  validation_correlation=0.02419888973236084


Epoch 62: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


Epoch 62, Batch 28, Train loss -2985194.0, Validation loss -1484616.75
EPOCH=62  validation_correlation=0.026016870513558388


Epoch 63: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]


Epoch 63, Batch 28, Train loss -2882247.75, Validation loss -1432267.75
EPOCH=63  validation_correlation=0.04722100496292114


Epoch 64: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 64, Batch 28, Train loss -2936093.0, Validation loss -1458406.375
EPOCH=64  validation_correlation=0.02268921583890915


Epoch 65: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 65, Batch 28, Train loss -2950534.5, Validation loss -1465684.625
EPOCH=65  validation_correlation=0.030361903831362724


Epoch 66: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 66, Batch 28, Train loss -2853965.75, Validation loss -1417430.25
EPOCH=66  validation_correlation=0.02069445326924324


Epoch 67: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 67, Batch 28, Train loss -2701563.75, Validation loss -1343800.0
EPOCH=67  validation_correlation=0.024065112695097923


Epoch 68: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 68, Batch 28, Train loss -2823086.25, Validation loss -1402883.25
EPOCH=68  validation_correlation=0.016673246398568153


Epoch 69: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 69, Batch 28, Train loss -2820968.0, Validation loss -1401529.25
EPOCH=69  validation_correlation=0.03731439635157585


Epoch 70: 100%|██████████| 29/29 [00:04<00:00,  6.48it/s]


Epoch 70, Batch 28, Train loss -2931107.75, Validation loss -1457559.25
EPOCH=70  validation_correlation=0.020803334191441536


Epoch 71: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


Epoch 71, Batch 28, Train loss -3163300.25, Validation loss -1573596.375
EPOCH=71  validation_correlation=0.016361327841877937


Epoch 72: 100%|██████████| 29/29 [00:04<00:00,  6.58it/s]


Epoch 72, Batch 28, Train loss -2886504.5, Validation loss -1436443.125
EPOCH=72  validation_correlation=0.03852907568216324


Epoch 73: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 73, Batch 28, Train loss -2961440.25, Validation loss -1472873.0
EPOCH=73  validation_correlation=0.01893671043217182


Epoch 74: 100%|██████████| 29/29 [00:04<00:00,  6.62it/s]


Epoch 74, Batch 28, Train loss -3061026.75, Validation loss -1521177.75
EPOCH=74  validation_correlation=0.039369601756334305


Epoch 75: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


Epoch 75, Batch 28, Train loss -2697719.0, Validation loss -1339516.125
EPOCH=75  validation_correlation=0.025863591581583023


Epoch 76: 100%|██████████| 29/29 [00:04<00:00,  6.58it/s]


Epoch 76, Batch 28, Train loss -3034713.0, Validation loss -1508839.5
EPOCH=76  validation_correlation=0.039037611335515976


Epoch 77: 100%|██████████| 29/29 [00:04<00:00,  6.54it/s]


Epoch 77, Batch 28, Train loss -2811866.0, Validation loss -1398240.5
EPOCH=77  validation_correlation=0.021379608660936356


Epoch 78: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 78, Batch 28, Train loss -2994979.5, Validation loss -1488016.625
EPOCH=78  validation_correlation=0.038022808730602264


Epoch 79: 100%|██████████| 29/29 [00:04<00:00,  6.57it/s]


Epoch 79, Batch 28, Train loss -2800017.75, Validation loss -1390969.625
EPOCH=79  validation_correlation=0.027208887040615082


Epoch 80: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]


Epoch 80, Batch 28, Train loss -2854867.5, Validation loss -1419614.0
EPOCH=80  validation_correlation=0.0287049300968647


Epoch 81: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


Epoch 81, Batch 28, Train loss -2952405.0, Validation loss -1469165.75
EPOCH=81  validation_correlation=0.03328492492437363


Epoch 82: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 82, Batch 28, Train loss -3040653.75, Validation loss -1510571.25
EPOCH=82  validation_correlation=0.02822674624621868


Epoch 83: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 83, Batch 28, Train loss -2852506.5, Validation loss -1419479.0
EPOCH=83  validation_correlation=0.012475614435970783


Epoch 84: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 84, Batch 28, Train loss -2810434.5, Validation loss -1396622.375
EPOCH=84  validation_correlation=0.017344722524285316


Epoch 85: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 85, Batch 28, Train loss -2834721.5, Validation loss -1409039.25
EPOCH=85  validation_correlation=0.008249226026237011


Epoch 86: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 86, Batch 28, Train loss -3029155.25, Validation loss -1504625.375
EPOCH=86  validation_correlation=0.025249527767300606


Epoch 87: 100%|██████████| 29/29 [00:04<00:00,  6.55it/s]


Epoch 87, Batch 28, Train loss -2914074.25, Validation loss -1448565.25
EPOCH=87  validation_correlation=0.03838242217898369


Epoch 88: 100%|██████████| 29/29 [00:04<00:00,  6.52it/s]


Epoch 88, Batch 28, Train loss -2909847.0, Validation loss -1445745.75
EPOCH=88  validation_correlation=0.03641900047659874


Epoch 89: 100%|██████████| 29/29 [00:04<00:00,  6.57it/s]


Epoch 89, Batch 28, Train loss -2919557.0, Validation loss -1452673.875
EPOCH=89  validation_correlation=0.04856790602207184


Epoch 90: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 90, Batch 28, Train loss -2928295.5, Validation loss -1456224.5
EPOCH=90  validation_correlation=0.04725898802280426


Epoch 91: 100%|██████████| 29/29 [00:04<00:00,  6.50it/s]


Epoch 91, Batch 28, Train loss -2949105.25, Validation loss -1466507.75
EPOCH=91  validation_correlation=0.027119746431708336


Epoch 92: 100%|██████████| 29/29 [00:04<00:00,  6.62it/s]


Epoch 92, Batch 28, Train loss -2838528.75, Validation loss -1409763.125
EPOCH=92  validation_correlation=0.03636354207992554


Epoch 93: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 93, Batch 28, Train loss -2964272.25, Validation loss -1474067.25
EPOCH=93  validation_correlation=0.037174999713897705


Epoch 94: 100%|██████████| 29/29 [00:04<00:00,  6.60it/s]


Epoch 94, Batch 28, Train loss -2766421.75, Validation loss -1375484.375
EPOCH=94  validation_correlation=0.03125428408384323


Epoch 95: 100%|██████████| 29/29 [00:04<00:00,  6.53it/s]


Epoch 95, Batch 28, Train loss -2749034.0, Validation loss -1366065.5
EPOCH=95  validation_correlation=0.026393814012408257


Epoch 96: 100%|██████████| 29/29 [00:04<00:00,  6.56it/s]


Epoch 96, Batch 28, Train loss -2861884.0, Validation loss -1422369.75
EPOCH=96  validation_correlation=0.03799669072031975


Epoch 97: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 97, Batch 28, Train loss -3041930.25, Validation loss -1510889.125
EPOCH=97  validation_correlation=0.020800543949007988


Epoch 98: 100%|██████████| 29/29 [00:04<00:00,  6.58it/s]


Epoch 98, Batch 28, Train loss -2828641.5, Validation loss -1406668.25
EPOCH=98  validation_correlation=0.03934235870838165


Epoch 99: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


Epoch 99, Batch 28, Train loss -2906706.25, Validation loss -1443948.5
EPOCH=99  validation_correlation=0.028277523815631866


Epoch 100: 100%|██████████| 29/29 [00:04<00:00,  6.58it/s]


Epoch 100, Batch 28, Train loss -2989634.75, Validation loss -1485940.875
EPOCH=100  validation_correlation=0.028271351009607315


Epoch 101: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


Epoch 101, Batch 28, Train loss -2802277.75, Validation loss -1392846.125
EPOCH=101  validation_correlation=0.03436623141169548


Epoch 102: 100%|██████████| 29/29 [00:04<00:00,  6.59it/s]


Epoch 102, Batch 28, Train loss -2773465.0, Validation loss -1378426.875
EPOCH=102  validation_correlation=0.02686525695025921


Epoch 103: 100%|██████████| 29/29 [00:04<00:00,  6.46it/s]


Epoch 103, Batch 28, Train loss -2870636.25, Validation loss -1426348.25
EPOCH=103  validation_correlation=0.04084144905209541


Epoch 104: 100%|██████████| 29/29 [00:04<00:00,  6.58it/s]


Epoch 104, Batch 28, Train loss -2952007.25, Validation loss -1467035.125
EPOCH=104  validation_correlation=0.015016558580100536


 FINAL validation_correlation [-2.23527819e-01  1.01493828e-01  1.85111705e-02  1.34007499e-01
 -9.52667221e-02  9.68854949e-02 -3.75080705e-02  9.45146978e-02
  4.76465337e-02  5.40753938e-02  5.38817942e-02  3.51026617e-02
  7.58177340e-02 -5.10618128e-02 -9.38764960e-02  1.10812232e-01
 -2.31051687e-02 -6.61231205e-02  4.97053750e-02 -4.30180132e-02
  1.71934580e-03  6.96269423e-02 -8.42740759e-02  7.10252151e-02
 -2.10814551e-02  1.28345266e-01 -6.87920004e-02  1.07648484e-01
 -5.93202747e-02  9.37053338e-02  4.92502637e-02 -3.29988264e-02
 -5.97144887e-02 -1.50160849e-01  3.17727923e-02 -2.98195649e-02
 -4.21201549e-02  1.54503763e-01  2.51672640e-02  3.68089345e-03
 -2.43673399e-02  1.02621003e-03  7.34241307e-02  6.10230565e-02
  2.69292183e-02  1.27484962e-01  9.76711437e-02  1.66094095e-01
  4.33492549e-02  2.07817201e-02 -9.13386196e-02  9.8948016

VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Batch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Epoch Train loss,▇▂▄▅▅▁▄▂▅▂▂▃▂▅▂▄█▂▅▃▃▁▄▄▂▅▅▅▅▃▄▃▂▄▃▂▆▅▃▂
Epoch validation loss,▆▁█▅▇▅▅▃▁▃▂▇▂▅▅▅▄▆▅▃▅▄▆▇▄█▆▅▂▆▆▃▆▅▄▄▆▆▆▄
mean_validation_correlation,▁▄█▇▆▄▅▃▅█▅▄▃▄▃▆▄▅▅▆▆▄▃▄▄▄▆▆▆▄▅▅▃▆▇▆▆▆▆▄

0,1
Batch,3016.0
Epoch,104.0
Epoch Train loss,-90379920.0
Epoch validation loss,-1467035.125
mean_validation_correlation,0.01502


In [None]:
# full grid search - too late to start this rn! :P
for hc in hidden_channels: # 4 *
    for ik in input_kern_sizes: # 4 * 
        for hk in hidden_kern_sizes: # 4 *
            for d in depths: # 3 *
                for rc in rec_channels_gru: # 4 *
                    for ikg in input_kern_sizes_gru: #4 *
                        for hkg in hidden_kern_sizes_gru: #4 *
                            core_dict = make_core_dict(hc, ik, hk, d)
                            gru_dict = make_gru_dict(hc, rc, ikg, hkg) # input channels always same as hc

                            # make full model
                            gru_2d_model = make_video_model(data_loaders,
                                             seed,
                                             core_dict=core_dict.copy(),
                                             core_type='2D',
                                             readout_dict=readout_dict.copy(),
                                             readout_type='gaussian',               
                                             use_gru=True,
                                             gru_dict=gru_dict.copy(),
                                             use_shifter=False,
                                             deeplake_ds=True
                                             )
                            # train
                            validation_score, trainer_output, state_dict = trainer(gru_2d_model)

# We have a model running!
Finally

I need to clean things up now.

Constructing a repository - What did I actually make?  Stuff that happens between pickle format of data and deeplake dataset, then multiple deeplake datasets to multiple dataloaders, then training multiple dataloaders with models.

Most of the last two steps was already done by the sensorium 2023 repository.  I need to fetch the new commits, and probably rewrite nicely the edits i've made to their shit.

I think a repo would be focused on the sort of data fusion and construction process.  I supposed I wouldn't wanna make it conflict with the sensorium structure too much...but I'm on the fence about whether I should install that repo, or fork my own and make changes.

How am I computing firing rate?  What parameters of that process do we want to make
I need to append behavioral, optogenetic, and pupil data as well.  The shifter is i think maybe the most important first.

The behavior we can just append as channels like they do, but the optogenetics idk.  I wonder if it should really affect the features computed by the core, or rather affect the readout (or recurrence) modules.  For the GRU recurrence I think we could assign an optimizable weight to integrate it into the gate update equations.  For the readout I am less sure how to do that...I mean honestly maybe a really dumb way to do it would be to have two columns for the weight vectors (so like 2 different readouts) and then just switch which one you select depending on whether opto is on or off... this could be done in the forward pass of the readout module I think.

These two might be worth a commit to neural predictors.

But first I have to clean my workspace


### Testing

In [13]:
val_correlation = get_correlations(
        gru_2d_model, data_loaders["oracle"], device='cuda', as_dict=False, per_neuron=False
    )

print(val_correlation)

0.08788537


## Other models example configs
### Rotation Equivariant 2D core, GRU and Gaussian readout

Here everything is the same except for the core. The core is rotation equivariant, inspired by Ecker et all 2018 [4]. Please note that the final amount of channels then is number_of_rotations * number_of_channels, as each channel is rotated number_of_rotations times.

---
[4] Ecker, A. S., Sinz, F. H., Froudarakis, E., Fahey, P. G., Cadena, S. A., Walker, E. Y., ... & Bethge, M. (2018). A rotation-equivariant convolutional neural network model of primary visual cortex. arXiv preprint arXiv:1809.10504.

In [31]:
equivar_2D_core_dict = dict(
    input_channels=4,
    hidden_channels=8,
    input_kern=9,
    hidden_kern=7,
    layers=4,
    num_rotations=8,
    gamma_input=500,
    skip=0,
    pad_input=False,
    final_nonlinearity=False,
    bias=True,
    momentum=0.9,
    batch_norm=True,
    hidden_dilation=1,
    laplace_padding=None,
    input_regularizer="LaplaceL2norm",
    stack=-1,
    depth_separable=False,
    linear=False,
    attention_conv=False,
    hidden_padding=None,
    use_avg_reg=False,
    final_batchnorm_scale=True,
    gamma_hidden=500_000,
)

In [34]:
gru_2d_model_equivariant = make_video_model(data_loaders,
                 seed,
                 core_dict=equivar_2D_core_dict,
                 core_type='2D_equivariant',
                 readout_dict=readout_dict.copy(),
                 readout_type='gaussian',               
                 use_gru=True,
                 gru_dict=gru_dict,
                 use_shifter=True,
                 shifter_dict=shifter_dict,
                 shifter_type='MLP',
                 deeplake_ds=True
                 )

/gpfs01/berens/data/data/sensorium23/_gpfs01_berens_data_data_sensorium23_hub_sinzlab_sensorium2023_dynamic29156-11-10-Video-8744edeac3b4d1ce16b680916b5267ce_train loaded successfully.

** Loaded local copy of dataset. Downloaded on: Wed Jun 21 16:42:54 2023




In [35]:
gru_2d_model_equivariant

VideoFiringRateEncoder(
  (core): RotationEquivariant2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (hermite_conv): HermiteConv2D(
          (rotate_hermite): RotateHermite(
            (Rs): ParameterList(
                (0): Parameter containing: [torch.float32 of size 45x45]
                (1): Parameter containing: [torch.float32 of size 45x45]
                (2): Parameter containing: [torch.float32 of size 45x45]
                (3): Parameter containing: [torch.float32 of size 45x45]
                (4): Parameter containing: [torch.float32 of size 45x45]
                (5): Parameter containing: [torch.float32 of size 45x45]
                (6): Parameter containing: [torch.float32 of size 45x45]
                (7): Parameter containing: [torch.float32 of size 45x45]
            )
          )
        )
        (norm): RotationEquivariantBatchNorm2D(
          (batch_n

In [7]:
trainer_fn = "sensorium.training.video_training_loop.standard_trainer"

trainer_config = {
    'dataloaders' : data_loaders,
    'seed' : 111,
    'use_wandb' : False,
    'verbose': True,
    'lr_decay_steps': 4,
    'lr_init': 0.005,
    'device' : f"cuda",
    'detach_core' : False,
    'checkpoint_save_path' : '/gpfs01/berens/user/ncimaszewski/my-docker-folder/ncimaszewski/dl_for_sensorium/sensorium_2023/state_dicts/gru_2d_equiv/',
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                 trainer_config=trainer_config)

In [38]:
#17 -10 
validation_score, trainer_output, state_dict = trainer(gru_2d_model_equivariant)

within standard_trainer from sensorium.training.video_training_loop.py
stop_function is get_correlations
device is cuda
optim_step_count = 1


Epoch 1: 100%|██████████| 23/23 [00:57<00:00,  2.50s/it]


Epoch 1, Batch 22, Train loss 13209597.0, Validation loss 5387805.0
EPOCH=1  validation_correlation=0.004065733402967453


Epoch 2: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 2, Batch 22, Train loss 13088991.0, Validation loss 5438223.0
EPOCH=2  validation_correlation=0.012634330429136753


Epoch 3: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 3, Batch 22, Train loss 12111847.0, Validation loss 5077576.5
EPOCH=3  validation_correlation=0.013424641452729702


Epoch 4: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 4, Batch 22, Train loss 12147579.0, Validation loss 5112736.0
EPOCH=4  validation_correlation=0.013618553057312965


Epoch 5: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 5, Batch 22, Train loss 13352884.0, Validation loss 5580119.0
EPOCH=5  validation_correlation=0.017037495970726013


Epoch 6: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 6, Batch 22, Train loss 11914002.0, Validation loss 4995057.5
EPOCH=6  validation_correlation=0.0178961381316185


Epoch 7: 100%|██████████| 23/23 [00:27<00:00,  1.21s/it]


Epoch 7, Batch 22, Train loss 11803503.0, Validation loss 4959281.0
EPOCH=7  validation_correlation=0.012931333854794502


Epoch 8: 100%|██████████| 23/23 [00:28<00:00,  1.22s/it]


Epoch 8, Batch 22, Train loss 12688781.0, Validation loss 5302756.5
EPOCH=8  validation_correlation=0.02378493919968605


Epoch 9: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 9, Batch 22, Train loss 12920349.0, Validation loss 5382246.0
EPOCH=9  validation_correlation=0.02342301420867443


Epoch 10: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 10, Batch 22, Train loss 11831710.0, Validation loss 4944355.0
EPOCH=10  validation_correlation=0.020561395213007927


Epoch 11: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 11, Batch 22, Train loss 11966526.0, Validation loss 4991894.0
EPOCH=11  validation_correlation=0.02772560901939869


Epoch 12: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 12, Batch 22, Train loss 12382682.0, Validation loss 5153821.5
EPOCH=12  validation_correlation=0.019120562821626663


Epoch 13: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 13, Batch 22, Train loss 13212104.0, Validation loss 5459616.5
EPOCH=13  validation_correlation=0.01883845031261444


Epoch 14: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 14, Batch 22, Train loss 12631647.0, Validation loss 5241691.5
EPOCH=14  validation_correlation=0.03291740268468857


Epoch 15: 100%|██████████| 23/23 [00:28<00:00,  1.25s/it]


Epoch 15, Batch 22, Train loss 12500183.0, Validation loss 5190298.0
EPOCH=15  validation_correlation=0.0267714262008667


Epoch 16: 100%|██████████| 23/23 [00:27<00:00,  1.20s/it]


Epoch 16, Batch 22, Train loss 12782656.0, Validation loss 5284122.0
EPOCH=16  validation_correlation=0.02071988955140114


Epoch 17: 100%|██████████| 23/23 [00:28<00:00,  1.22s/it]


Epoch 17, Batch 22, Train loss 11933547.0, Validation loss 4965740.5
EPOCH=17  validation_correlation=0.025659622624516487


Epoch 18: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 18, Batch 22, Train loss 12498919.0, Validation loss 5178886.0
EPOCH=18  validation_correlation=0.01827332004904747


Epoch 19: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 19, Batch 22, Train loss 12088159.0, Validation loss 5031596.0
EPOCH=19  validation_correlation=0.027467690408229828
Epoch 00019: reducing learning rate of group 0 to 1.5000e-03.


Epoch 20: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 20, Batch 22, Train loss 12793813.0, Validation loss 5321807.0
EPOCH=20  validation_correlation=0.02427278831601143


Epoch 21: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 21, Batch 22, Train loss 11834389.0, Validation loss 4938628.0
EPOCH=21  validation_correlation=0.027632595971226692


Epoch 22: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 22, Batch 22, Train loss 12084474.0, Validation loss 5038585.0
EPOCH=22  validation_correlation=0.02050679922103882


Epoch 23: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 23, Batch 22, Train loss 12109378.0, Validation loss 5051768.5
EPOCH=23  validation_correlation=0.027415843680500984


Epoch 24: 100%|██████████| 23/23 [00:27<00:00,  1.21s/it]


Epoch 24, Batch 22, Train loss 12355611.0, Validation loss 5158060.5
EPOCH=24  validation_correlation=0.0237624142318964


Epoch 25: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 25, Batch 22, Train loss 11859044.0, Validation loss 4958921.0
EPOCH=25  validation_correlation=0.024255532771348953
Epoch 00025: reducing learning rate of group 0 to 4.5000e-04.


Epoch 26: 100%|██████████| 23/23 [00:28<00:00,  1.25s/it]


Epoch 26, Batch 22, Train loss 11602933.0, Validation loss 4859801.5
EPOCH=26  validation_correlation=0.025443311780691147


Epoch 27: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 27, Batch 22, Train loss 11738531.0, Validation loss 4910580.0
EPOCH=27  validation_correlation=0.021015914157032967


Epoch 28: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 28, Batch 22, Train loss 12351721.0, Validation loss 5160145.5
EPOCH=28  validation_correlation=0.02181219309568405


Epoch 29: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 29, Batch 22, Train loss 12061876.0, Validation loss 5046434.5
EPOCH=29  validation_correlation=0.02140902914106846


Epoch 30: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 30, Batch 22, Train loss 12408690.0, Validation loss 5184251.5
EPOCH=30  validation_correlation=0.021754754707217216


Epoch 31: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 31, Batch 22, Train loss 11870884.0, Validation loss 4969548.0
EPOCH=31  validation_correlation=0.017477726563811302


Epoch 32: 100%|██████████| 23/23 [00:28<00:00,  1.22s/it]


Epoch 32, Batch 22, Train loss 12466618.0, Validation loss 5214204.0
EPOCH=32  validation_correlation=0.020536569878458977


Epoch 33: 100%|██████████| 23/23 [00:28<00:00,  1.24s/it]


Epoch 33, Batch 22, Train loss 11907720.0, Validation loss 4986232.0
EPOCH=33  validation_correlation=0.024456635117530823


Epoch 34: 100%|██████████| 23/23 [00:29<00:00,  1.30s/it]


Epoch 34, Batch 22, Train loss 12180672.0, Validation loss 5098323.5
EPOCH=34  validation_correlation=0.007961125113070011


Epoch 35: 100%|██████████| 23/23 [00:28<00:00,  1.23s/it]


Epoch 35, Batch 22, Train loss 12941226.0, Validation loss 5404648.0
EPOCH=35  validation_correlation=0.021621011197566986


 FINAL validation_correlation 0.023504063487052917 




### Full 3D core and Gaussian readout

In [9]:
full_3D_core_dict = dict(
    input_channels=4,
    hidden_channels=[64]*2,
    input_kernel=9,
    hidden_kernel=7,
    num_frames=9,
    stride=1,
    layers=2,
    gamma_input_spatial=0,
    gamma_input_temporal=0,
    bias=True, 
    hidden_nonlinearities='elu',
    x_shift=0, 
    y_shift=0,
    batch_norm=True,
    laplace_padding=None,
    input_regularizer='LaplaceL2norm',
    padding=True,
    final_nonlin=True,
    independent_bn_bias=True,
    # pad_time=False,
    momentum=0.7,
    spatial_input_kernel=None,
)

In [10]:
full_3d_model = make_video_model(data_loaders,
                 seed,
                 core_dict=full_3D_core_dict,
                 core_type='3D',
                 readout_dict=readout_dict.copy(),
                 readout_type='gaussian',               
                 use_gru=False,
                 gru_dict=None,
                 use_shifter=True,
                 shifter_dict=shifter_dict,
                 shifter_type='MLP',
                 deeplake_ds=True
                 )

RuntimeError: running_mean should contain 64 elements not 4096

In [None]:
%debug

> [0;32m/.pyenv/versions/miniconda3-3.9-4.12.0/lib/python3.9/site-packages/torch/nn/functional.py[0m(2450)[0;36mbatch_norm[0;34m()[0m
[0;32m   2448 [0;31m        [0m_verify_batch_size[0m[0;34m([0m[0minput[0m[0;34m.[0m[0msize[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   2449 [0;31m[0;34m[0m[0m
[0m[0;32m-> 2450 [0;31m    return torch.batch_norm(
[0m[0;32m   2451 [0;31m        [0minput[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0mbias[0m[0;34m,[0m [0mrunning_mean[0m[0;34m,[0m [0mrunning_var[0m[0;34m,[0m [0mtraining[0m[0;34m,[0m [0mmomentum[0m[0;34m,[0m [0meps[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mbackends[0m[0;34m.[0m[0mcudnn[0m[0;34m.[0m[0menabled[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   2452 [0;31m    )
[0m


ipdb>  running_mean.shape


torch.Size([64, 64])


ipdb>  print(self)


*** NameError: name 'self' is not defined


In [15]:
full_3d_model

VideoFiringRateEncoder(
  (core): Basic3dCore(
    (_input_weight_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (temporal_regulazirer): TimeLaplaceL23d(
      (laplace): Laplace1d()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv3d(4, 64, kernel_size=(9, 9, 9), stride=(1, 1, 1), padding=(0, 4, 4))
        (norm): BatchNorm3d([64, 64], eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer1): Sequential(
        (conv_1): Conv3d(64, 64, kernel_size=(9, 7, 7), stride=(1, 1, 1), padding=(0, 3, 3))
        (norm): BatchNorm3d([64, 64], eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
    )
  ) [Basic3dCore regularizers: gamma_input_spatial = 0|gamma_input_temporal = 0]
  
  (readout): MultipleFullGaussian2d(
    (dynamic26872-19-13-Video-b580985c0d83307660a6109ec863aaed): full FullGaussian2d (64 x 22 x 50 -> 10968) wit

### Factorised 3D core and Gaussian readout

In [46]:
factorised_3D_core_dict = dict(
    input_channels=4,
    hidden_channels=[64]*3,
    spatial_input_kernel=(11,11),
    temporal_input_kernel=11,
    spatial_hidden_kernel=(7,7),
    temporal_hidden_kernel=7,
    stride=1,
    layers=3,
    gamma_input_spatial=0,
    gamma_input_temporal=0,
    bias=True, 
    hidden_nonlinearities='elu',
    x_shift=0, 
    y_shift=0,
    batch_norm=True,
    laplace_padding=None,
    input_regularizer='LaplaceL2norm',
    padding=True,
    final_nonlin=True,
    independent_bn_bias=True,
    # pad_time=False,
    momentum=0.7
)

In [47]:
factorised_3d_model = make_video_model(
    data_loaders,
    seed,
    core_dict=factorised_3D_core_dict,
    core_type='3D_factorised',
    readout_dict=readout_dict.copy(),
    readout_type='gaussian',               
    use_gru=False,
    gru_dict=None,
    use_shifter=True,
    shifter_dict=shifter_dict,
    shifter_type='MLP',
    deeplake_ds=True
)

RuntimeError: running_mean should contain 64 elements not 262144

In [None]:
factorised_3d_model

### Factorised 3D core and Factorised readouts 

In [7]:
factorised_readout_dict = dict(
    bias=True,
    gamma_readout=0.0,
    spatial_and_feature_reg_weight=0.0,
    positive_weights=False,
    normalize=False,
    init_noise=0.001,
)

In [None]:
factorised_3d_core_factorised_readout = make_video_model(
    data_loaders,
    seed,
    core_dict=factorised_3D_core_dict,
    core_type='3D_factorised',
    readout_dict=factorised_readout_dict,
    readout_type='factorised',               
    use_gru=False,
    gru_dict=None,
    use_shifter=True,
    shifter_dict=shifter_dict,
    shifter_type='MLP',
)

In [11]:
factorised_3d_core_factorised_readout

VideoFiringRateEncoder(
  (core): Factorized3dCore(
    (_input_weight_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (temporal_regularizer): TimeLaplaceL23d(
      (laplace): Laplace1d()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv_spatial): Conv3d(4, 64, kernel_size=(1, 11, 11), stride=(1, 1, 1), padding=(0, 5, 5))
        (conv_temporal): Conv3d(64, 64, kernel_size=(11, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(64, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer1): Sequential(
        (conv_spatial_1): Conv3d(64, 64, kernel_size=(1, 7, 7), stride=(1, 1, 1), padding=(0, 3, 3))
        (conv_temporal_1): Conv3d(64, 64, kernel_size=(7, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(64, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer2): Sequential(
        (conv_spatial_2): Conv3d(64