In [1]:
%load_ext autoreload
%autoreload 2

import os
import statistics

import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader

import generate_stim_train_data
import lesion
import michaels_load
import mRNN
import observer
import stim
import stim_model
import utils

OBSERVER_TYPE = "passthrough"
LESION_MODULE = "F5"
LESION_TYPE = "outputs"
LESION_PCT = 1.0
STIMULATION_TYPE = "1to1"
ACTIVATION_TYPE = "ReLU"

RECOVERY = True
RECOVERY_STR = "recovered" if RECOVERY else "notrecovered"

RUN_TYPE = "_".join([OBSERVER_TYPE, LESION_MODULE, LESION_TYPE,
                     str(LESION_PCT), STIMULATION_TYPE])

DATA_DIR = "ben"
MRNN_DIR = "mrnn"
DATA_PATH = os.path.join(DATA_DIR, f"stim_data_brain_{RUN_TYPE}.hdf5")
MODEL_PATH = os.path.join(DATA_DIR, f"ben_model_{ACTIVATION_TYPE}_{RUN_TYPE}.pth")
MRNN_MODEL_PATH = os.path.join(MRNN_DIR, f"mrnn_{LESION_MODULE}_{LESION_TYPE}_"
                               f"{LESION_PCT}_{RECOVERY_STR}.pth")
BATCH_SIZE = 64
NUM_NEURONS_PER_MODULE = 100

In [2]:
if OBSERVER_TYPE == "passthrough":
    observer_instance = observer.ObserverPassthrough(NUM_NEURONS_PER_MODULE)
else:
    raise ValueError(f"Unrecognized observer type: {OBSERVER_TYPE}")

In [3]:
if LESION_TYPE == "outputs":
    lesion_instance = lesion.LesionOutputs(NUM_NEURONS_PER_MODULE,
            LESION_MODULE, LESION_PCT)
elif LESION_TYPE == "none":
    lesion_instance = None
else:
    raise ValueError(f"Unrecognized lesion type: {LESION_TYPE}")

In [4]:
if STIMULATION_TYPE == "1to1":
    stimulus = stim.Stimulus1to1(NUM_NEURONS_PER_MODULE, NUM_NEURONS_PER_MODULE)
elif STIMULATION_TYPE == "gaussian":
    # NOTE: can add the num_stim_channels and sigma arg above
    stimulus = stim.StimulusGaussian(35, NUM_NEURONS_PER_MODULE)
else:
    raise ValueError(f"Unrecognized stimulation type: {STIMULATION_TYPE}")

In [5]:
if ACTIVATION_TYPE == "ReLU":
    activation = torch.nn.ReLU
elif ACTIVATION_TYPE == "ReTanh":
    activation = utils.ReTanh
else:
    raise ValueError(f"Unrecognized activation type: {ACTIVATION_TYPE}")

In [196]:
# For reference: typical losses for a healthy network
# NOTE: if preds are all 0s, a loss of 0.02 is typical
# NOTE: if RNN activations are all 0s, 0.22 is typical
path = michaels_load.get_default_path()
rnn = mRNN.MichaelsRNN(init_data_path=path)
dataset = mRNN.MichaelsDataset(path)

losses = []
loader = DataLoader(
  dataset, batch_size=64, shuffle=True
)
for i_batch, sampled_batch in enumerate(loader):
    rnn.reset_hidden()
    optimizer.zero_grad()

    din, dout = sampled_batch
    batch_size = din.shape[0]
    example_len = din.shape[1]

    preds = torch.empty((batch_size, example_len, mrnn.output_dim))
    for tidx in range(example_len):
        cur_in = din[:, tidx, :]
        pred = rnn(cur_in.T)
        preds[:, tidx, :] = pred[:, :]

    loss = torch.nn.MSELoss()(preds, dout)
    losses.append(loss.item())

print(losses)
print(statistics.mean(losses))

[0.003970268648117781, 0.0032652635127305984, 0.003434328595176339, 0.004032896365970373, 0.0034519059117883444, 0.003314115572720766, 0.0033993504475802183, 0.004226842429488897]
0.0036368714354466647


