<a href="https://colab.research.google.com/github/nicholasdcrotty/CDVMBG_BRM_CNNOculomotorAnalysis/blob/main/analysisForReplication/CDVMBG_BRM_CNNModelGeneralizability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



# Using CNNs to analyze oculomotor timecourse data


Before running this code, make sure to select a fast runtime (either your own local runtime or one provided by Colab), as the model fitting procedure and SHAP analysis takes quite some time on the CPU alone.




**Before running this code**, you need to load your data files into the Colab environment. To do so, click the file icon on the left sidebar, click the "Upload to session storage" icon (the page icon with the upwards arrow) and upload the matching features and conditions files. There should be two feature `.csv` files and one conditions `.csv`. file.

# **--------Section 1: Loading in data and preprocessing--------**

This section needs to be run regardless if you are newly fitting the networks to data or replicating the exact manuscript results. **Prior to running this code**, make sure you have loaded the oculomotor data `.csv` files (for both the x- and y-position) and the conditions `.csv` file.

## Load in necessary libraries

In [None]:
# libraries related to neural network
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

# libraries related to input/output arrays
import numpy as np
import pandas as pd

#libraries related to importing data into script
#from google.colab import drive
#drive.mount('/content/drive')
from google.colab import files

#libraries related to graphing NN results
import matplotlib.pyplot as plt
import seaborn as sns

#SHAP values
!pip install shap
import shap

Collecting shap
  Downloading shap-0.47.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting slicer==0.0.8 (from shap)
  Downloading slicer-0.0.8-py3-none-any.whl.metadata (4.0 kB)
Downloading shap-0.47.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m38.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading slicer-0.0.8-py3-none-any.whl (15 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.47.2 slicer-0.0.8


## Pre-define values of experimental details based on your data

In [None]:
desiredLabel = "distractorLocation" #what is the column name of the values you are trying to predict?
inputFeatures = 2 #how many features does your input data have?
outputClasses = 6 #how many unique values could your predicted label have?
removeNAs = True #Are there NAs in your data/labels (e.g., a distractor not appearing on some trials), and should they be removed?
removeRowLabels = False #does your .csv file have row labels that need to be removed?
numericallyEncode = True #Do your labels need to be numerically encoded?

details = {"desiredLabel": desiredLabel, "inputFeatures" : inputFeatures, "outputClasses" : outputClasses, "removeNAs" : removeNAs, "removeRowLabels" : removeRowLabels, "numericallyEncode" : numericallyEncode}
print(f"CNN predicting {desiredLabel} from user's dataset, with {inputFeatures} input features and {outputClasses} output classes")

CNN predicting distractorLocation from user's dataset, with 2 input features and 6 output classes


## Load in data

In [None]:
#upload = files.upload() #The GUI is actually much faster at this

feature1 = pd.read_csv('DVCMG_testPhase_xData.csv')
feature2 = pd.read_csv('DVCMG_testPhase_yData.csv')
#repeat above for as many features as you are uploading, or remove feature2 if using one feature
conditions = pd.read_csv('DVCMG_testPhase_conditionLabels.csv')


#format data into 3d numpy array
if (details["inputFeatures"] == 1):
  arr2d = feature1.values
  data = arr2d[:, np.newaxis,:]
else:
  data = np.stack((feature1, feature2), axis=1) #add more features if needed

#remove row labels if needed
if details["removeRowLabels"] == True:
  data = data[:,:,1:]

print(data.shape)
conditions.head()

(52800, 2, 600)


Unnamed: 0,targetLocation,distractorLocation
0,60,180.0
1,180,60.0
2,120,
3,180,300.0
4,300,


## Extract desired labels from conditions datafile

In [None]:
label = details['desiredLabel']

#remove NA values if necessary (e.g., if search object is absent on some trials)
if details["removeNAs"] == True:
  data = data[conditions[label].isna()==False, :, :]
  conditions = conditions[conditions[label].isna()==False]

#extract label information from conditions file - this will serve as the comparison during supervised learning
#numerical encoding of labels
if details["numericallyEncode"] == True:
#   oldLabels = conditions[label].astype(str)
#   newLabels = oldLabels #preallocate for proper filling of numerically encoded labels
#   print(oldLabels.tail())
#   uniques = oldLabels.unique()
#   for i in range(len(uniques)):
#     newLabels = newLabels.replace(uniques[i], i)

#   #check that labels encoded properly
#   print(newLabels.tail())
#   for i in range(len(uniques)):
#     print(uniques[i], ",", newLabels[oldLabels==uniques[i]].unique())
#   labels = newLabels.astype(int)
# else:
  labels = conditions[label].astype(int)
  labels = labels.map({0:0, 60:1, 120:2, 180:3, 240:4, 300:5})

## Pre-process data

In [None]:
#transform numpy data to tensor
data = torch.from_numpy(data)

#change to float type - more easily handled by training + test loops
data = data.float()

## Check dimensions of data and labels

In [None]:
print(data.shape)
#trials, features, samples

print(labels.shape)
#1-dimensional tensor of identical length to rows of test data

torch.Size([26400, 2, 600])
(26400,)


## Determine size of flattened output passed to first linear transform -- different based on the length of your trial and number of features

In [None]:
sizeChecker = nn.Sequential( #the same layers as the first portion of the CNN, just to see what that layer's output size will be
        nn.Conv1d(in_channels = details["inputFeatures"], out_channels = 64, kernel_size = 3),
        nn.ReLU(),
        nn.Dropout(0.25),
        nn.MaxPool1d(kernel_size=5),
        nn.Flatten()
        )
tmp = sizeChecker(data[0:1,:,:])
flattenOutput = tmp.shape[1]
print(flattenOutput)

7616


## Initialize the neural network and load in weights from the CNN predicting distractor location using Massa et al. (2024) data


In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        #Flattening layer (gets called after Sequential convolution layer)
        self.flatten = nn.Flatten()

        #Sequential layer containing convolution and subsequent regularization methods
        self.cnnL1 = nn.Sequential(
        nn.Conv1d(in_channels = details["inputFeatures"], out_channels = 64, kernel_size = 3),
        nn.ReLU(),
        nn.Dropout(0.25),
        nn.MaxPool1d(kernel_size=5),
        )

        #Linear transforms
        self.lin = nn.Linear(in_features = flattenOutput, out_features = 64)
        self.lin2 = nn.Linear(in_features = 64, out_features = 32)
        self.lin3 = nn.Linear(in_features = 32, out_features = details["outputClasses"]) #Output is X number of logits, to be used as the CNN's predictions with loss function

    def forward(self, x):
        #apply convolution
        conv = self.cnnL1(x)

        #rearrage resulting array for proper order during flattening
        permuted = conv.permute(0,2,1)

        #flatten array into one-dimensional tensor
        flattened = self.flatten(permuted)

        #linear transforms, with logits as output
        linear = self.lin(flattened)
        linear2 = self.lin2(linear)
        logits = self.lin3(linear2)
        return logits

model = NeuralNetwork()
massaWeights = torch.load("model_state_MassaDistLoc.pth")
model.load_state_dict(massaWeights)

<All keys matched successfully>

## Initialize Early Stopping class

In [None]:
#EarlyStopping class used in pytorchtools library
#source: https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

## Define a Custom Dataset object that stores the samples and the labels together for each trial

In [None]:
class CustomDataset(Dataset):
    def __init__(self, labelsObject, dataObject, transform=None, target_transform=None):
        self.labels = labelsObject
        self.dataframe = dataObject
        self.transform = transform #in case data needs to be transformed to a different type
        self.target_transform = target_transform #in case labels need to be transformed to a different type

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        data = self.dataframe[idx]
        label = self.labels.iloc[idx]
        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)
        return data, label

