In [64]:
# Define networks
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
import scipy.io as sio

class CTRNN(nn.Module):
    """Continuous-time RNN.

    Parameters:
        input_size: Number of input neurons
        hidden_size: Number of hidden neurons
        dt: discretization time step in ms. 
            If None, dt equals time constant tau

    Inputs:
        input: tensor of shape (seq_len, batch, input_size)
        hidden: tensor of shape (batch, hidden_size), initial hidden activity
            if None, hidden is initialized through self.init_hidden()
        
    Outputs:
        output: tensor of shape (seq_len, batch, hidden_size)
        hidden: tensor of shape (batch, hidden_size), final hidden activity
    """

    def __init__(self, input_size, hidden_size, dt=None, **kwargs):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.tau = 100
        if dt is None:
            alpha = 1
        else:
            alpha = dt / self.tau
        self.alpha = alpha

        self.input2h = nn.Linear(input_size, hidden_size)
        self.h2h = nn.Linear(hidden_size, hidden_size)

    def init_hidden(self, input_shape):
        batch_size = input_shape[1]
        return torch.zeros(batch_size, self.hidden_size)

    def recurrence(self, input, hidden):
        """Run network for one time step.
        
        Inputs:
            input: tensor of shape (batch, input_size)
            hidden: tensor of shape (batch, hidden_size)
        
        Outputs:
            h_new: tensor of shape (batch, hidden_size),
                network activity at the next time step
        """
        h_new = torch.relu(self.input2h(input) + self.h2h(hidden))
        h_new = hidden * (1 - self.alpha) + h_new * self.alpha
        return h_new

    def forward(self, input, hidden=None):
        """Propogate input through the network."""
        
        # If hidden activity is not provided, initialize it
        if hidden is None:
            hidden = self.init_hidden(input.shape).to(input.device)

        # Loop through time
        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.recurrence(input[i], hidden)
            output.append(hidden)

        # Stack together output from all time steps
        output = torch.stack(output, dim=0)  # (seq_len, batch, hidden_size)
        return output, hidden


class RNNNet(nn.Module):
    """Recurrent network model.

    Parameters:
        input_size: int, input size
        hidden_size: int, hidden size
        output_size: int, output size
    
    Inputs:
        x: tensor of shape (Seq Len, Batch, Input size)

    Outputs:
        out: tensor of shape (Seq Len, Batch, Output size)
        rnn_output: tensor of shape (Seq Len, Batch, Hidden size)
    """
    def __init__(self, input_size, hidden_size, output_size, **kwargs):
        super().__init__()

        # Continuous time RNN
        self.rnn = CTRNN(input_size, hidden_size, **kwargs)
        
        # Add an output layer
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        rnn_output, _ = self.rnn(x)
        out = self.fc(rnn_output)
        return out, rnn_output

In [91]:
# Load input and output data
inputoutputs = sio.loadmat('InOutputs')
Inputs       = inputoutputs['Inputs']
Outputs      = inputoutputs['Outputs']

Xnp = []
ynp = []
for cond_i in range(len(Outputs)):
    Xnp.append(np.tile(Inputs[cond_i, :], (Outputs[cond_i][0].shape[0], 1)))
    ynp.append(Outputs[cond_i][0])
Xnp = np.expand_dims(np.vstack(Xnp), axis=2)
ynp = np.vstack(ynp)
X = torch.from_numpy(Xnp).to(torch.float32)
y = torch.from_numpy(ynp).to(torch.float32)

input_size = 1
output_size = 2
batch_size = 16
hidden_size = 8
dt          = 100

dataset    = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size = batch_size, shuffle=True)

# Combine inputs and outputs into a dataset
import torch.optim as optim
import time

# Instantiate the network and print information
net = RNNNet(input_size=input_size, hidden_size=hidden_size,
             output_size=output_size, dt=dt)


def train_model(net):
    """Simple helper function to train the model.
    
    Args:
        net: a pytorch nn.Module module
        dataset: a dataset object that when called produce a (input, target output) pair
    
    Returns:
        net: network object after training
    """
    # Use Adam optimizer
    optimizer = optim.Adam(net.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    running_loss = []
    start_time = time.time()
    # Loop over training batches
    print('Training network...')
    for i in range(1000):
        # Generate input and target, convert to pytorch tensor
        for xx, yy in dataloader:
            inputs = torch.permute(xx, (1, 0, 2))
            labels = torch.permute(yy, (1, 0, 2))

            # boiler plate pytorch training:
            optimizer.zero_grad()   # zero the gradient buffers
            output, _ = net(inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()    # Does the update

        # Compute the running loss every 100 steps
        running_loss.append(loss.item())
        if i % 10 == 9:
            print(f'Epoch [{i+1}/1000], Loss: {loss.item():.4f}')
    return net

netres  = train_model(net)

Training network...
Epoch [10/1000], Loss: 0.2983
Epoch [20/1000], Loss: 0.5653
Epoch [30/1000], Loss: 0.5047
Epoch [40/1000], Loss: 0.4501
Epoch [50/1000], Loss: 0.3410
Epoch [60/1000], Loss: 0.4259
Epoch [70/1000], Loss: 0.2800
Epoch [80/1000], Loss: 0.4322
Epoch [90/1000], Loss: 0.4180
Epoch [100/1000], Loss: 0.4301
Epoch [110/1000], Loss: 0.3629
Epoch [120/1000], Loss: 0.3032
Epoch [130/1000], Loss: 0.4160
Epoch [140/1000], Loss: 0.3615
Epoch [150/1000], Loss: 0.3900
Epoch [160/1000], Loss: 0.4994
Epoch [170/1000], Loss: 0.4500
Epoch [180/1000], Loss: 0.2489
Epoch [190/1000], Loss: 0.3575
Epoch [200/1000], Loss: 0.6129
Epoch [210/1000], Loss: 0.3423
Epoch [220/1000], Loss: 0.3378


KeyboardInterrupt: 