# Task 1 - EWC Notebook: Force Perturbation Dataset

## Overview

This notebook focuses on the implementation of Elastic Weight Consolidation (EWC) for continual learning on a dynamic perturbation dataset. The primary goal is to explore the ability of neural decoders to adapt to perturbations by leveraging EWC to mitigate catastrophic forgetting. 

## Key Objectives

1. **Target Variable**: The primary target variable in this notebook is velocity, and it will not be normalized. Outliers here are not removed. 
   
2. **EWC Implementation**: 
   - Implement EWC to observe how well the model retains learned knowledge when trained sequentially on different tasks.
   - Compare the performance of models trained with EWC against those without it, specifically evaluating the model's ability to retain knowledge from the baseline task while adapting to force perturbations.

## Notebook Structure

1. **Imports and Setup**: All necessary libraries and modules are imported, and initial configurations such as device settings and paths are established.

2. **Data Loading and Preprocessing**:
   - Load pre-processed data for both baseline and force perturbation epochs.
   - Apply Gaussian noise to input features to simulate realistic conditions.
   - Prepare data for training, validation, and testing.

3. **Baseline Model**:
   - Train a basic RNN model on baseline data.
   - Evaluate the model's performance on both baseline and perturbation data.

4. **Full Data Training**:
   - Train the RNN model on the combined dataset (baseline and perturbation).
   - Save and evaluate the model's performance.

5. **EWC Implementation**:
   - Define and compute the Fisher Information Matrix.
   - Apply EWC during training to retain knowledge from the baseline task while training on perturbation data.
   - Compare the results with models trained without EWC.

6. **Catastrophic Forgetting Evaluation**:
   - Evaluate the impact of sequential training on catastrophic forgetting by training models on one task and then testing on another.
   - Analyze the model's performance in retaining baseline task knowledge after training on perturbation data and vice versa.

7. **Visualization and Results**:
   - Visualize training and validation losses across epochs.
   - Present the model's performance metrics (Explained Variance, R²) for both baseline and perturbation tasks.

## Purpose

The purpose of this notebook is to validate the effectiveness of EWC in mitigating catastrophic forgetting in a force adaptation task. By comparing models trained with and without EWC, we aim to demonstrate the benefits of continual learning approaches in neural decoding tasks, particularly in scenarios where the model must adapt to changing conditions while retaining previously learned information.


### Imports

In [3]:
import pandas as pd
import numpy as np
import xarray as xr

import os
import sys
from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import seaborn as sns


# Navigate up two levels to reach the grandparent directory (CL Control)
parent_dir = os.path.abspath(os.path.join('..',))
sys.path.append(parent_dir)

from src.helpers import *
from src.visualize import *
from src.trainer import *
from src.regularizers import *
from Models.models import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from sklearn.metrics import *
from copy import deepcopy
import torch.utils.data as data
from torch.utils.data import Dataset

import pickle
import math

In [4]:
import sys
sys.path.append("c:\\Users\\nerea\\OneDrive\\Documentos\\EPFL_MASTER\\PDM\\Project\\PyalData")
# to change for the actual path where PyalData has been cloned

In [5]:
from pyaldata import *

In [6]:
name = 'Chewie'
date = '1007'
fold = 0
target_variable = 'vel'

In [7]:
#@title Helper functions for plotting (run this cell!)
sns.set_context("notebook")

# initialize a color palette for plotting
palette = sns.xkcd_palette(["windows blue",
                            "red",
                            "medium green",
                            "dusty purple",
                            "orange",
                            "amber",
                            "clay",
                            "pink",
                            "greyish"])

In [8]:
to_t_eval =  lambda array: torch.tensor(array, device=device, dtype=dtype)  

## Load pre-processed data

In [10]:
data_path = '../Data/Processed_Data/Tidy_'+name+'_'+date+'.pkl'

with open(data_path, 'rb') as file:
    tidy_df = pickle.load(file)

In [11]:
baseline_df = tidy_df.loc[tidy_df['epoch'] == 'BL']

In [12]:
baseline_df.id.nunique()

170

In [13]:
force_df =  tidy_df.loc[tidy_df['epoch'] == 'AD']

We need to consider only the trials for which the monkey has already adapted to the perturbation.

In [14]:
ids_to_keep = force_df.id.unique()[50:]

The baseline subset has a total of 170 trials, whereas the perturbation one contains 201 trials, we can for now try to remove the first 50 trials from the perturbation subset.

In [15]:
force_df = force_df.loc[force_df.id.isin(ids_to_keep)]

## Get train-val-test split

