In [None]:
# IF RNN-FF Example.

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Model definition.

class RNN_FF(nn.Module):
    def __init__(self, size_x, size_h):
        super(RNN_FF, self).__init__()
        self.size_x = size_x
        self.size_h = size_h

        self.rnn = nn.RNN(self.size_x, self.size_h, nonlinearity='tanh')
        self.linear = nn.Linear(self.size_h, 1, bias=False)
    
    def forward(self, x):
        output, _ = self.rnn(x)
        output    = self.linear(output)
        return output

In [None]:
# Prepare data.

from neuron_models import IF
from neuron_models.framework import neuron_dataloader

# Data and model parameters.
batch_size = 1
size_x = 1
size_h = 3

# Generate data.
neuron_model = IF()
neuron_model.run(300)

# Load data.
dataloader_train = neuron_dataloader(
    neuron_model.I,
    neuron_model.V,
    batch_size = batch_size,
    out_dimension = (1,batch_size,size_x),
)

In [None]:
# Visualize data.

fig = plt.figure(figsize=(15,2))
plt.title('Input current vs Time')
plt.ylabel('Current')
plt.xlabel('Time')
plt.grid()
plt.plot(neuron_model.T,neuron_model.I, "orange")
plt.show()

fig = plt.figure(figsize=(15,2))
plt.title('Membrane potential vs Time')
plt.ylabel('Potential')
plt.xlabel('Time')
plt.grid()
plt.plot(neuron_model.T,neuron_model.V)
plt.show()

In [None]:
# Training.

from neuron_models.framework import Trainer
from tqdm import trange

# Define model and training parameters.
model = RNN_FF(size_x, size_h)
train_params = {
    'epochs': 10,
    'lr': 0.01,
    'loss_fn': F.l1_loss,
}

# Train loop.
trainer = Trainer(model, dataloader_train, dtype=torch.float64)
LOSS_TRAIN = np.array([])

for epoch in trange(train_params['epochs']):
    model, loss_train = trainer.run(**train_params)
    LOSS_TRAIN = np.append(LOSS_TRAIN, loss_train) # Collect results.

In [None]:
# Visualize loss.

fig = plt.figure(figsize=(15,2))
plt.title('Loss vs Steps')
plt.ylabel('Loss')
plt.xlabel('Steps')
plt.grid()
plt.plot(LOSS_TRAIN)
plt.show()

In [None]:
# Evaluate model.

size = len(neuron_model.T)

with torch.no_grad():
    X = torch.ones(size, dtype=torch.double) 
    X = X.reshape(size,1,1)
    Y = model(X)

# Visualize.
fig = plt.figure(figsize=(15,2))
plt.title('Input current vs Time')
plt.ylabel('Current')
plt.xlabel('Time')
plt.grid()
plt.plot(X.reshape(size), "orange")
plt.show()

fig = plt.figure(figsize=(15,2))
plt.title('Membrane potential vs Time (network output)')
plt.ylabel('Potential')
plt.xlabel('Time')
plt.grid()
plt.plot(Y.reshape(size))
plt.show()

fig = plt.figure(figsize=(15,2))
plt.title('Membrane potential vs Time (training data)')
plt.ylabel('Potential')
plt.xlabel('Time')
plt.grid()
plt.plot(neuron_model.T,neuron_model.V)
plt.show()

In [None]:
# Save and load model.

torch.save(model, 'mymodel.pt')
saved_model = torch.load('mymodel.pt')