## Initialize the CustomDataset

In [None]:
preppedData = CustomDataset(labelsObject = labels, dataObject = data)
labels.head()

Unnamed: 0,distractorLocation
0,3
1,1
3,5
5,3
6,0


## Create DataLoader objects containing the data/labels


In [None]:
batch_size = 64
allForVal_dataloader = DataLoader(preppedData, batch_size=batch_size, drop_last=False)

## Hyperparameters

Since we're not doing any additional training, we do not include any optimizer or


In [None]:
#Hyperparameters
learning_rate = 1e-3
epochs = 2500 #arbitrary large number to give time for early stopping to occur
loss_fn = nn.CrossEntropyLoss()



## Full Implementation
Since we're using pre-trained weights and not performing any additional training, we only need the `model_test` function.




In [None]:
def test_model(dataloader, model, loss_fn, trialLevelDF, predictionDF, logitDF, epNum = 0, earlyStop="none", lossChecker = "none", mode = "loss"):
    size = len(dataloader.dataset)
    model.eval() #set the model to evaluation mode - important for batch normalization and dropout layers

    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    #evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    with torch.no_grad():
      batchCount = 0
      for X, y in dataloader:
          pred = model(X) #apply model to batch
          predSize = len(pred) #rows of pred
          for p in (range(predSize)): #iterate through current batch to get trial-level accuracies

              #determine whether there was a correct or incorrect prediction on the current dataframe
              trialAcc = (pred[p].argmax().item() == y[p].item())
              prediction = pred[p].argmax().item()
              logits = pred[p]

              #save trial-level accuracy to initialized dataframe
              trialLevelDF.iloc[(batch_size*batchCount)+p,epNum] = trialAcc
              predictionDF.iloc[(batch_size*batchCount)+p,epNum] = prediction
              logitDF.iloc[(batch_size*batchCount)+p,:] = logits

          #update loss and accuracy metrics for reporting
          test_loss += loss_fn(pred, y).item()
          correct += (pred.argmax(1) == y).type(torch.float).sum().item() #same calculation as above, but performed for entire batch at once
          batchCount+=1

    #compute overall loss and accuracy
    test_loss /= num_batches
    correct /= size

    #if we want to save the CNN's lowest loss value, record that minimum loss and the epoch where it occurred
    if lossChecker != "none" and test_loss < lossChecker[0]:
      lossChecker[0] = test_loss
      lossChecker[1] = epNum

    #if we want to implement early stopping, apply early stopping
    if earlyStop != "none":
      earlyStop(test_loss, model)

    #print the relevant performance metrics -- loss, accuracy, or both
    if mode == "loss":
      print(f"Epoch {epNum+1} complete! \n Current loss: {test_loss:>8f}    Lowest loss: {lossChecker[0]:>8f} \n")
    elif mode == "accuracy":
      print(f"Accuracy: {(100*correct):>8f}% \n")
    elif mode =="both":
      print(f"Epoch {epNum+1} complete! \n Current loss: {test_loss:>8f}    Lowest loss: {lossChecker[0]:>8f} \n Accuracy: {(100*correct):>8f}% \n")

