In [1]:
%matplotlib inline

In [2]:
import matplotlib.pylab as plt
import torch
import numpy as np
import seaborn as sn
sn.set_context("poster")
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from torchvision import transforms, datasets

from cornn import coRNN
import numpy as np
import scipy
import scipy.stats as st
import scipy.special
import scipy.signal
import scipy.interpolate

import pandas as pd

from os.path import join
import random
from csv import DictWriter

from tqdm.notebook import tqdm
import pickle
# if gpu is to be used
use_cuda = torch.cuda.is_available()
print(use_cuda)
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
DoubleTensor = torch.cuda.DoubleTensor if use_cuda else torch.DoubleTensor
IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
ttype = FloatTensor

True


In [3]:
import torch
import numpy as np

def get_batch(batch_size, T, ttype):
    values = torch.rand(T, batch_size, requires_grad=False)
    indices = torch.zeros_like(values)
    half = int(T / 2)
    for i in range(batch_size):
        half_1 = np.random.randint(half)
        hals_2 = np.random.randint(half, T)
        indices[half_1, i] = 1
        indices[hals_2, i] = 1

    data = torch.stack((values, indices), dim=-1).type(ttype)
    targets = torch.mul(values, indices).sum(dim=0).type(ttype)
    return data, targets


In [4]:
inp, _ = get_batch(4,500,ttype)

In [5]:
torch.manual_seed(1111)

<torch._C.Generator at 0x7fe40a4608f0>

In [6]:
def train(model, ttype, seq_length, optimizer, loss_func, 
          epoch, perf_file, loss_buffer_size=20, batch_size=1, test_size=10,
          device='cuda', prog_bar=None):
    assert(loss_buffer_size%batch_size==0)

    losses = []
    perfs = []
    last_test_perf = 0
    for batch_idx in range(20000):
        model.train()
        sig, target = get_batch(batch_size, seq_length, ttype=ttype)
        
        target = target.unsqueeze(1)
        optimizer.zero_grad()
        out = model(sig)
        loss = loss_func(out,
                         target)
         
        loss.backward()
        optimizer.step()

        losses.append(loss.detach().cpu().numpy())
        losses = losses[-loss_buffer_size:]
        if not (prog_bar is None):
            # Update progress_bar
            s = "{}:{} Loss: {:.8f}"
            format_list = [e, int(batch_idx/(50/batch_size)), np.mean(losses)]         
            s = s.format(*format_list)
            prog_bar.set_description(s)
        if ((batch_idx*batch_size)%loss_buffer_size == 0) & (batch_idx != 0):
            loss_track = {}
            #last_test_perf = test_norm(model, 'cuda', test_sig, test_class,
            #                                    batch_size=test_size, 
            #                                    )
            loss_track['avg_loss'] = np.mean(losses)
            #loss_track['last_test'] = last_test_perf
            loss_track['epoch'] = epoch
            loss_track['batch_idx'] = batch_idx
            with open(perf_file, 'a+') as fp:
                csv_writer = DictWriter(fp, fieldnames=list(loss_track.keys()))
                if fp.tell() == 0:
                    csv_writer.writeheader()
                csv_writer.writerow(loss_track)
                fp.flush()
def test_norm(model, device, seq_length, loss_func, batch_size=100):
    model.eval()
    correct = 0
    count = 0
    with torch.no_grad():
        sig, target = get_batch(batch_size, seq_length, ttype=ttype)
        target = target.unsqueeze(1)
        out = model(sig)
        loss = loss_func(out,
                         target)
    return loss

In [7]:
from torch import nn
import torch
from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class coRNNCell(nn.Module):
    def __init__(self, n_inp, n_hid, dt, gamma, epsilon):
        super(coRNNCell, self).__init__()
        self.dt = dt
        self.gamma = gamma
        self.epsilon = epsilon
        self.i2h = nn.Linear(n_inp + n_hid + n_hid, n_hid)

    def forward(self,x,hy,hz):
        hz = hz + self.dt * (torch.tanh(self.i2h(torch.cat((x, hz, hy),1)))
                                   - self.gamma * hy - self.epsilon * hz)
        hy = hy + self.dt * hz

        return hy, hz

