In [11]:
# v1 was just vanila.

# v2: 24.09.01. 
#   1. Doing the CV (10 folds)
#   2. Early stopping with patience parameter.
#   3. Takings n steps into account. 

# v3: 24.10.15 - 29
#   1. No CV. No Early Stopping as well. 
#       Adding early stopping may be an option, but it it not super necessary.
#   2. Adding tau param. h(t) = f((1-tau) * h(t-1) + tau * x(t)) => This is what basically CT-RNN is about.

# v4: 24.11.15
#   1. Simplifying task...
#       - adding two layers for outputs first. (getting max for the decision)

# import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
import scipy.io as sio
import matplotlib.pyplot as plt

# seed random number generator for reproducibility
torch.manual_seed(0)
np.random.seed(0)

## ---------------------------------- ##
## ------------ FIX down ------------ ##
## ---------------------------------- ##

# input
input_size  = 1   # Input feature size
hidden_size = 64  # Number of features in the hidden state
taus        = np.linspace(0, 1, 10)
noise_sig   = 0.01
output_name = ['StimRating', 'Clicks']
output_size = len(output_name)
objfunc     = ['MSE']
actfunc     = 'relu'
targfolder  = f'act-{actfunc}_label-{''.join(output_name)}_loss-{''.join(objfunc)}'
if not (os.path.isdir(targfolder)):
    os.mkdir(targfolder)

# params
learning_rate = 1e-3 # 
n_epochs      = 5000
batch_size    = 32   # or 9 for condition. so far at the trial-level
device        = 'cpu' # mps


## ---------------------------------- ##
## ------------- FIX up ------------- ##
## ---------------------------------- ##

# Q. why no Cross-Validation?
# A. no differences across inputs.

# Input: n_batch, n_seqlen, n_dim
class PainRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, tau):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.alpha       = tau # [0, 1] balances the effect of the input and hidden.
                               # having it as one is just same with vanila RNN
        self.i2h         = nn.Linear(input_size, hidden_size)
        self.h2h         = nn.Linear(hidden_size, hidden_size)
        self.h2o         = nn.Linear(hidden_size, output_size)
        
    def forward(self, input, init_hidden=None):
        hidden_units  = []
        output_units  = []
        seq_len       = input.size(1) # sequence length
        
        for i in range(seq_len):
             if (i == 0) and (init_hidden is None):
                 h_t = torch.zeros(input.size(0), self.hidden_size)
             elif (i == 0) and (init_hidden is not None):
                 h_t = init_hidden
             else:
                 h_t = hidden_unit
             h_t           = h_t.to(device)    
             x_t           = input[:, i, :]
             if actfunc == 'relu':
                 hidden_unit   = torch.relu(self.i2h(x_t) + self.h2h(h_t))
             elif actfunc == 'tanh':
                hidden_unit   = torch.tanh(self.i2h(x_t) + self.h2h(h_t))
             hidden_unit   = h_t * (1 - self.alpha) + hidden_unit * self.alpha
             output_unit   = self.h2o(hidden_unit)
             
             hidden_units.append(hidden_unit)
             output_units.append(output_unit)
        
        hidden_units = torch.stack(hidden_units, dim = 1)
        output_units = torch.stack(output_units, dim = 1)
        return output_units, hidden_units
    
# Load input and output data
inputoutputs = sio.loadmat('InOutputs_condavg.mat')
Inputs       = inputoutputs['Inputs']
Outputs      = inputoutputs['Outputs']

Xnp          = np.expand_dims(Inputs, axis = 2)
ynp          = []
for cond_i in range(len(Outputs)):
    ynp.append(Outputs[cond_i][0])
ynp         = np.stack(ynp, axis = 0)
X = torch.from_numpy(Xnp).to(torch.float32)
y = torch.from_numpy(ynp).to(torch.float32)

