In [170]:
# v1 was just vanila.

# v2: 24.09.01. 
#   1. Doing the CV (10 folds)
#   2. Early stopping with patience parameter.
#   3. Takings n steps into account. 

# v3: 24.10.15 - 29
#   1. No CV. No Early Stopping as well. 
#       Adding early stopping may be an option, but it it not super necessary.
#   2. Adding tau param. h(t) = f((1-tau) * h(t-1) + tau * x(t)) => This is what basically CT-RNN is about.

# import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
import scipy.io as sio
import matplotlib.pyplot as plt

# seed random number generator for reproducibility
torch.manual_seed(0)
np.random.seed(0)

## ---------------------------------- ##
## ------------ FIX down ------------ ##
## ---------------------------------- ##

# input
input_size  = 1   # Input feature size
hidden_size = 16  # Number of features in the hidden state
output_size = 2   # Output feature size (stim rating / click profiles)
output_name = ['stim_rating', 'clicks'] # 'stim_rating' 'clicks'
taus        = np.linspace(0, 1, 10)
noise_sig   = 0.5

# params
learning_rate = 1e-2 # 
n_epochs      = 100
batch_size    = 9   # or 9 for condition. so far at the trial-level
device        = 'cpu' # mps
targfolder    = 'condavg_label-RatingClicks_loss-MSEBCE'
if not (os.path.isdir(targfolder)):
    os.mkdir(targfolder)

## ---------------------------------- ##
## ------------- FIX up ------------- ##
## ---------------------------------- ##

# Q. why no Cross-Validation?
# A. no differences across inputs.

# Input: n_batch, n_seqlen, n_dim
class PainRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, tau):
        super().__init__()
        self.input_size  = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.alpha       = tau # [0, 1] balances the effect of the input and hidden.
                               # having it as one is just same with vanila RNN
        self.i2h         = nn.Linear(input_size, hidden_size)
        self.h2h         = nn.Linear(hidden_size, hidden_size)
        self.h2o         = nn.Linear(hidden_size, output_size)
        
    def forward(self, input, init_hidden=None):
        hidden_units  = []
        output_units  = []
        seq_len       = input.size(1) # sequence length
        
        for i in range(seq_len):
             if (i == 0) and (init_hidden is None):
                 h_t = torch.zeros(input.size(0), self.hidden_size)
             elif (i == 0) and (init_hidden is not None):
                 h_t = init_hidden
                 
             h_t           = h_t.to(device)    
             x_t           = input[:, i, :]
             hidden_unit   = torch.relu(self.i2h(x_t) + self.h2h(h_t))
             hidden_unit   = h_t * (1 - self.alpha) + hidden_unit * self.alpha
             output_unit   = torch.sigmoid(self.h2o(hidden_unit))
            
             hidden_units.append(hidden_unit)
             output_units.append(output_unit)
        
        hidden_units = torch.stack(hidden_units, dim = 1)
        output_units = torch.stack(output_units, dim = 1)
        return output_units, hidden_units
# Load input and output data
inputoutputs = sio.loadmat('InOutputs_condavg')
Inputs       = inputoutputs['Inputs']
Outputs      = inputoutputs['Outputs']

Xnp = []
ynp = []
for cond_i in range(len(Outputs)):
    ynp.append(Outputs[cond_i][0])

Xnp = np.expand_dims(Inputs, axis = 2)
ynp = np.stack(ynp, axis = 0)
    
X = torch.from_numpy(Xnp).to(torch.float32)
y = torch.from_numpy(ynp).to(torch.float32)