In [16]:
xx_train_base, yy_train_base, xx_val_base, yy_val_base,\
      xx_test_base, yy_test_base, info_train_base, info_val_base,\
          info_test_base, list_mins_base, \
            list_maxs_base= get_dataset(baseline_df, fold, target_variable= target_variable, no_outliers = False, force_data = True)

Train trials 109
Test trials  34
Val trials 27
We are testing the optimization method on fold  0


In [17]:
xx_train_force, yy_train_force, xx_val_force, yy_val_force,\
      xx_test_force, yy_test_force, info_train_force, info_val_force,\
          info_test_force,  list_mins_force, \
            list_maxs_force = get_dataset(force_df, fold, target_variable= target_variable, no_outliers = False, force_data = True)

Train trials 97
Test trials  30
Val trials 24
We are testing the optimization method on fold  0


In [18]:
xx_train_all, yy_train_all, xx_val_all, yy_val_all, \
    xx_test_all, yy_test_all, info_train_all, \
    info_val_all, info_test_all,  list_mins_all,\
          list_maxs_all = get_dataset(tidy_df,fold, target_variable= target_variable, no_outliers = False
                                      , force_data = True)

Train trials 211
Test trials  66
Val trials 53
We are testing the optimization method on fold  0


In [19]:
# Specify that we want our tensors on the GPU and in float32
device = torch.device('cuda:0') #suposed to be cuda
#device = torch.device('cpu') 
dtype = torch.float32
path_to_models = '../Models/Models_Force'

# Set the seed for reproducibility
seed_value = 42
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)  # If using CUDA

num_dim_output = yy_train_base.shape[2]
num_features = xx_train_base.shape[2]

In [20]:
def weight_reset(m):
    reset_parameters = getattr(m, "reset_parameters", None)
    if callable(reset_parameters):
        m.reset_parameters()

## Baseline Model

#### Adding noise to the training data

In [21]:
# Add Gaussian noise to input features
xx_train_base = torch.tensor(xx_train_base, dtype=torch.float32) + torch.tensor(np.random.normal(loc=0, scale=0.075, size=xx_train_base.shape), dtype=torch.float32)

In [22]:
# Define hyperparameters

#Hyperparameters objective and regularization
alpha_reg = 1e-5
l1_ratio_reg = 0.5

lr = 0.00001
loss_function = huber_loss
delta = 8  # hyperparameter for huber loss

# Hyperparameters LSTM class
hidden_units = 300
num_layers = 1
input_size = 49
dropout = 0.2

#Other training hyperparameters

lr_gamma= 1.37 #for scheduler
lr_step_size = 10 #for scheduler

seq_length_LSTM= 19
batch_size_train= 25
batch_size_val = 25

torch.manual_seed(42)

<torch._C.Generator at 0x7fa7d41a39d0>

In [23]:
model =  Causal_Simple_RNN(num_features=num_features, 
                hidden_units= hidden_units, 
                num_layers = num_layers, 
                out_dims = num_dim_output,
                dropout = dropout).to(device)

In [24]:
# Fit the LSTM model
model_base = model
model_base.apply(weight_reset)

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [22]:
train_losses, val_losses = \
    train_model(model_base, 
                xx_train_base,yy_train_base,
                xx_val_base, 
                yy_val_base,
                lr= lr,
                lr_step_size=lr_step_size,
                lr_gamma= lr_gamma,
                sequence_length_LSTM=seq_length_LSTM,
                batch_size_train = batch_size_train,
                batch_size_val = batch_size_val,
                num_epochs=1000, 
                delta = 8,                 
                regularizer= Regularizer_RNN, 
                l1_ratio = l1_ratio_reg,
                alpha = alpha_reg,     
                early_stop = 5)

  self.X = torch.tensor(X)


