In [1]:
import numpy as np
from scipy.io import loadmat
from rnn_utils import RateUnitNetwork
import torch
from torch import nn, optim
import numpy.linalg as la
import copy

%matplotlib widget
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML


In [2]:
def animate_scatter(points, targ_points=None, keep_last=1, carray=None, interval=5):
    def update(i, data, scat):
        scat.set_offsets(data[max(0,i-keep_last):i])
        if carray is not None:
            scat.set_array(carray[max(0,i-keep_last):i])
        return scat

    fig, ax = plt.subplots(figsize=(8,4))
    scat = ax.scatter(*points[0].T, s=25)
    
    if targ_points is not None:
        ax.scatter(*targ_points, marker='x', s=0.1)
        ax.set_xlim([-0.6,0.4])
        ax.set_ylim([-0.2,0.4])
    else:
        mins = points.min(0)
        maxs = points.max(0)
        ax.set_xlim(mins[0]-0.1, maxs[0]+0.1)
        ax.set_ylim(mins[1]-0.1, maxs[1]+0.1)

    if carray is not None:
        scat.set_clim(vmin=carray.min(), 
                      vmax=carray.max())

    _ = ax.grid(True)
    plt.close()
    
    ani = animation.FuncAnimation(fig, update, frames=range(1,points.shape[0]),
                                  interval = interval, blit=True,
                                  fargs=(points, scat))
    return ani

In [3]:
handwriting_data_mat = './nn.3405-S2/DAC_handwriting_training/DAC_handwriting_output_targets.mat'
handwriting_data = loadmat(handwriting_data_mat)

In [4]:
device = 'cpu'
input_dim = 2
hidden_dim = 800
output_dim = 2
noise_std = 0.0
stimulus_amp = 2.0
t_init = 200 
t_stim= 50

In [6]:
imdl = 20000
ckp = torch.load('/media/druckmannuser/4C82D91382D901FE/handwriting_models/handwriting_mdl_new2_{}.pt'.format(imdl),
                 map_location=lambda storage, loc: storage)
model = RateUnitNetwork(input_dim, hidden_dim, output_dim)
model.load_state_dict(ckp)

In [None]:
# model = RateUnitNetwork(input_dim, hidden_dim, output_dim)
# mdl_path = '/media/druckmannuser/4C82D91382D901FE/handwriting_mdl_2_backup.pt'
# ckp = torch.load(mdl_path)
# #ckp
# #model.load_state_dict(ckp.drop())
# model = RateUnitNetwork(input_dim, 
#                         hidden_dim, 
#                         output_dim, 
#                         include_bias=True, 
#                         noise_amp = 0.001)

# model.i2h.weight.data = ckp['i2h.weight']
# model.i2h.bias.data = ckp['i2h.bias']
    
# model.h2h.weight.data = ckp['h2h.weight']
# model.h2h.bias.data = ckp['h2h.bias']

# model.h2o.weight.data = ckp['h2o.weight']
# model.h2o.bias.data = ckp['h2o.bias']

In [7]:
patterns = ['neuron', 'chaos']
targets = {}
inputs = {}
for j, pat in enumerate(patterns):
    target = handwriting_data[pat].astype(np.float32)
    targets[pat] = torch.from_numpy(target.T).to(device=device).view((1,) + target.T.shape)
    t_total = t_init + t_stim + target.shape[1]
    inputs[pat] = torch.zeros((1, t_total, 2), device=device)
    inputs[pat][0, t_init:t_init+t_stim, j] = stimulus_amp


In [8]:
def run_network_for_input(inp):
    hidden = 2*torch.rand((1, hidden_dim), device=device) - 1
    if isinstance(inp, str):
        outputs, hiddens = model(inputs[pat].to(device=device), 
                                 hidden.to(device))
    else: # assume we got a tensor
        outputs, hiddens = model(inp, hidden.to(device))
    
    outputs = outputs.detach().cpu().numpy()[0]
    hiddens = hiddens.detach().cpu().numpy()[0]
    return outputs, hiddens


## Plot some different runs of the two patterns

