In [54]:
import torch
import torch.nn as nn
import torch.optim as optim

import torch.nn.functional as F

import os
import numpy as np

from torch.utils.data import Dataset, DataLoader

In [150]:
class SimpleRNN(nn.Module):
    def __init__(self, n_input, n_rec, n_output, activation=torch.tanh):
        super().__init__()

        self.n_input = n_input
        self.n_rec = n_rec
        self.n_output = n_output

        # random initial weight
        self.W_input = torch.nn.Parameter(torch.randn(n_input, n_rec))
        self.W_rec = torch.nn.Parameter(torch.randn(n_rec, n_rec))
        self.b_rec = torch.nn.Parameter(torch.zeros(1, n_rec))
        
        self.W_output = torch.nn.Parameter(torch.randn(n_rec, n_output))
        self.b_out = torch.nn.Parameter(torch.zeros(1, n_output))

        self.activation = activation

        self.h = torch.zeros(1, n_rec)
        
    def reset_hidden(self):
        self.h = torch.zeros(1, self.n_rec)


    def forward(self, X):
        # X should be of size (1, n_input)
        self.h = self.activation(
            torch.mm(X, self.W_input) + \
            torch.mm(self.h, self.W_rec) + \
            self.b_rec)

        self.out = self.activation(
            torch.mm(self.h, self.W_output) + \
            self.b_out)

        return self.h, self.out

In [151]:
TOTAL_STEPS = 50
NUM_INPUT = 4
NUM_OUTPUT = 4
NUM_UNITS = 10
BATCH_SIZE = 1000
NUM_EPOCHS = 10

LR = 0.001

In [153]:
x_all = np.concatenate((x_stim, x_go), axis=1)
y_all = np.concatenate((y_stim, y_go), axis=1)

In [158]:
criterion = nn.MSELoss()

In [159]:
net = SimpleRNN(NUM_INPUT, NUM_UNITS, NUM_OUTPUT)

In [160]:
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [173]:


for epoch in range(NUM_EPOCHS):
    train_loss = 0.0
    train_acc = 0.0
    
    train_losses = []
    
    net.train()
    
    for i,data in enumerate(dataloader):
        x_cur = data['x'][0]
        y_cur = data['y'][0]
        
        net.reset_hidden()
        optimizer.zero_grad()
        
        outs = []
        for timestep in x_cur:
            _, out = net(torch.unsqueeze(timestep,0))
            outs.append(out)
            
        out_tensor = torch.cat(outs)
        loss = 100*criterion(out_tensor,y_cur)
                
        loss.backward()
        optimizer.step()

        train_losses.append(loss.detach().item())
        
print(train_losses[0::100])
        

[6.712124013574794e-05, 6.281041714828461e-05, 6.1519313021563e-05, 6.672004383290187e-05, 5.735182639909908e-05, 6.090680108172819e-05, 5.919143222854473e-05, 5.671001417795196e-05, 8.318993059219792e-05, 5.297124516800977e-05]


In [170]:
class TaskDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        sample = {'x': torch.from_numpy(self.x[idx]).to(torch.float),
                  'y': torch.from_numpy(self.y[idx]).to(torch.float)
                 }
            
        return sample

In [98]:
td = TaskDataset(x_all, y_all)

In [99]:
dataloader = DataLoader(td, batch_size=1)

In [108]:
torch.zeros(1,5)

tensor([[0., 0., 0., 0., 0.]])