Epoch 000 Train 12.0800 Val 11.4261
Epoch 001 Train 10.0309 Val 9.8023
Epoch 002 Train 8.8359 Val 9.1065
Epoch 003 Train 8.1913 Val 8.6166
Epoch 004 Train 7.6851 Val 8.1048
Epoch 005 Train 6.9805 Val 7.1412
Epoch 006 Train 6.2320 Val 6.5102
Epoch 007 Train 5.7184 Val 6.0479
Epoch 008 Train 5.3147 Val 5.6570
Epoch 009 Train 4.9600 Val 5.3139
Epoch 010 Train 4.6026 Val 4.9216
Epoch 011 Train 4.2328 Val 4.5756
Epoch 012 Train 3.9203 Val 4.2808
Epoch 013 Train 3.6519 Val 4.0531
Epoch 014 Train 3.4266 Val 3.8309
Epoch 015 Train 3.2438 Val 3.6965
Epoch 016 Train 3.0871 Val 3.5425
Epoch 017 Train 2.9633 Val 3.5127
Epoch 018 Train 2.8599 Val 3.3889
Epoch 019 Train 2.7697 Val 3.3080
Epoch 020 Train 2.7007 Val 3.2303
Epoch 021 Train 2.6314 Val 3.2159
Epoch 022 Train 2.5648 Val 3.1310
Epoch 023 Train 2.5249 Val 3.1101
Epoch 024 Train 2.4908 Val 3.0777
Epoch 025 Train 2.4637 Val 3.0411
Epoch 026 Train 2.4312 Val 3.0475
Epoch 027 Train 2.4136 Val 2.9796
Epoch 028 Train 2.3882 Val 3.0089
Epoch 029 T

In [23]:
experiment_name = 'RNN'+ name+ '_' +date+'_Baseline'
path_base_model = os.path.join(path_to_models,experiment_name)
if not os.path.exists(path_base_model):
            os.makedirs(path_base_model)
path_base_fold = os.path.join(path_base_model,'fold_{}.pth'.format(fold))
torch.save(model_base, path_base_fold)

In [24]:
model_baselineonly = torch.load(path_base_fold)
model_baselineonly.eval()

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [25]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_base, yy_train_base, xx_val_base, yy_val_base, xx_test_base, yy_test_base, model_baselineonly, metric = 'ev')

Train EV: 0.94 
Val EV: 0.90 
Test EV: 0.87 


  inputs = torch.tensor(x, device=device, dtype=torch.float32)


In [26]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_base, yy_train_base, xx_val_base, yy_val_base, xx_test_base, yy_test_base, model_baselineonly, metric = 'r2')

Train R2: 0.94 
Val R2: 0.90 
Test R2: 0.87 


  inputs = torch.tensor(x, device=device, dtype=torch.float32)


## Testing the model on force data

In [27]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_force, yy_train_force, xx_val_force, yy_val_force, xx_test_force, yy_test_force, model_baselineonly, metric = 'ev')

Train EV: 0.49 
Val EV: 0.49 
Test EV: 0.53 


## Now we use all data for training

In [28]:
model_all = model
model_all.apply(weight_reset)

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [29]:
train_losses, val_losses = \
    train_model(model_all, xx_train_all,yy_train_all,
                xx_val_all, 
                yy_val_all,
                lr= lr,
                lr_step_size=lr_step_size,
                lr_gamma= lr_gamma,
                sequence_length_LSTM=seq_length_LSTM,
                batch_size_train = batch_size_train,
                batch_size_val = batch_size_val,
                num_epochs=1000, 
                delta = 8,                 
                regularizer=None, #Regularizer_LSTM,
                l1_ratio = l1_ratio_reg,
                alpha = alpha_reg,     
                early_stop = 5,
                
                )


Epoch 000 Train 9.8331 Val 7.7173
Epoch 001 Train 6.7892 Val 5.2663
Epoch 002 Train 4.9698 Val 4.2868
Epoch 003 Train 4.1628 Val 3.7672
Epoch 004 Train 3.6566 Val 3.4615
Epoch 005 Train 3.3029 Val 3.1952
Epoch 006 Train 3.0489 Val 3.0291
Epoch 007 Train 2.8524 Val 2.9213
Epoch 008 Train 2.6980 Val 2.8095
Epoch 009 Train 2.5699 Val 2.7320
Epoch 010 Train 2.4502 Val 2.5915
Epoch 011 Train 2.3226 Val 2.5498
Epoch 012 Train 2.2147 Val 2.4524
Epoch 013 Train 2.1173 Val 2.3842
Epoch 014 Train 2.0324 Val 2.3439
Epoch 015 Train 1.9547 Val 2.3226
Epoch 016 Train 1.8863 Val 2.2668
Epoch 017 Train 1.8216 Val 2.2186
Epoch 018 Train 1.7641 Val 2.2092
Epoch 019 Train 1.7130 Val 2.1909
Epoch 020 Train 1.6580 Val 2.1957
Epoch 021 Train 1.5973 Val 2.1836
Epoch 022 Train 1.5454 Val 2.1528
Epoch 023 Train 1.4947 Val 2.1575
Epoch 024 Train 1.4505 Val 2.1684
Epoch 025 Train 1.4082 Val 2.1647
Epoch 026 Train 1.3700 Val 2.1296
Epoch 027 Train 1.3327 Val 2.1513
Epoch 028 Train 1.3017 Val 2.1607
Epoch 029 Trai

