In [2]:
import numpy as np
from scipy.io import loadmat
from rnn_utils import RateUnitNetwork
import torch
from torch import nn, optim
import copy

%matplotlib widget
import matplotlib.pyplot as plt
from IPython.display import HTML

from vis_utils import animate_trajectory, animate_activity


In [3]:
handwriting_data_mat = './nn.3405-S2/DAC_handwriting_training/DAC_handwriting_output_targets.mat'
handwriting_data = loadmat(handwriting_data_mat)

In [4]:
device = 'cpu'
input_dim = 2
hidden_dim = 800
output_dim = 2
noise_std = 0.0
stimulus_amp = 2.0
t_init = 200 
t_stim= 50

In [5]:
imdl = 20000
ckp = torch.load('./handwriting_models/handwriting_mdl_new2_{}.pt'.format(imdl),
                 map_location=lambda storage, loc: storage)
model = RateUnitNetwork(input_dim, hidden_dim, output_dim)
model.load_state_dict(ckp)

In [None]:
# model = RateUnitNetwork(input_dim, hidden_dim, output_dim)
# mdl_path = '/media/druckmannuser/4C82D91382D901FE/handwriting_mdl_2_backup.pt'
# ckp = torch.load(mdl_path)
# #ckp
# #model.load_state_dict(ckp.drop())
# model = RateUnitNetwork(input_dim, 
#                         hidden_dim, 
#                         output_dim, 
#                         include_bias=True, 
#                         noise_amp = 0.001)

# model.i2h.weight.data = ckp['i2h.weight']
# model.i2h.bias.data = ckp['i2h.bias']
    
# model.h2h.weight.data = ckp['h2h.weight']
# model.h2h.bias.data = ckp['h2h.bias']

# model.h2o.weight.data = ckp['h2o.weight']
# model.h2o.bias.data = ckp['h2o.bias']

In [6]:
patterns = ['neuron', 'chaos']
targets = {}
inputs = {}
for j, pat in enumerate(patterns):
    target = handwriting_data[pat].astype(np.float32)
    targets[pat] = torch.from_numpy(target.T).to(device=device).view((1,) + target.T.shape)
    t_total = t_init + t_stim + target.shape[1]
    inputs[pat] = torch.zeros((1, t_total, 2), device=device)
    inputs[pat][0, t_init:t_init+t_stim, j] = stimulus_amp


In [7]:
def run_network_for_input(inp):
    hidden = 2*torch.rand((1, hidden_dim), device=device) - 1
    if isinstance(inp, str):
        outputs, hiddens = model(inputs[pat].to(device=device), 
                                 hidden.to(device))
    else: # assume we got a tensor
        outputs, hiddens = model(inp, hidden.to(device))
    
    outputs = outputs.detach().cpu().numpy()[0]
    hiddens = hiddens.detach().cpu().numpy()[0]
    return outputs, hiddens


## Plot some different runs of the two patterns

In [8]:
for pat in patterns:
    fig, ax = plt.subplots(1,3,figsize=(12,3))
    for i in range(3):
        outputs, _ = run_network_for_input(pat)
        outputs = outputs[t_init+t_stim:] # trim initial phase prior to stimulus
        ax[i].scatter(*outputs.T, 
                      c=np.arange(outputs.shape[0]))
        ax[i].scatter(*targets[pat].detach().cpu().numpy()[0].T, 
                      marker='x', s=0.1)
plt.show()

FigureCanvasNbAgg()

FigureCanvasNbAgg()

## OK Cool story but I thought this was about dynamics?

In [9]:
pat = 'chaos'
outputs, _ = run_network_for_input(pat)
ani = animate_trajectory(outputs, 
                         targ_points = targets[pat].detach().cpu().numpy()[0].T,
                         keep_last = 25,
                         carray=np.arange(outputs.shape[0])
                        )
HTML(ani.to_html5_video())

FigureCanvasNbAgg()

## What happens when we perturb the network while it's writing?

In [10]:
pat = 'neuron'
pulse_time = 400
pulse_width = 10
pulse_std = 0.5

corrupted_input = copy.deepcopy(inputs[pat]).to(device=device)
corrupted_input[0,pulse_time:pulse_time+pulse_width,:] = pulse_std*torch.randn((1,pulse_width,2))
plt.figure()
plt.plot(corrupted_input[0].detach().cpu().numpy())
plt.title('corrupted stimulus for pattern: {}'.format(pat))
plt.xlabel('Time (ms)')
plt.ylabel('Input [au]')
plt.legend(['Dim 1', 'Dim 2'])
plt.show()

outputs, _ = run_network_for_input(corrupted_input)
outputs = outputs[t_init:]
ani = animate_trajectory(outputs,
                         keep_last = int(outputs.shape[0]),
                         carray=np.arange(outputs.shape[0]),
                         targ_points = targets[pat].detach().cpu().numpy()[0].T
                        )
HTML(ani.to_html5_video())

FigureCanvasNbAgg()

FigureCanvasNbAgg()

## Now let's look at the intrinsic dynamics of this network

In [12]:
dur = 10000 #ms
outputs, _ = run_network_for_input(torch.zeros((1,dur,2)))
ani = animate_trajectory(outputs,
                         keep_last = 1000,
                         carray=np.arange(outputs.shape[0]),
                        )

HTML(ani.to_html5_video())

FigureCanvasNbAgg()

In [15]:
pat = 'chaos'
outputs, hiddens = run_network_for_input(pat)
ani = animate_activity(np.tanh(hiddens),
                       outputs, 
                       targ_points = targets[pat].detach().cpu().numpy()[0].T,
                       keep_last = 25,
                       carray=np.arange(outputs.shape[0]),
                       interval=10
                      )

HTML(ani.to_html5_video())

FigureCanvasNbAgg()

In [None]:
# import numpy.linalg as la
# l,v = la.eig(model.h2h.weight.cpu().detach().numpy())
# _=plt.scatter(np.real(l), np.imag(l))
# _=plt.axis('square')