# **--------Section 2: Applying model to data -------**

## Apply CNN to data, return prediction accuracy

In [None]:
#check of prediction accuracy for model
trialLevelAcc = pd.DataFrame(index=range(len(allForVal_dataloader.dataset)), columns = range(1))
trialLevelPredictions = pd.DataFrame(index=range(len(allForVal_dataloader.dataset)), columns = range(1))
trialLevelLogits = pd.DataFrame(index=range(len(allForVal_dataloader.dataset)), columns = range(6))
test_model(allForVal_dataloader, model, loss_fn, trialLevelAcc,
          trialLevelPredictions, trialLevelLogits, mode = "accuracy")


Accuracy: 19.867424% 



## Download relevant model performance metrics as .csv files through Google Drive
This is done in multiple code sections because the resulting files are quite large, and some may not  download if all are run in one section.

In [None]:
trialLevelAcc.to_csv('trialLevelAcc_TransferLearning_MassaDistOnDoyleDist.csv')
files.download('trialLevelAcc_TransferLearning_MassaDistOnDoyleDist.csv')

Mounted at /content/drive


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
#trial-level predictions from best epoch
trialLevelPredictions.to_csv("trialLevelPredictions_TransferLearning_MassaDistOnDoyleDist.csv")
files.download("trialLevelPredictions_TransferLearning_MassaDistOnDoyleDist.csv")

#trial-level logits (CNN output values) from best epoch
trialLevelLogits.to_csv("trialLevelOutputs_TransferLearning_MassaDistOnDoyleDist.csv")
files.download("trialLevelOutputs__TransferLearning_MassaDistOnDoyleDist.csv")

# **--------Section 3: SHAP analysis -------**
We have made the SHAP analysis its own section as it takes a substantial amount of time + processing to run.

In [None]:
trialLevelSHAP = pd.DataFrame(index=range(len(allForVal_dataloader.dataset)), columns = range(len(allForVal_dataloader.dataset[0][0][0])))
print(trialLevelSHAP.shape)

shapBatch = allForVal_dataloader.dataset[0:100] #first 100 trials used for explanation model
explainer = shap.DeepExplainer(model, shapBatch[0])

loopDur = len(allForVal_dataloader.dataset)
for s in range(loopDur):

  #calculate SHAP values relative to explainer set
  shap_values = explainer.shap_values(allForVal_dataloader.dataset[s:s+1][0])

  #compute global feature importance by taking absolute value
  abs_shap_values = np.abs(shap_values)

  #average across classes and features to get one global feature importance value per sample
  SHAP = abs_shap_values.mean(axis=(0,1,3))

  #save trial-level accuracy to initialized dataframe
  trialLevelSHAP.iloc[s,:] = SHAP
  if s % 100 ==0:
    print(s)

print("Done!")

## Download SHAP values as .csv file

In [None]:
trialLevelSHAP = pd.DataFrame(trialLevelSHAP)
trialLevelSHAP.to_csv("shapValues_TransferLearning_MassaDistOnDoyleDist.csv")
files.download("shapValues_TransferLearning_MassaDistOnDoyleDist.csv")

## Graph heatmap of SHAP values for last trial as a quick visualization

In [None]:
df = pd.DataFrame({"SHAP values for each sample in timecourse": abs_shap_values.mean(axis=(0,1,3))},
                  index=range(len(shap_values[0][0])))
g = sns.heatmap(df, fmt="g", cmap='viridis')
g.set_yticks(range(0,len(data[1,1,:]),50))
g.set_yticklabels(range(0,len(data[1,1,:]),50))
plt.show()