In [30]:
experiment_name = 'RNN'+name+'_'+date+'_Alldata'
path_to_save_model = os.path.join(path_to_models,experiment_name)
if not os.path.exists(path_to_save_model):
            os.makedirs(path_to_save_model)
path_to_save_model_fold = os.path.join(path_to_save_model,'fold_{}.pth'.format(fold))
torch.save(model_all, path_to_save_model_fold)

In [31]:
model_all_data = torch.load(path_to_save_model_fold)
model_all_data.eval() 

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [32]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_all, yy_train_all, xx_val_all, yy_val_all, xx_test_all, yy_test_all, model_all_data, metric = 'ev')

Train EV: 0.89 
Val EV: 0.86 
Test EV: 0.86 


## Implementing EWC

Defining EWC loss. For each parameter, we multiply the sqared difference between the current training parameter and the optimal one for the previous task by the importance of the parameter (extracted from the Fisher information matrix)

In [33]:
def get_ewc_loss(model, fisher, p_old):
    loss = 0
    for n, p in model.named_parameters():
        _loss = fisher[n] * (p - p_old[n]) ** 2
        loss += _loss.sum()
    return loss

In [34]:
x = to_t_eval(xx_train_base) 
y = to_t_eval(yy_train_base)

  to_t_eval =  lambda array: torch.tensor(array, device=device, dtype=dtype)


In [35]:
# Assuming X_train and y_train are NumPy arrays or PyTorch tensors
dataset = list(zip(x, y))

In [37]:
import copy

In [38]:
model_pre_EWC = copy.deepcopy(model_baselineonly) 
# Flatten the parameters of the copied model
for module in model_pre_EWC.modules():
    if isinstance(module, nn.RNNBase):
        module.flatten_parameters()
model_pre_EWC.train()

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

Save the optimal parameters for the baseline task.

In [39]:
params = {n: p for n, p in model_pre_EWC.named_parameters() if p.requires_grad}
p_old = {}

for n, p in deepcopy(params).items():
    p_old[n] = p.data

#### Defining the Fisher matrix. 

Each element of the Fisher information matrix is approximated as squared gradients averaged across mini-batches during a single pass through the training dataset.
For each input and label in the dataset the gradient goes back to 0, the model does a forward pass, the loss is computed, and a backward step is taken. From the step, each parameter's gradient is obtained and the sqared of this gradient is added to the Fisher matrix diagonal, normalized by the length of the dataset. 

In [1]:
def get_fisher_diag(model, dataset, params, empirical=True):
 
    fisher = {}
    for n, p in deepcopy(params).items():
        p.data.zero_()
        fisher[n] = p.data

    
    for input, gt_label in dataset:
        model.zero_grad()
        output = model(input).view(-1)
        if empirical:
            label = gt_label.view(-1)
            
        else:
            label = output.max(1)[1].view(-1)
            

        h_loss  = huber_loss(output, label)
        #negloglikelihood = F.nll_loss(F.log_softmax(output, dim = -1), label)
        #negloglikelihood.backward()
        model.train()
        h_loss.backward()

        for n, p in model.named_parameters():
            fisher[n].data += p.grad.data ** 2 / len(dataset)

    fisher = {n: p for n, p in fisher.items()}
    return fisher

In [41]:
fisher_matrix = get_fisher_diag(model_pre_EWC, dataset, params)