In [200]:
# NOTE: if we kill the entire network except for 12 neurons in M1, the
#       loss can still be as good as 0.005, which is only ~50% worse than
#       a healthy network!  So: lesions are not intensely powerful after
#       recovery.

mrnn = mRNN.generate(MRNN_MODEL_PATH, stimulus=stimulus, lesion=lesion_instance,
                     recover_after_lesion=RECOVERY)
# TODO: double check that on reload the masks are set right

Mean loss: 0.04177322122268379
Mean loss: 0.012175323441624641
Mean loss: 0.006701853359118104
Mean loss: 0.005282561818603426
Mean loss: 0.004903594555798918
Mean loss: 0.004813128965906799
Mean loss: 0.004788718360941857
Mean loss: 0.0047788957599550486
Mean loss: 0.004740425036288798
Mean loss: 0.004757479066029191
Mean loss: 0.004753833403810859
Mean loss: 0.004779451293870807


In [6]:
# Reload, so we can pick up from here
mrnn = mRNN.load_from_file(MRNN_MODEL_PATH, lesion=lesion_instance, stimulus=stimulus)

In [7]:
# Generate the training data for our stimulation model (referred
# to as "brain emulator network", or "BEN")
x = generate_stim_train_data.generate(DATA_PATH, mrnn,
                                      observer_instance=observer_instance,
                                      stim_max_power=4.5)

Generating example 0
Generating example 1
Generating example 2
Generating example 3
Generating example 4
Generating example 5
Generating example 6
Generating example 7
Generating example 8
Generating example 9
Generating example 10
Generating example 11
Generating example 12
Generating example 13
Generating example 14
Generating example 15
Generating example 16
Generating example 17
Generating example 18
Generating example 19
Generating example 20
Generating example 21
Generating example 22
Generating example 23
Generating example 24
Generating example 25
Generating example 26
Generating example 27
Generating example 28
Generating example 29
Generating example 30
Generating example 31
Generating example 32
Generating example 33
Generating example 34
Generating example 35
Generating example 36
Generating example 37
Generating example 38
Generating example 39
Generating example 40
Generating example 41
Generating example 42
Generating example 43
Generating example 44
Generating example 4

Generating example 361
Generating example 362
Generating example 363
Generating example 364
Generating example 365
Generating example 366
Generating example 367
Generating example 368
Generating example 369
Generating example 370
Generating example 371
Generating example 372
Generating example 373
Generating example 374
Generating example 375
Generating example 376
Generating example 377
Generating example 378
Generating example 379
Generating example 380
Generating example 381
Generating example 382
Generating example 383
Generating example 384
Generating example 385
Generating example 386
Generating example 387
Generating example 388
Generating example 389
Generating example 390
Generating example 391
Generating example 392
Generating example 393
Generating example 394
Generating example 395
Generating example 396
Generating example 397
Generating example 398
Generating example 399
Generating example 400
Generating example 401
Generating example 402
Generating example 403
Generating 

Generating example 718
Generating example 719
Generating example 720
Generating example 721
Generating example 722
Generating example 723
Generating example 724
Generating example 725
Generating example 726
Generating example 727
Generating example 728
Generating example 729
Generating example 730
Generating example 731
Generating example 732
Generating example 733
Generating example 734
Generating example 735
Generating example 736
Generating example 737
Generating example 738
Generating example 739
Generating example 740
Generating example 741
Generating example 742
Generating example 743
Generating example 744
Generating example 745
Generating example 746
Generating example 747
Generating example 748
Generating example 749
Generating example 750
Generating example 751
Generating example 752
Generating example 753
Generating example 754
Generating example 755
Generating example 756
Generating example 757
Generating example 758
Generating example 759
Generating example 760
Generating 