# Combine inputs and outputs into a dataset
dataset    = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model instantiation
for i, tau in enumerate(taus):
    addstr      = 'tau_%.2f' % tau
    pain_rnn    = PainRNN(input_size, hidden_size, output_size, tau).to(device)
    optimizer   = optim.Adam(pain_rnn.parameters(), lr=learning_rate)
    criterion1  = nn.MSELoss()              # For intensity.
    criterion2  = nn.BCELoss()     # For clicks.
    loss_epochs = []
    for epoch in range(n_epochs):
        for inputs, targets in dataloader:
            add_noise = np.random.normal(0, noise_sig, inputs.shape)
            inputs += torch.from_numpy(add_noise).to(torch.float32)
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs, _ = pain_rnn(inputs)  # Capture both outputs and hidden states
            loss1 = criterion1(outputs[:, :, 0], targets[:, :, 0])
            loss2 = criterion2(outputs[:, :, 1], targets[:, :, 1])
            loss  = loss1 + loss2
            # loss = criterion1(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        loss_epochs.append(loss.item())
        
        if ((epoch+1) % 10) == 0:
            print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f} in {i+1} of {len(taus)}')
            
    out_t, h_t = pain_rnn(torch.from_numpy(np.expand_dims(Inputs, axis = 2)).to(torch.float32))
    out_t      = out_t.detach().numpy()
    h_t        = h_t.detach().numpy()

    fig, axes = plt.subplots(3, 3, figsize = (9, 9))
    axes[0, 0].plot(loss_epochs)
    axes[0, 0].set_xlabel('Epochs')
    axes[0, 0].set_ylabel('Loss')

    axes[0, 1].plot(np.mean(h_t, 0))
    axes[0, 1].set_xlabel('time')
    axes[0, 1].set_ylabel('Hidden Unit activations')
    
    axes[2, 0].plot(Inputs.T)
    axes[2, 0].set_xlabel('time')
    axes[2, 0].set_ylabel('Training Input')
    
    for i in range(output_size):
        axes[1, i].plot(out_t[:, :, i].T)
        axes[1, i].set_xlabel('time')
        axes[1, i].set_ylabel(f'Outputs {output_name[i]}')
        
        axes[2, i+1].imshow(y[:, :, i])
        axes[2, i+1].set_xlabel('time')
        axes[2, i+1].set_ylabel(f'Training {output_name[i]}')
        
    
    sio.savemat(os.path.join(targfolder, ('HiddenLayers_' + addstr + '.mat')), 
                {'out_t': out_t, 'h_t': h_t, 'loss_epochs': loss_epochs, 
                 'i2h':pain_rnn.i2h.weight.detach().numpy(), 
                 'h2h':pain_rnn.h2h.weight.detach().numpy(), 
                 'h2o':pain_rnn.h2o.weight.detach().numpy()})

    plt.savefig(os.path.join(targfolder, ('HiddenLayers_' + addstr + '.png')))
    plt.close()
    

Epoch [10/100], Loss: 0.8419 in 1 of 10
Epoch [20/100], Loss: 0.8076 in 1 of 10
Epoch [30/100], Loss: 0.7778 in 1 of 10
Epoch [40/100], Loss: 0.7527 in 1 of 10
Epoch [50/100], Loss: 0.7319 in 1 of 10
Epoch [60/100], Loss: 0.7150 in 1 of 10
Epoch [70/100], Loss: 0.7016 in 1 of 10
Epoch [80/100], Loss: 0.6910 in 1 of 10
Epoch [90/100], Loss: 0.6828 in 1 of 10
Epoch [100/100], Loss: 0.6765 in 1 of 10
Epoch [10/100], Loss: 0.7082 in 2 of 10
Epoch [20/100], Loss: 0.7142 in 2 of 10
Epoch [30/100], Loss: 0.6881 in 2 of 10
Epoch [40/100], Loss: 0.6629 in 2 of 10
Epoch [50/100], Loss: 0.6488 in 2 of 10
Epoch [60/100], Loss: 0.6319 in 2 of 10
Epoch [70/100], Loss: 0.6198 in 2 of 10
Epoch [80/100], Loss: 0.6083 in 2 of 10
Epoch [90/100], Loss: 0.5990 in 2 of 10
Epoch [100/100], Loss: 0.5864 in 2 of 10
Epoch [10/100], Loss: 0.8018 in 3 of 10
Epoch [20/100], Loss: 0.7262 in 3 of 10
Epoch [30/100], Loss: 0.7129 in 3 of 10
Epoch [40/100], Loss: 0.6894 in 3 of 10
Epoch [50/100], Loss: 0.6643 in 3 of 1