In [42]:
def train_model_EWC(model, X,Y,
                X_val, 
                Y_val,
                lr=lr, # 0.0001,
                lr_step_size= lr_step_size ,#10,
                lr_gamma=lr_gamma,#0.9,
                sequence_length_LSTM= seq_length_LSTM, #10,
                batch_size_train = batch_size_train, #3,
                batch_size_val = batch_size_val,# 3,
                num_epochs=1000, 
                delta = 8,                 
                regularizer=None,
                l1_ratio = l1_ratio_reg, #0.5,
                alpha = alpha_reg, #1e-5,     
                early_stop = 5,
                lambda_ewc = 0.2):

    # Set up the optimizer with the specified learning rate
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Set up a learning rate scheduler
    scheduler = lr_scheduler.StepLR(optimizer, 
                                    step_size=lr_step_size, 
                                    gamma=lr_gamma)
    
    
        
    # Keep track of the best model's parameters and loss
    best_model_wts = deepcopy(model.state_dict())
    best_loss = 1e8

    # Enable anomaly detection for debugging
    torch.autograd.set_detect_anomaly(True)

    # Track the train and validation loss
    train_losses = []
    val_losses = []
    # Counters for early stopping
    not_increased = 0
    end_train = 0
    
    # Reshape data for the LSTM
    train_dataset = SequenceDataset(
    Y,    X,    sequence_length=sequence_length_LSTM)

    val_dataset = SequenceDataset(
    Y_val,    X_val,    sequence_length=sequence_length_LSTM)
    loader_train = data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
    loader_val = data.DataLoader(val_dataset, batch_size=batch_size_val, shuffle=True)

    # Loop through epochs
    for epoch in np.arange(num_epochs):
        for phase in ['train', 'val']:
            # set model to train/validation as appropriate
            if phase == 'train':
                model.train()
                loader = loader_train
            else:
                model.eval()
                loader = loader_val

            # Initialize variables to track loss and batch size
            running_loss = 0
            running_size = 0        

            # Iterate over batches in the loader
            for X_, y_ in loader:
                X_ = X_.to(device)
                y_ = y_.to(device)
                if phase == "train":
                    with torch.set_grad_enabled(True):
                        optimizer.zero_grad()

                        output_t = model(X_)
                        output_t = torch.squeeze(output_t)


                        loss_t = huber_loss(output_t, y_, delta = delta)
                        
                        
                        # Add regularization to the loss in the training phase
                        if regularizer is not None:
                            ewc_loss = get_ewc_loss(model, fisher_matrix, p_old)
                            loss_t += lambda_ewc * ewc_loss
                        #     loss += regularizer(model, l1_ratio, alpha)
                        # Compute gradients and perform an optimization step
                        loss_t.backward(retain_graph=True)
                        optimizer.step()
                else:
                    # just compute the loss in validation phase
                    output_t = model(X_)
                    output_t = torch.squeeze(output_t)

                    loss_t = huber_loss(output_t, y_, delta = delta)
                    

                # Ensure the loss is finite
                assert torch.isfinite(loss_t)
                running_loss += loss_t.item()
                running_size += 1

            # compute the train/validation loss and update the best
            # model parameters if this is the lowest validation loss yet
            running_loss /= running_size
            if phase == "train":
                train_losses.append(running_loss)
            else:
                val_losses.append(running_loss)
                # Update best model parameters if validation loss improves
                if running_loss < best_loss:
                    best_loss = running_loss
                    best_epoch = epoch
                    best_model_wts = deepcopy(model.state_dict())
                    not_increased = 0
                else:
                    # Perform early stopping if validation loss doesn't improve
                    if epoch > 10:
                        not_increased += 1
                        # print('Not increased : {}/5'.format(not_increased))
                        if not_increased == early_stop:
                            print('Decrease LR')
                            for g in optimizer.param_groups:
                                g['lr'] = g['lr'] / 2
                            not_increased = 0
                            end_train += 1
                        
                        if end_train == 2:
                            model.load_state_dict(best_model_wts)
                            print(best_epoch)
                            return np.array(train_losses), np.array(val_losses), best_epoch

        # Update learning rate with the scheduler
        scheduler.step()
        print("Epoch {:03} Train {:.4f} Val {:.4f}".format(epoch, train_losses[-1], val_losses[-1]))

    # load best model weights
    model.load_state_dict(best_model_wts)
    print(best_epoch)

    return np.array(train_losses), np.array(val_losses), best_epoch

`Note` Grid search for lambda (strength of regularization in EWC) not included here. Best lambda value obtained was 1900.

In [46]:
train_losses, val_losses, best_epoch = \
    train_model_EWC(model_pre_EWC, xx_train_force,yy_train_force,
                xx_val_force, 
                yy_val_force,
                lr= 1e-5,
                lr_step_size=lr_step_size,
                lr_gamma= lr_gamma,
                sequence_length_LSTM=seq_length_LSTM,
                batch_size_train = batch_size_train,
                batch_size_val = batch_size_val,
                num_epochs=1000, 
                delta = 8,                 
                regularizer= True,
                l1_ratio = 0.5,
                alpha = 1e-5,     
                early_stop = 5,
                lambda_ewc = 1900
                )

