In [1]:
%load_ext autoreload
%autoreload 2

import os
import statistics

import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader

import config
import cpn_model
import generate_cpn_train_data
import lesion
import michaels_load
import mRNN
import observer
import stim
import utils

OBSERVER_TYPE = "gaussian"
LESION_PCT = 1.0
STIMULATION_TYPE = "gaussian"
ACTIVATION_TYPE = "ReTanh"
BEN_ACTIVATION_TYPE = "ReLU"
BATCH_SIZE = 64

observer_instance, stimulus, lesion_instance, activation, recovery_mode, recovery_str, \
        run_type_str, _ = config.get(observer_type=OBSERVER_TYPE,
                                              stimulation_type=STIMULATION_TYPE,
                                              lesion_pct=LESION_PCT,
                                              activation_type=ACTIVATION_TYPE,
                                              batch_size=BATCH_SIZE)

CPN_DATA_DIR = "cpn"
BEN_DATA_DIR = "ben"
MRNN_DIR = "mrnn"
CPN_DATA_PATH = os.path.join(CPN_DATA_DIR, f"cpn_train_data_{run_type_str}.hdf5")
BEN_MODEL_PATH = os.path.join(BEN_DATA_DIR, f"ben_model_{BEN_ACTIVATION_TYPE}_{run_type_str}.pth")
MRNN_MODEL_PATH = os.path.join(MRNN_DIR, f"mrnn_{str(lesion_instance)}_{recovery_str}.pth")
MRNN_HEALTHY_MODEL_PATH = os.path.join(MRNN_DIR, f"mrnn_{str(lesion_instance)}_{recovery_str}.pth_pre")
HEALTHY_DATA_PATH = os.path.join(CPN_DATA_DIR, f"cpn_healthy_train_data_{run_type_str}.hdf5")


In [2]:
_ = generate_cpn_train_data.generate(HEALTHY_DATA_PATH, MRNN_HEALTHY_MODEL_PATH, observer_instance)

1622412493.5638661 Generating example 0
1622412504.2174082 Generating example 100
1622412514.8929935 Generating example 200
1622412525.5372133 Generating example 300
1622412536.2411902 Generating example 400
1622412546.916648 Generating example 500
1622412557.544474 Generating example 600
1622412568.218761 Generating example 700
1622412578.8795252 Generating example 800
1622412589.5417264 Generating example 900
1622412600.2217777 Generating example 1000
1622412610.9145374 Generating example 1100
1622412621.6252654 Generating example 1200
1622412632.2902086 Generating example 1300
1622412642.9382613 Generating example 1400
1622412653.5609176 Generating example 1500
1622412664.1990726 Generating example 1600
1622412674.812902 Generating example 1700
1622412685.4653907 Generating example 1800
1622412696.1164498 Generating example 1900