In [12]:
# Combine inputs and outputs into a dataset
dataset    = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model instantiation
for i, tau in enumerate(taus):
    addstr      = 'tau_%.2f' % tau
    pain_rnn    = PainRNN(input_size, hidden_size, output_size, tau).to(device)
    optimizer   = optim.Adam(pain_rnn.parameters(), lr=learning_rate)
    criterion1  = nn.MSELoss()              # For intensity.
    criterion2  = nn.CrossEntropyLoss()     # For clicks.
    criterion3  = nn.BCEWithLogitsLoss()     # For clicks.
    loss_epochs = []
    for epoch in range(n_epochs):
        for inputs, targets in dataloader:
            add_noise = np.random.normal(0, noise_sig, inputs.shape)
            inputs += torch.from_numpy(add_noise).to(torch.float32)
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs, _ = pain_rnn(inputs)  # Capture both outputs and hidden states
            if (len(objfunc) == 1) and objfunc[0] == 'MSE' and not ('ClickThr' in output_name):
                loss = criterion1(outputs, targets)
            elif (len(objfunc) == 1) and objfunc[0] == 'CE':
                loss = criterion2(outputs, targets)
            elif (len(objfunc) == 2) and objfunc[1] == 'CE':
                loss1 = criterion1(outputs[:, :, 0], targets[:, :, 0])
                loss2 = criterion2(outputs[:, :, 1], targets[:, :, 1])
                loss  = loss1 + loss2
            elif (len(objfunc) == 2) and objfunc[1] == 'BCE' and len(output_name) == 3:
                loss1 = criterion1(outputs[:, :, 0], targets[:, :, 0])
                actionvals = outputs[:, :, 1] - outputs[:, :, 2]
                loss3 = criterion3(actionvals, targets[:, :, 1])
                loss  = loss1 + loss3
            elif (len(objfunc) == 2) and objfunc[1] == 'BCE' and len(output_name) == 2:
                loss1 = criterion1(outputs[:, :, 0], targets[:, :, 0])
                loss3 = criterion3(outputs[:, :, 1], targets[:, :, 1])
                loss  = loss1 + loss3
            elif (len(objfunc) == 1) and objfunc[0] == 'MSE' and 'ClickThr' in output_name:
                outputs[:, :, 1] = (outputs[:, :, 1] > 0).long()
                loss = criterion1(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        loss_epochs.append(loss.item())
        
        if ((epoch+1) % 10) == 0:
            print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f} in {i+1} of {len(taus)}')
            
    out_t, h_t = pain_rnn(torch.from_numpy(np.expand_dims(Inputs, axis = 2)).to(torch.float32))
    out_t      = out_t.detach().numpy()
    h_t        = h_t.detach().numpy()

    fig, axes = plt.subplots(3, 3, figsize = (9, 9))
    axes[0, 0].plot(loss_epochs)
    axes[0, 0].set_xlabel('Epochs')
    axes[0, 0].set_ylabel('Loss')

    axes[0, 1].plot(np.mean(h_t, 0))
    axes[0, 1].set_xlabel('time')
    axes[0, 1].set_ylabel('Hidden Unit activations')
    
    axes[2, 0].plot(Inputs.T)
    axes[2, 0].set_xlabel('time')
    axes[2, 0].set_ylabel('Training Input')
    
    for i in range(output_size):
        axes[1, i].plot(out_t[:, :, i].T)
        axes[1, i].set_xlabel('time')
        axes[1, i].set_ylabel(f'Outputs {output_name[i]}')    
    
    for i in range(y.shape[2]):
        axes[2, i+1].plot(y[:, :, i].T)
        axes[2, i+1].set_xlabel('time')
        axes[2, i+1].set_ylabel(f'Training {output_name[i]}')
    if len(output_name) == 3:
        actionval = (out_t[:, :, 1] - out_t[:, :, 2]) > 0
        axes[0, 2].plot(actionval.T)
        axes[0, 2].set_xlabel('time')
        axes[0, 2].set_ylabel('Hidden Unit activations')
    
    
    sio.savemat(os.path.join(targfolder, ('HiddenLayers_' + addstr + '.mat')), 
                {'out_t': out_t, 'h_t': h_t, 'loss_epochs': loss_epochs, 
                 'i2h':pain_rnn.i2h.weight.detach().numpy(), 
                 'h2h':pain_rnn.h2h.weight.detach().numpy(), 
                 'h2o':pain_rnn.h2o.weight.detach().numpy()})

    plt.savefig(os.path.join(targfolder, ('HiddenLayers_' + addstr + '.png')))
    plt.close()
    

Epoch [10/5000], Loss: 0.1578 in 1 of 10
Epoch [20/5000], Loss: 0.1545 in 1 of 10
Epoch [30/5000], Loss: 0.1514 in 1 of 10
Epoch [40/5000], Loss: 0.1485 in 1 of 10
Epoch [50/5000], Loss: 0.1460 in 1 of 10
Epoch [60/5000], Loss: 0.1436 in 1 of 10
Epoch [70/5000], Loss: 0.1415 in 1 of 10
Epoch [80/5000], Loss: 0.1397 in 1 of 10
Epoch [90/5000], Loss: 0.1380 in 1 of 10
Epoch [100/5000], Loss: 0.1365 in 1 of 10
Epoch [110/5000], Loss: 0.1352 in 1 of 10
Epoch [120/5000], Loss: 0.1340 in 1 of 10
Epoch [130/5000], Loss: 0.1330 in 1 of 10
Epoch [140/5000], Loss: 0.1321 in 1 of 10
Epoch [150/5000], Loss: 0.1313 in 1 of 10
Epoch [160/5000], Loss: 0.1306 in 1 of 10
Epoch [170/5000], Loss: 0.1301 in 1 of 10
Epoch [180/5000], Loss: 0.1296 in 1 of 10
Epoch [190/5000], Loss: 0.1291 in 1 of 10
Epoch [200/5000], Loss: 0.1288 in 1 of 10
Epoch [210/5000], Loss: 0.1285 in 1 of 10
Epoch [220/5000], Loss: 0.1282 in 1 of 10
Epoch [230/5000], Loss: 0.1280 in 1 of 10
Epoch [240/5000], Loss: 0.1278 in 1 of 10
E

In [17]:
y.shape[2]

2