Epoch 000 Train 4.7855 Val 4.3033
Epoch 001 Train 4.2282 Val 4.0862
Epoch 002 Train 4.0822 Val 3.9902
Epoch 003 Train 4.0116 Val 3.9420
Epoch 004 Train 3.9586 Val 3.9015
Epoch 005 Train 3.9242 Val 3.8929
Epoch 006 Train 3.9145 Val 3.8723
Epoch 007 Train 3.8887 Val 3.7913
Epoch 008 Train 3.8580 Val 3.8048
Epoch 009 Train 3.8530 Val 3.8045
Epoch 010 Train 3.8525 Val 3.8050
Epoch 011 Train 3.8487 Val 3.7751
Epoch 012 Train 3.8276 Val 3.7280
Epoch 013 Train 3.8093 Val 3.7748
Epoch 014 Train 3.8203 Val 3.7291
Epoch 015 Train 3.8050 Val 3.7287
Epoch 016 Train 3.8049 Val 3.7542
Epoch 017 Train 3.7935 Val 3.7012
Epoch 018 Train 3.7919 Val 3.6921
Epoch 019 Train 3.7831 Val 3.6760
Epoch 020 Train 3.8054 Val 3.6759
Epoch 021 Train 3.7998 Val 3.6991
Epoch 022 Train 3.8320 Val 3.7129
Epoch 023 Train 3.7888 Val 3.7168
Epoch 024 Train 3.8162 Val 3.7269
Decrease LR
Epoch 025 Train 3.7919 Val 3.6831
Epoch 026 Train 3.7535 Val 3.6957
Epoch 027 Train 3.7554 Val 3.6736
Epoch 028 Train 3.7482 Val 3.6548
Ep

In [47]:
experiment_name = 'RNN_'+name+'_'+date+'_EWC'
path_to_save_model = os.path.join(path_to_models,experiment_name)
if not os.path.exists(path_to_save_model):
            os.makedirs(path_to_save_model)
path_to_save_model_fold = os.path.join(path_to_save_model,'fold_{}.pth'.format(fold))
torch.save(model_pre_EWC, path_to_save_model_fold) 

In [48]:
model_EWC = torch.load(path_to_save_model_fold)
model_EWC.eval() 

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [50]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_force, yy_train_force, xx_val_force, yy_val_force, xx_test_force, yy_test_force, model_EWC, metric = 'ev')

Train EV: 0.78 
Val EV: 0.72 
Test EV: 0.77 


### Testing the performance of the model on Baseline data after EWC

In [51]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_base, yy_train_base, xx_val_base, yy_val_base, xx_test_base, yy_test_base, model_EWC, metric = 'ev')

Train EV: 0.81 
Val EV: 0.74 
Test EV: 0.76 


  inputs = torch.tensor(x, device=device, dtype=torch.float32)


## Training model only on force data and testing on baseline data

In [52]:
model_force = model
model_force.apply(weight_reset)

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [53]:
train_losses, val_losses = \
    train_model(model_force, xx_train_force,yy_train_force,
                xx_val_force, 
                yy_val_force,
                lr= lr,
                lr_step_size=lr_step_size,
                lr_gamma= lr_gamma,
                sequence_length_LSTM=seq_length_LSTM,
                batch_size_train = batch_size_train,
                batch_size_val = batch_size_val,
                num_epochs=1000, 
                delta = 8,                 
                regularizer= None, #Regularizer_LSTM,
                l1_ratio = l1_ratio_reg,
                alpha = alpha_reg,     
                early_stop = 5,
                
                )

Epoch 000 Train 10.8152 Val 9.5241
Epoch 001 Train 7.9973 Val 7.7428
Epoch 002 Train 6.7816 Val 6.9417
Epoch 003 Train 6.0715 Val 6.3801
Epoch 004 Train 5.4451 Val 5.6979
Epoch 005 Train 4.6607 Val 4.8218
Epoch 006 Train 4.0079 Val 4.3118
Epoch 007 Train 3.6178 Val 4.0043
Epoch 008 Train 3.3361 Val 3.7720
Epoch 009 Train 3.1123 Val 3.5694
Epoch 010 Train 2.9006 Val 3.3917
Epoch 011 Train 2.6978 Val 3.2287
Epoch 012 Train 2.5276 Val 3.0855
Epoch 013 Train 2.3827 Val 2.9819
Epoch 014 Train 2.2610 Val 2.9036
Epoch 015 Train 2.1457 Val 2.8290
Epoch 016 Train 2.0404 Val 2.7663
Epoch 017 Train 1.9433 Val 2.7044
Epoch 018 Train 1.8553 Val 2.6716
Epoch 019 Train 1.7664 Val 2.6054
Epoch 020 Train 1.6805 Val 2.5461
Epoch 021 Train 1.5875 Val 2.5195
Epoch 022 Train 1.5064 Val 2.4624
Epoch 023 Train 1.4382 Val 2.4002
Epoch 024 Train 1.3795 Val 2.3787
Epoch 025 Train 1.3268 Val 2.3504
Epoch 026 Train 1.2764 Val 2.3139
Epoch 027 Train 1.2309 Val 2.3050
Epoch 028 Train 1.1889 Val 2.3064
Epoch 029 Tra