In [9]:
for pat in patterns:
    fig, ax = plt.subplots(1,3,figsize=(12,3))
    for i in range(3):
        outputs, _ = run_network_for_input(pat)
        outputs = outputs[t_init+t_stim:] # trim initial phase prior to stimulus
        ax[i].scatter(*outputs.T, 
                      c=np.arange(outputs.shape[0]))
        ax[i].scatter(*targets[pat].detach().cpu().numpy()[0].T, 
                      marker='x', s=0.1)
plt.show()

FigureCanvasNbAgg()

FigureCanvasNbAgg()

## OK Cool story but I thought this was about dynamics?

In [10]:
pat = 'chaos'
outputs, _ = run_network_for_input(pat)
ani = animate_scatter(outputs, 
                      targ_points = targets[pat].detach().cpu().numpy()[0].T,
                      keep_last = 25,
                      carray=np.arange(outputs.shape[0])
                     )
HTML(ani.to_html5_video())

FigureCanvasNbAgg()

## What happens when we perturb the network while it's writing?

In [11]:
pat = 'neuron'
pulse_time = 400
pulse_width = 10
pulse_std = 0.5

corrupted_input = copy.deepcopy(inputs[pat]).to(device=device)
corrupted_input[0,pulse_time:pulse_time+pulse_width,:] = pulse_std*torch.randn((1,pulse_width,2))
plt.figure()
plt.plot(corrupted_input[0].detach().cpu().numpy())
plt.title('corrupted stimulus for pattern: {}'.format(pat))
plt.xlabel('Time (ms)')
plt.ylabel('Input [au]')
plt.legend(['Dim 1', 'Dim 2'])
plt.show()

outputs, _ = run_network_for_input(corrupted_input)
outputs = outputs[t_init:]
ani = animate_scatter(outputs,
                      keep_last = int(outputs.shape[0]),
                      carray=np.arange(outputs.shape[0]),
                      targ_points = targets[pat].detach().cpu().numpy()[0].T
                     )
HTML(ani.to_html5_video())

FigureCanvasNbAgg()

FigureCanvasNbAgg()

## Now let's look at the intrinsic dynamics of this network

In [12]:
dur = 10000 #ms
outputs, _ = run_network_for_input(torch.zeros((1,dur,2)))
ani = animate_scatter(outputs,
                      keep_last = 125,
                      carray=np.arange(outputs.shape[0]),
                     )

HTML(ani.to_html5_video())

FigureCanvasNbAgg()

In [13]:
from matplotlib.gridspec import GridSpec
def animate_population(positions, activity, trajectory, keep_last=1, targ_points=None, carray=None, interval=5):
    def update(i, activity, trajectory, scat0, scat1):
        scat0.set_array(activity[i])
        scat1.set_offsets(trajectory[max(0,i-keep_last):i])
        if carray is not None:
            scat1.set_array(carray[max(0,i-keep_last):i])
        return scat0, scat1

    fig = plt.figure(figsize=(9,3))
    gs = GridSpec(1,5)
    ax0 = fig.add_subplot(gs[:2])
    scat0 = ax0.scatter(*positions.T, s=10, c=activity[0])
    v0, v1 = np.percentile(activity, [2.5, 97.5])
    scat0.set_clim(vmin=-1, vmax=1)
    #scat0.set_clim(vmin=v0, vmax=v1)
    ax0.axis('square')
    ax0.set_xticklabels([])
    ax0.set_yticklabels([])
    ax0.set_title('Hidden Unit Activity')
    
    ax1 = fig.add_subplot(gs[2:])    
    scat1 = ax1.scatter(*trajectory[0].T, s=25)
    
    if targ_points is not None:
        ax1.scatter(*targ_points, marker='x', s=0.1)
        ax1.set_xlim([-0.6,0.4])
        ax1.set_ylim([-0.2,0.4])
    else:
        mins = points.min(0)
        maxs = points.max(0)
        ax1.set_xlim(mins[0]-0.1, maxs[0]+0.1)
        ax1.set_ylim(mins[1]-0.1, maxs[1]+0.1)

    if carray is not None:
        scat1.set_clim(vmin=carray.min(), 
                       vmax=carray.max())

    _ = ax1.grid(True)
    
    plt.close()
    
    ani = animation.FuncAnimation(fig, update, frames=range(activity.shape[0]),
                                  interval = interval, blit=True,
                                  fargs=(activity, trajectory, scat0, scat1))
    return ani