(array([[-0.02150223, -0.01471637, -0.15769812, ...,  0.02501846,
          0.00500783,  1.        ],
        [-0.02150223, -0.01471637, -0.15769812, ...,  0.02501846,
          0.00500783,  1.        ],
        [-0.02150223, -0.01471637, -0.15769812, ...,  0.02501846,
          0.00500783,  1.        ],
        ...,
        [-0.02150223, -0.01471637, -0.15769812, ...,  0.02501846,
          0.00500783,  0.        ],
        [-0.02150223, -0.01471637, -0.15769812, ...,  0.02501846,
          0.00500783,  0.        ],
        [-0.02150223, -0.01471637, -0.15769812, ...,  0.02501846,
          0.00500783,  0.        ]]),
 array([[0.00252967, 0.01227746, 0.08356892, ..., 0.0050797 , 0.08490833,
         0.00923976],
        [0.01334654, 0.00640347, 0.08222893, ..., 0.00523181, 0.09038559,
         0.0135254 ],
        [0.02351345, 0.00341785, 0.0815514 , ..., 0.00550089, 0.09614731,
         0.01774631],
        ...,
        [0.23316082, 0.07880791, 0.0666663 , ..., 0.12709798, 0.21532099

In [70]:
train_args = cpn_model.prep_new(HEALTHY_DATA_PATH, MRNN_MODEL_PATH, BEN_MODEL_PATH, observer_instance,
                                lesion_instance, stimulus, activation, batch_size=BATCH_SIZE)

Loading dataset; this may take awhile...


In [76]:
cpn_model.train_model(*train_args)

Epoch: 0
tensor(0.0904, grad_fn=<MseLossBackward>)
> [0;32m/home/mbryan/Projects/coproc/coproc-poc/cpn_model.py[0m(219)[0;36mtrain_model[0;34m()[0m
[0;32m    217 [0;31m[0;34m[0m[0m
[0m[0;32m    218 [0;31m            [0mloss[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 219 [0;31m            [0moptimizer[0m[0;34m.[0m[0mstep[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    220 [0;31m[0;34m[0m[0m
[0m[0;32m    221 [0;31m        [0;32mif[0m [0mtrain_stop_thresh[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m [0;32mand[0m [0mtrain_stop_thresh[0m [0;34m>=[0m [0mmin_loss[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> print(ben.x)
tensor([[-1.2966, -1.0709, -1.5994,  ..., -1.3686, -1.9508, -1.6224],
        [-1.2852, -1.0615, -1.5826,  ..., -1.3638, -1.9483, -1.6126],
        [-1.2932, -1.0685, -1.5944,  ..., -1.3677, -1.9503, -1.6194],
        ...,
        [-1.2861, -1.0628, -1.58



None
ipdb> f['ben_in_grad'] = ben_in.grad.detach().numpy()
ipdb> f['model_fc_W_grad'] = model.fc.weight.grad.detach().numpy()
ipdb> f.close()
ipdb> quit


BdbQuit: 

In [44]:
op = AdamW(train_args[1].parameters(), lr=0.01)
train_args = list(train_args)
train_args[5] = op
print(train_args[3:6])

[StimModel(
  (activation_func): ReTanh()
  (fc): Linear(in_features=190, out_features=105, bias=True)
), <observer.ObserverGaussian1d object at 0x7fe1a22c6fa0>, AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.01
    weight_decay: 0.01
)]


In [35]:
import stim_model
ovd = train_args[4].out_dim * 3
sd = stimulus.num_stim_channels
ben = stim_model.load_from_file(BEN_MODEL_PATH, ovd + sd, ovd)

In [37]:
train_args[3] = ben

In [72]:
print(cpn_model.bi.shape)
print(cpn_model.bi.grad)

torch.Size([64, 140])
tensor([[ 9.4913e-07,  3.4652e-07,  2.3553e-07,  ..., -1.8996e-11,
          8.4051e-10,  3.4139e-10],
        [ 9.4260e-07,  2.9936e-07,  1.8435e-07,  ..., -2.1296e-10,
          8.4464e-10,  1.6565e-10],
        [ 1.0112e-06,  2.6495e-07,  2.3141e-07,  ..., -4.0874e-10,
          8.7424e-10,  1.5835e-11],
        ...,
        [ 9.3517e-07,  2.9555e-07,  1.6778e-07,  ..., -1.6756e-10,
          8.3928e-10,  2.1279e-10],
        [ 9.4260e-07,  2.9936e-07,  1.8435e-07,  ..., -2.1296e-10,
          8.4464e-10,  1.6565e-10],
        [ 9.6922e-07,  3.2272e-07,  2.4639e-07,  ..., -2.7137e-10,
          9.0416e-10,  1.3642e-10]])


In [73]:
print(cpn_model.bi[0,-10:])
print(cpn_model.bi.grad[0,-10:])

tensor([ 3.4879,  2.7927,  3.0463,  2.5524, -2.0597, -2.4432, -1.1065, -0.8758,
        -2.5194, -1.8867], grad_fn=<SliceBackward>)
tensor([-1.5640e-10,  1.0171e-10, -4.2748e-10, -9.8202e-11, -4.9592e-10,
         4.2175e-10, -7.5766e-11, -1.8996e-11,  8.4051e-10,  3.4139e-10])


In [74]:
print(train_args[1].fc.bias.grad)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