In [54]:
experiment_name = 'RNN'+name+'_'+date+'_Force'
path_to_save_model = os.path.join(path_to_models,experiment_name)
if not os.path.exists(path_to_save_model):
            os.makedirs(path_to_save_model)
path_to_save_model_fold = os.path.join(path_to_save_model,'fold_{}.pth'.format(fold))
torch.save(model_force, path_to_save_model_fold)

In [55]:
model_force = torch.load(path_to_save_model_fold)
model_force.eval()

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [56]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_force, yy_train_force, xx_val_force, yy_val_force, xx_test_force, yy_test_force, model_force, metric = 'ev')

Train EV: 0.94 
Val EV: 0.80 
Test EV: 0.85 


### Now test on baseline data to compare to the EWC

In [57]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_base, yy_train_base, xx_val_base, yy_val_base, xx_test_base, yy_test_base, model_force, metric = 'ev')

Train EV: 0.31 
Val EV: 0.30 
Test EV: 0.34 


  inputs = torch.tensor(x, device=device, dtype=torch.float32)


## Checking Catastrophic Forgetting

The idea here is to take models trained for one specific task, train them on the other task and then see how they generalize or not.

#### Training on stimulation data using baseline model

In [62]:
model_force_after_base  = copy.deepcopy(model_baselineonly)

# Flatten the parameters of the copied model
for module in model_force_after_base.modules():
    if isinstance(module, nn.RNNBase):
        module.flatten_parameters()
model_force_after_base.train()

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [63]:
train_losses, val_losses = \
    train_model(model_force_after_base, xx_train_force,yy_train_force,
                xx_val_force, 
                yy_val_force,
                lr= lr,
                lr_step_size=lr_step_size,
                lr_gamma= lr_gamma,
                sequence_length_LSTM=seq_length_LSTM,
                batch_size_train = batch_size_train,
                batch_size_val = batch_size_val,
                num_epochs=1000, 
                delta = 8,                 
                regularizer= None, #Regularizer_LSTM,
                l1_ratio = l1_ratio_reg,
                alpha = alpha_reg,     
                early_stop = 5,
                
                )

Epoch 000 Train 4.3486 Val 3.9733
Epoch 001 Train 3.3735 Val 3.5937
Epoch 002 Train 3.0583 Val 3.3521
Epoch 003 Train 2.8624 Val 3.2325
Epoch 004 Train 2.7140 Val 3.1133
Epoch 005 Train 2.6029 Val 3.0435
Epoch 006 Train 2.5040 Val 2.9437
Epoch 007 Train 2.4278 Val 2.9223
Epoch 008 Train 2.3542 Val 2.8661
Epoch 009 Train 2.2968 Val 2.8334
Epoch 010 Train 2.2390 Val 2.8041
Epoch 011 Train 2.1671 Val 2.8061
Epoch 012 Train 2.1069 Val 2.7706
Epoch 013 Train 2.0429 Val 2.7413
Epoch 014 Train 1.9925 Val 2.7095
Epoch 015 Train 1.9326 Val 2.7246
Epoch 016 Train 1.8732 Val 2.6905
Epoch 017 Train 1.8249 Val 2.7031
Epoch 018 Train 1.7688 Val 2.6909
Epoch 019 Train 1.7192 Val 2.7448
Epoch 020 Train 1.7080 Val 2.6905
Decrease LR
Epoch 021 Train 1.6609 Val 2.7125
Epoch 022 Train 1.6213 Val 2.6710
Epoch 023 Train 1.5970 Val 2.6612
Epoch 024 Train 1.5805 Val 2.6922
Epoch 025 Train 1.5576 Val 2.6559
Epoch 026 Train 1.5387 Val 2.6354
Epoch 027 Train 1.5256 Val 2.6242
Epoch 028 Train 1.4992 Val 2.6468
Ep

In [64]:
experiment_name = 'RNN'+name+'_'+date+'_Force_after_Baseline'
path_to_save_model = os.path.join(path_to_models,experiment_name)
if not os.path.exists(path_to_save_model):
            os.makedirs(path_to_save_model)
path_to_save_model_fold = os.path.join(path_to_save_model,'fold_{}.pth'.format(fold))
torch.save(model_force_after_base, path_to_save_model_fold) 

In [65]:
model_force_after_base = torch.load(path_to_save_model_fold)
model_force_after_base.eval()

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [66]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_base, yy_train_base, xx_val_base, yy_val_base, xx_test_base, yy_test_base, model_force_after_base, metric = 'ev')

