In [1]:
import numpy as np
import argparse
import time
import os
from itertools import chain
import matplotlib.pyplot as plt
import multiprocessing as mp

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from datasets import inputs
from sr_model.models.models import AnalyticSR, STDP_SR

## Collect the parameter range of interest
$$area=A\tau$$

In [2]:
def get_kernel_params(tau_0, A_0, tau_offset):
    area = A_0*tau_0
    new_tau = tau_0 + tau_offset
    new_A = area/new_tau
    return new_A, new_tau
    
def get_symm_kernel_params(tau_0, A_0, tau_offset):
    area = A_0*tau_0
    new_tau = tau_0 + tau_offset
    new_A = (area/2)/new_tau
    return new_A, new_tau

def get_stdp_kernel(
    A_pos, tau_pos, A_neg, tau_neg, kernel_len
    ):
    """ Returns plasticity kernel for plotting or debugging. """

    k = np.zeros(kernel_len)
    half_len = kernel_len//2
    scaling = 1
    k[:half_len] = scaling*A_neg * np.exp(
        np.arange(-half_len, 0)/tau_neg
        )
    k[-half_len-1:] = scaling*A_pos * np.exp(
        -1*np.arange(half_len+1)/tau_pos
        )
    return k

In [4]:
# Order: [A_pos, tau_pos, A_neg, tau_neg]
A_0 = 0.5
tau_0 = 1.15
set0 = get_kernel_params(tau_0, A_0, -0.5)
set1 = get_kernel_params(tau_0, A_0, -0.25)
set2 = get_kernel_params(tau_0, A_0, 0.5)
set3 = get_kernel_params(tau_0, A_0, 0.75)
set4 = get_symm_kernel_params(tau_0, A_0, -0)

params = [
    [A_0, tau_0, 0, 1],
    [set0[0], set0[1], 0, 1],
    [set1[0], set1[1], 0, 1],
    [set2[0], set2[1], 0, 1],
    [set3[0], set3[1], 0, 1],
    [set4[0], set4[1], set4[0], set4[1]],
    ]

errs = []

In [5]:
params

[[0.5, 1.15, 0, 1],
 [0.8846153846153847, 0.6499999999999999, 0, 1],
 [0.638888888888889, 0.8999999999999999, 0, 1],
 [0.34848484848484845, 1.65, 0, 1],
 [0.3026315789473684, 1.9, 0, 1],
 [0.25, 1.15, 0.25, 1.15]]

## Initialize data structures

In [10]:
device = 'cpu'

# Dataset Configs
dset = inputs.RBYCacheWalk(
    num_spatial_states=25*25,
    downsample_factor=None,
    skip_frame=0.7
    )

## Run through params

In [20]:
def eval_model(args):
    param = args
    
    net = STDP_SR(num_states=692, gamma=0.4)
    net.ca3.reset_trainable_ideal()
    net.ca3.set_differentiability(False)
    
    nn.init.constant_(net.ca3.A_pos, param[0])
    nn.init.constant_(net.ca3.tau_pos, param[1])
    nn.init.constant_(net.ca3.A_neg, param[2])
    nn.init.constant_(net.ca3.tau_neg, param[3])
    
    with torch.no_grad():
        dg_inputs = torch.from_numpy(dset.dg_inputs.T).float().to(device).unsqueeze(1)
        dg_modes = torch.from_numpy(dset.dg_modes.T).float().to(device).unsqueeze(1)
        _, outputs = net(dg_inputs, dg_modes, reset=True)
        est_T = net.ca3.get_T().detach().numpy()
        real_T = net.ca3.get_real_T()
        err = np.mean(np.abs(est_T - real_T))
        M0 = net.ca3.get_M_hat()
        M1 = net.ca3.get_M_hat(gamma=0.6)
    
    results = [param, M0, M1, err]
    return results

In [21]:
args_list = params

In [22]:
def main():
    pool = mp.Pool(1)
    result = pool.map(eval_model, args_list)
    return result

In [23]:
result = main()

Traceback (most recent call last):
Process ForkPoolWorker-3:
  File "/home/chingf/anaconda3/envs/aronov/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()


KeyboardInterrupt: 

  File "/home/chingf/anaconda3/envs/aronov/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/chingf/anaconda3/envs/aronov/lib/python3.7/multiprocessing/pool.py", line 121, in worker
    result = (True, func(*args, **kwds))
  File "/home/chingf/anaconda3/envs/aronov/lib/python3.7/multiprocessing/pool.py", line 44, in mapstar
    return list(map(*args))
  File "<ipython-input-20-1108e07693b5>", line 16, in eval_model
    _, outputs = net(dg_inputs, dg_modes, reset=True)
  File "/home/chingf/anaconda3/envs/aronov/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chingf/Code/sr-project/src/sr_model/models/models.py", line 81, in forward
    self.update()
  File "/home/chingf/Code/sr-project/src/sr_model/models/models.py", line 99, in update
    self.ca3.update()
  File "/home/chingf/Code/sr-project/src/sr_model/models/ca3.py", line 12

## Plot Place Cells

In [7]:
from sr_model.plotting import SpatialPlot

In [None]:
plotter = SpatialPlot(
    None, 16, None, input.sorted_states[:16], None
    )

In [None]:
for res in result:
    param, M0, M1, err = res
    heatmap = SpatialPlot._format_into_heamap(M0)
    print(param)
    plt.figure()
    plt.imshow(heatmap)
    plt.show()