Generating example 1072
Generating example 1073
Generating example 1074
Generating example 1075
Generating example 1076
Generating example 1077
Generating example 1078
Generating example 1079
Generating example 1080
Generating example 1081
Generating example 1082
Generating example 1083
Generating example 1084
Generating example 1085
Generating example 1086
Generating example 1087
Generating example 1088
Generating example 1089
Generating example 1090
Generating example 1091
Generating example 1092
Generating example 1093
Generating example 1094
Generating example 1095
Generating example 1096
Generating example 1097
Generating example 1098
Generating example 1099
Generating example 1100
Generating example 1101
Generating example 1102
Generating example 1103
Generating example 1104
Generating example 1105
Generating example 1106
Generating example 1107
Generating example 1108
Generating example 1109
Generating example 1110
Generating example 1111
Generating example 1112
Generating examp

Generating example 1414
Generating example 1415
Generating example 1416
Generating example 1417
Generating example 1418
Generating example 1419
Generating example 1420
Generating example 1421
Generating example 1422
Generating example 1423
Generating example 1424
Generating example 1425
Generating example 1426
Generating example 1427
Generating example 1428
Generating example 1429
Generating example 1430
Generating example 1431
Generating example 1432
Generating example 1433
Generating example 1434
Generating example 1435
Generating example 1436
Generating example 1437
Generating example 1438
Generating example 1439
Generating example 1440
Generating example 1441
Generating example 1442
Generating example 1443
Generating example 1444
Generating example 1445
Generating example 1446
Generating example 1447
Generating example 1448
Generating example 1449
Generating example 1450
Generating example 1451
Generating example 1452
Generating example 1453
Generating example 1454
Generating examp

KeyboardInterrupt: 

In [6]:
dataset = stim_model.StimDataset(DATA_PATH)                                                                            
example_len = dataset[0][0].shape[0]                                                                        
in_dim = dataset[0][0].shape[1]                                                                             
out_dim = dataset[0][1].shape[1]

In [7]:
model = stim_model.StimModel(in_dim, out_dim, activation_func=activation)
#optimizer = Adam(model.parameters(), lr=0.006, weight_decay=4e-2)
optimizer = AdamW(model.parameters(), lr=0.005)

In [None]:
stim_model.train_model(dataset,
                       model,
                       optimizer,
                       example_len,
                       in_dim,
                       out_dim,
                       batch_size=BATCH_SIZE,
                       model_save_path=MODEL_PATH,
                       train_stop_thresh=0.0001)

Epoch: 0
Min loss: 0.1420021802186966
Min loss: 0.10729675740003586
Min loss: 0.09725897014141083
Min loss: 0.09184914082288742
Min loss: 0.08833830803632736
Min loss: 0.08311958611011505
Min loss: 0.07805612683296204
Min loss: 0.07453465461730957
Min loss: 0.0712384283542633
Min loss: 0.06910833716392517
Min loss: 0.06677868217229843
Min loss: 0.06580938398838043
Min loss: 0.06550144404172897
Min loss: 0.06541876494884491
Min loss: 0.06480361521244049
Min loss: 0.0646900162100792
Min loss: 0.06448712944984436
Min loss: 0.06402163952589035
Min loss: 0.06395656615495682
Min loss: 0.06326495856046677
Min loss: 0.06230289116501808
Min loss: 0.06199323758482933
Epoch: 1
Min loss: 0.06138768792152405
Min loss: 0.061078138649463654
Min loss: 0.060713693499565125
Min loss: 0.060648564249277115
Min loss: 0.06024439260363579
Min loss: 0.05980215594172478
Min loss: 0.059331055730581284
Min loss: 0.05923217162489891
Min loss: 0.05914818495512009
Min loss: 0.05836954712867737
Min loss: 0.058216493

In [None]:
model.reset()

example_in, example_out = dataset[0]
example_len = example_in.shape[0]
out_dim = example_out.shape[1]
preds = torch.empty((example_len, out_dim))

for tidx in range(example_len):
    cur_in = example_in[tidx,:]
    cur_in = cur_in.reshape((1,) + cur_in.shape)
    pred = model(cur_in)
    preds[tidx, :] = pred[:]

plt.plot(preds[:,8].detach().numpy())
plt.plot(example_out[:,8])
print(preds.shape)
plt.show()