Train EV: 0.35 
Val EV: 0.29 
Test EV: 0.35 


  inputs = torch.tensor(x, device=device, dtype=torch.float32)


In [67]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_force, yy_train_force, xx_val_force, yy_val_force, xx_test_force, yy_test_force, model_force_after_base, metric = 'ev')

Train EV: 0.94 
Val EV: 0.82 
Test EV: 0.83 


#### Training on baseline data using force model

In [68]:
model_base_after_force = copy.deepcopy(model_force)

# Flatten the parameters of the copied model
for module in model_base_after_force.modules():
    if isinstance(module, nn.RNNBase):
        module.flatten_parameters()
model_base_after_force.train()

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [69]:
train_losses, val_losses = \
    train_model(model_base_after_force, xx_train_base,yy_train_base,
                xx_val_base, 
                yy_val_base,
                lr= lr,
                lr_step_size=lr_step_size,
                lr_gamma= lr_gamma,
                sequence_length_LSTM=seq_length_LSTM,
                batch_size_train = batch_size_train,
                batch_size_val = batch_size_val,
                num_epochs=1000, 
                delta = 8,                 
                regularizer= None, #Regularizer_LSTM,
                l1_ratio = l1_ratio_reg,
                alpha = alpha_reg,     
                early_stop = 5,
                
                )

  self.X = torch.tensor(X)


Epoch 000 Train 5.5312 Val 4.0114
Epoch 001 Train 3.6615 Val 3.2763
Epoch 002 Train 3.1011 Val 2.9771
Epoch 003 Train 2.7750 Val 2.8055
Epoch 004 Train 2.5446 Val 2.6957
Epoch 005 Train 2.3689 Val 2.5986
Epoch 006 Train 2.2267 Val 2.5345
Epoch 007 Train 2.1121 Val 2.4759
Epoch 008 Train 2.0128 Val 2.4268
Epoch 009 Train 1.9257 Val 2.3824
Epoch 010 Train 1.8391 Val 2.3439
Epoch 011 Train 1.7502 Val 2.3045
Epoch 012 Train 1.6709 Val 2.2750
Epoch 013 Train 1.6032 Val 2.2529
Epoch 014 Train 1.5475 Val 2.2731
Epoch 015 Train 1.4923 Val 2.2213
Epoch 016 Train 1.4458 Val 2.2017
Epoch 017 Train 1.4001 Val 2.1867
Epoch 018 Train 1.3626 Val 2.1812
Epoch 019 Train 1.3274 Val 2.1654
Epoch 020 Train 1.2929 Val 2.1706
Epoch 021 Train 1.2541 Val 2.1468
Epoch 022 Train 1.2196 Val 2.1432
Epoch 023 Train 1.1853 Val 2.1396
Epoch 024 Train 1.1564 Val 2.1377
Epoch 025 Train 1.1279 Val 2.1163
Epoch 026 Train 1.1018 Val 2.1215
Epoch 027 Train 1.0804 Val 2.1015
Epoch 028 Train 1.0562 Val 2.1011
Epoch 029 Trai

In [70]:
experiment_name = 'RNN'+name+'_'+date+'_Baseline_after_Force'
path_to_save_model = os.path.join(path_to_models,experiment_name)
if not os.path.exists(path_to_save_model):
            os.makedirs(path_to_save_model)
path_to_save_model_fold = os.path.join(path_to_save_model,'fold_{}.pth'.format(fold))
torch.save(model_base_after_force, path_to_save_model_fold) 

In [71]:
model_base_after_force = torch.load(path_to_save_model_fold)
model_base_after_force.eval() 

Causal_Simple_RNN(
  (linear): Linear(in_features=300, out_features=2, bias=True)
  (rnn): RNN(130, 300, batch_first=True)
  (selu): SELU()
  (dropout): Dropout(p=0.2, inplace=False)
)

In [72]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_force, yy_train_force, xx_val_force, yy_val_force, xx_test_force, yy_test_force, model_base_after_force, metric = 'ev')

Train EV: 0.59 
Val EV: 0.57 
Test EV: 0.58 


In [73]:
y_hat, y_true, train_score, v_score, test_score = eval_model(xx_train_base, yy_train_base, xx_val_base, yy_val_base, xx_test_base, yy_test_base, model_base_after_force, metric = 'ev')

Train EV: 0.97 
Val EV: 0.87 
Test EV: 0.85 


  inputs = torch.tensor(x, device=device, dtype=torch.float32)