In [14]:
def gen_positions(n,scale=1):
    r = scale * np.sqrt(np.random.rand(n))
    theta =  2 * np.pi * np.random.rand(n)
    x = r*np.cos(theta)
    y = r*np.sin(theta)
    return np.c_[x,y]

In [15]:
positions = gen_positions(hidden_dim)

pat = 'chaos'
outputs, hiddens = run_network_for_input(pat)
ani = animate_population(positions,
                         np.tanh(hiddens),
                         outputs, 
                         targ_points = targets[pat].detach().cpu().numpy()[0].T,
                         keep_last = 25,
                         carray=np.arange(outputs.shape[0]),
                         interval=10
                        )

HTML(ani.to_html5_video())

FigureCanvasNbAgg()

In [None]:
plt.close('all')

In [None]:
# #%matplotlib notebook
# #%matplotlib inline
# %matplotlib widget

# from matplotlib import pyplot as plt
# from matplotlib.animation import FuncAnimation
# from random import randrange
# from threading import Thread
# import time

# class LiveGraph:
#     def __init__(self):
#         self.x_data, self.y_data = [], []
#         self.figure = plt.figure()
#         self.line, = plt.plot(self.x_data, self.y_data)
#         self.animation = FuncAnimation(self.figure, self.update, interval=1000)
#         self.th = Thread(target=self.thread_f, daemon=True)
#         self.th.start()

#     def update(self, frame):
#         self.line.set_data(self.x_data, self.y_data)
#         self.figure.gca().relim()
#         self.figure.gca().autoscale_view()
#         return self.line,

#     def show(self):
#         plt.show()

#     def thread_f(self):
#         x = 0
#         while True:
#             self.x_data.append(x)
#             x += 1
#             self.y_data.append(randrange(0, 100))   
#             time.sleep(1)  

# g = LiveGraph()
# g.show()

In [None]:
import numpy.linalg as la
l,v = la.eig(model.h2h.weight.cpu().detach().numpy())
_=plt.scatter(np.real(l), np.imag(l))
_=plt.axis('square')

In [None]:
%matplotlib widget
from ipywidgets import *
from IPython import display

hidden = 2*torch.rand((1, hidden_dim), device=device) - 1
output, hidden = model(torch.zeros(1,1,2).to(device=device), 
                       hidden.to(device))

output = output.detach().cpu().numpy()[0].T
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
scat = ax.scatter(output[0], output[1])
ax.set_xlim([-10, 10])
ax.set_ylim([-10, 10])

def update():
    output, hidden = model(torch.zeros(1,1,2).to(device=device), 
                           model.hiddens)
    output = output.detach().cpu().numpy()[0]
    #print(output.shape)
    scat.set_offsets(output)
    #fig.canvas.draw_idle()
    fig.canvas.draw()

import time
for i in range(250):
    _=interact(update)
    #display.display(fig)
    #display.clear_output(wait=True)
    time.sleep(0.01)



In [None]:
from IPython.display import clear_output
from matplotlib import pyplot as plt
import collections
%matplotlib inline

def live_plot(data_dict, figsize=(7,5), title=''):
    clear_output(wait=True)
    plt.figure(figsize=figsize)
    for label,data in data_dict.items():
        plt.plot(data, label=label)
    plt.title(title)
    plt.grid(True)
    plt.xlabel('epoch')
    plt.legend(loc='center left') # the plot evolves to the right
    plt.show();

In [None]:
data = collections.defaultdict(list)
for i in range(100):
    data['foo'].append(np.random.random())
    data['bar'].append(np.random.random())
    data['baz'].append(np.random.random())
    live_plot(data)