class coRNN(nn.Module):
    def __init__(self, n_inp, n_hid, n_out, dt, gamma, epsilon):
        super(coRNN, self).__init__()
        self.n_hid = n_hid
        self.cell = coRNNCell(n_inp,n_hid,dt,gamma,epsilon)
        self.readout = nn.Linear(n_hid, n_out)

    def forward(self, x):
        ## initialize hidden states
        hy = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)
        hz = Variable(torch.zeros(x.size(1),self.n_hid)).to(device)

        for t in range(x.size(0)):
            hy, hz = self.cell(x[t],hy,hz)
        output = self.readout(hy)

        return output

# T = 500

In [8]:
cornn_params = dict(n_inp=2,
                    n_hid=128, 
                    n_out=1,
                    dt=6e-2,
                    gamma=66,
                    epsilon=15)
model = coRNN(**cornn_params).cuda()

tot_weights = 0
for p in model.parameters():
    tot_weights += p.numel()
print("Total Weights:", tot_weights)
print(model)

Total Weights: 33281
coRNN(
  (cell): coRNNCell(
    (i2h): Linear(in_features=258, out_features=128, bias=True)
  )
  (readout): Linear(in_features=128, out_features=1, bias=True)
)


In [32]:
print(inp.shape, model(inp).shape)

torch.Size([500, 4, 2]) torch.Size([4, 1])


In [33]:
seq_length=500

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)
epochs = 1
progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')
for e in progress_bar:
    train(model, ttype, seq_length,
          optimizer, loss_func, batch_size=50, loss_buffer_size=100,
          epoch=e, perf_file=join('perf','adding500_cornn_2.csv'),
          prog_bar=progress_bar)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

KeyboardInterrupt: 

# T = 2000

In [None]:
cornn_params = dict(n_inp=2,
                    n_hid=128, 
                    n_out=1,
                    dt=6e-2,
                    gamma=66,
                    epsilon=15)
model = coRNN(**cornn_params).cuda()

tot_weights = 0
for p in model.parameters():
    tot_weights += p.numel()
print("Total Weights:", tot_weights)
print(model)

In [None]:
seq_length=2000

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)
epochs = 1
progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')
for e in progress_bar:
    train(model, ttype, seq_length,
          optimizer, loss_func, batch_size=50, loss_buffer_size=100,
          epoch=e, perf_file=join('perf','adding2000_cornn_1.csv'),
          prog_bar=progress_bar)

#  T = 5000

In [None]:
cornn_params = dict(n_inp=2,
                    n_hid=128, 
                    n_out=1,
                    dt=6e-2,
                    gamma=66,
                    epsilon=15)
model = coRNN(**cornn_params).cuda()

tot_weights = 0
for p in model.parameters():
    tot_weights += p.numel()
print("Total Weights:", tot_weights)
print(model)

In [None]:
seq_length=5000

loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-2)
epochs = 1
progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')
for e in progress_bar:
    train(model, ttype, seq_length,
          optimizer, loss_func, batch_size=50, loss_buffer_size=100,
          epoch=e, perf_file=join('perf','adding5000_cornn_1.csv'),
          prog_bar=progress_bar)

# Data Plot

In [None]:
dat = pd.read_csv(os.path.join('perf', 'adding500_cornn_2.csv'))
dat2 = pd.read_csv(os.path.join('perf', 'adding2000_cornn_2.csv'))
dat3 = pd.read_csv(os.path.join('perf', 'adding5000_cornn_1.csv'))

In [None]:
dat['training_step'] = (dat.batch_idx)/100
dat2['training_step'] = (dat2.batch_idx)/100
dat3['training_step'] = (dat3.batch_idx)*10/5/100


In [None]:

fig = plt.figure(figsize=(15,10))
sn.lineplot(data=dat, x=dat.training_step, y=dat.avg_loss,)
sn.lineplot(data=dat2, x=dat2.training_step, y=dat2.avg_loss,)
ax = sn.lineplot(data=dat3, x=dat3.training_step, y=dat3.avg_loss,)
plt.legend(['T=500', 'T=2000', 'T=5000'])
ax.set_xlabel('Training steps (hundreds)')
ax.set_ylabel('Average Loss')
ax.set_xlim(0, 150)
ax.set_ylim(0, .2)
plt.grid(True)
