<a href="https://colab.research.google.com/github/nicholasdcrotty/CMBG_BRM_CNNOculomotorAnalysis/blob/main/analysisForUser/CMBG_BHR_CNNBasedAnalysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>



# Using CNNs to analyze oculomotor timecourse data


Before running this code, make sure to select a fast runtime (either your own local runtime or one provided by Colab), as the model fitting procedure and SHAP analysis takes quite some time on the CPU alone.

In order for the CNNs to be applied correctly, your data should be in the form of a .csv file for each feature, with trials as rows and samples as columns. Your conditions file should be in the form of a .csv file, with trials as rows. **Importantly,** the order of the trials in the feature dataframes should match that of the conditions dataframe, so that trials are correctly assigned to their coresponding labels.




**Before running this code**, you need to load your data files into the Colab environment. To do so, click the file icon on the left sidebar, click the "Upload to session storage" icon (the page icon with the upwards arrow) and upload your features and conditions files.

# Load in necessary libraries

In [None]:
# libraries related to neural network
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

# libraries related to input/output arrays
import numpy as np
import pandas as pd

#libraries related to importing data into script
from google.colab import files

#libraries related to graphing NN results
import matplotlib.pyplot as plt
import seaborn as sns

#SHAP values
!pip install shap
import shap

Collecting shap
  Downloading shap-0.47.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting slicer==0.0.8 (from shap)
  Downloading slicer-0.0.8-py3-none-any.whl.metadata (4.0 kB)
Downloading shap-0.47.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/1.0 MB[0m [31m10.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading slicer-0.0.8-py3-none-any.whl (15 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.47.2 slicer-0.0.8


#Pre-define values of experimental details based on your data

In [None]:
desiredLabel = "targetLocation" #what is the column name of the values you are trying to predict?
inputFeatures = 2 #how many features does your input data have?
outputClasses = 6 #how many unique values could your predicted label have?
removeNAs = False #Are there NAs in your data/labels (e.g., a distractor not appearing on some trials), and should they be removed?
removeRowLabels = True #does your .csv file have row labels that need to be removed?
numericallyEncode = True #Do your labels need to be numerically encoded?

details = {"desiredLabel": desiredLabel, "inputFeatures" : inputFeatures, "outputClasses" : outputClasses, "removeNAs" : removeNAs, "removeRowLabels" : removeRowLabels, "numericallyEncode" : numericallyEncode}
print(f"CNN predicting {desiredLabel} from user's dataset, with {inputFeatures} input features and {outputClasses} output classes")

CNN predicting targetLocation from user's dataset, with 2 input features and 6 output classes


# Load in data
Make sure to add objects below if you are using more than 2 features, formatted like the following: `feature3 = pd.read.csv(".csv")`. Then, make sure to add those objects to the `np.stack` call when formatting into a 3D numpy array.

In [None]:
#upload = files.upload() #The GUI is actually much faster at this

feature1 = pd.read_csv('CMBG_BRM_Section4_XPos.csv')
feature2 = pd.read_csv('CMBG_BRM_Section4_YPos.csv')
#repeat above for as many features as you are uploading, or remove feature2 if using one feature
conditions = pd.read_csv('CMBG_BRM_Section4_Conditions.csv')


#format data into 3d numpy array
if (details["inputFeatures"] == 1):
  arr2d = feature1.values
  data = arr2d[:, np.newaxis,:]
else:
  data = np.stack((feature1, feature2), axis=1) #add more features if needed

#remove row labels if needed
if details["removeRowLabels"] == True:
  data = data[:,:,1:]

print(data.shape)

(34560, 2, 600)


#Extract desired labels from conditions datafile
If `numericallyEncode = True`, unique value within your labels will be assigned to its own number, ranging from 0 to `outputClasses - 1`. This numerical encoding allows for the proper comparison of the CNN's output to the label during fitting.

In [None]:
label = details['desiredLabel']

#remove NA values if necessary (e.g., if search object is absent on some trials)
if details["removeNAs"] == True:
  data = data[conditions[label].isna()==False, :, :]
  conditions = conditions[conditions[label].isna()==False]

#extract label information from conditions file - this will serve as the comparison during supervised learning
#numerical encoding of labels
if details["numericallyEncode"] == True:
  oldLabels = conditions[label].astype(str)
  newLabels = oldLabels #preallocate for proper filling of numerically encoded labels
  uniques = oldLabels.unique()
  for i in range(len(uniques)):
    newLabels = newLabels.replace(uniques[i], i)

  #check that labels encoded properly
  for i in range(len(uniques)):
    print(uniques[i], ",", newLabels[oldLabels==uniques[i]].unique())
  labels = newLabels.astype(int)
else:
  labels = conditions[label].astype(int)

180 , [0]
120 , [1]
60 , [2]
300 , [3]
240 , [4]
0 , [5]


  newLabels = newLabels.replace(uniques[i], i)


#Training-validation split (66% training, 33% validation)

In [None]:
N, R, C = data.shape

#find number of training/test trials (if data is not divisible by 3, add extras to training set)
n_test = N // 3 #same as floor(N/3)
n_train = N - n_test

#predefine dataframes with corresponding trial lengths
training_data = np.zeros((n_train, R, C))
test_data = np.zeros((n_test, R, C))

training_labels = []
test_labels = []

trainCounter = 0
testCounter = 0

for i in range(0, len(data)):
  if i % 3 == 0: #every third trial, add data and label to validation set
    test_data[testCounter, :, : ] = data[i, :, :]
    test_labels.append(labels.iloc[i])
    testCounter += 1
  else: #otherwise, add data and label to training set
    training_data[trainCounter, :, : ] = data[i, :, :]
    training_labels.append(labels.iloc[i])
    trainCounter += 1

training_labels = pd.Series(training_labels)
test_labels = pd.Series(test_labels)

print(training_data.shape, test_data.shape)


# Pre-process data

In [None]:
#transform numpy data to tensor
training_data = torch.from_numpy(training_data)
test_data = torch.from_numpy(test_data)

#change to float type - more easily handled by training + test loops
training_data = training_data.float()
test_data = test_data.float()

#Check dimensions of test data and labels

In [None]:
print(test_data.shape)
#trials, features, samples

print(test_labels.shape)
#1-dimensional tensor of identical length to rows of test data

#Determine size of flattened output passed to first linear transform -- different based on the length of your trial and number of features

In [None]:
sizeChecker = nn.Sequential( #the same layers as the first portion of the CNN, just to see what that layer's output size will be
        nn.Conv1d(in_channels = details["inputFeatures"], out_channels = 64, kernel_size = 3),
        nn.ReLU(),
        nn.Dropout(0.25),
        nn.MaxPool1d(kernel_size=5),
        nn.Flatten()
        )
tmp = sizeChecker(test_data[0:1,:,:])
flattenOutput = tmp.shape[1]
print(flattenOutput)

# Initialize the neural network


In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        #Flattening layer (gets called after Sequential convolution layer)
        self.flatten = nn.Flatten()

        #Sequential layer containing convolution and subsequent regularization methods
        self.cnnL1 = nn.Sequential(
        nn.Conv1d(in_channels = details["inputFeatures"], out_channels = 64, kernel_size = 3),
        nn.ReLU(),
        nn.Dropout(0.25),
        nn.MaxPool1d(kernel_size=5),
        )

        #Linear transforms
        self.lin = nn.Linear(in_features = flattenOutput, out_features = 64)
        self.lin2 = nn.Linear(in_features = 64, out_features = 32)
        self.lin3 = nn.Linear(in_features = 32, out_features = details["outputClasses"]) #Output is X number of logits, to be used as the CNN's predictions with loss function

    def forward(self, x):
        #apply convolution
        conv = self.cnnL1(x)

        #rearrage resulting array for proper order during flattening
        permuted = conv.permute(0,2,1)

        #flatten array into one-dimensional tensor
        flattened = self.flatten(permuted)

        #linear transforms, with logits as output
        linear = self.lin(flattened)
        linear2 = self.lin2(linear)
        logits = self.lin3(linear2)
        return logits

model = NeuralNetwork()

# Initialize Early Stopping class

In [None]:
#EarlyStopping class used in pytorchtools library
#source: https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pth', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# Define a Custom Dataset object that stores the samples and the labels together for each trial

In [None]:
class CustomDataset(Dataset):
    def __init__(self, labelsObject, dataObject, transform=None, target_transform=None):
        self.labels = labelsObject
        self.dataframe = dataObject
        self.transform = transform #in case data needs to be transformed to a different type
        self.target_transform = target_transform #in case labels need to be transformed to a different type

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        data = self.dataframe[idx]
        label = self.labels.iloc[idx]
        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)
        return data, label

# Create the CustomDatasets for training and validation data

In [None]:
train = CustomDataset(labelsObject = training_labels, dataObject = training_data)
test = CustomDataset(labelsObject = test_labels, dataObject = test_data)

# Create DataLoader objects containing the training and test data/labels
After our data and labels have been stored in our Custom Dataset objects, we need to place these objects into an object called a Dataloader. Dataloaders are PyTorch objects designed to divide input files into a series of minibatches and feed them into a CNN. This batching process allows the network to update its parameters more frequently than if it was fit to all of the data at once. The size of each minibatch, controlled by ``batch_size``, is the number of trials propagated through the CNN before the parameters are updated. In the current study, we chose a minibatch size of 64 trials.

In [None]:
batch_size = 64
train_dataloader = DataLoader(train, batch_size=batch_size, drop_last=False)
test_dataloader = DataLoader(test, batch_size=batch_size, drop_last=False)

## Hyperparameters & Optimizer




In [None]:
#Hyperparameters
learning_rate = 1e-3
epochs = 2500 #arbitrary large number to give time for early stopping to occur
loss_fn = nn.CrossEntropyLoss()

#Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


## Full Implementation
Here, we create our respective loops for the training/validation process with one function for training and one function for validation. The validation function also stores the trial-level accuracies of the CNN, which will be used to graph the CNN's classification accuracy later on.



In [None]:
def train_model(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train() #set the model to training mode - important for batch normalization and dropout layers
    for batch, (X, y) in enumerate(dataloader):

        #compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        #backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        #print current loss every 20 batches
        if batch % 20 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_model(dataloader, model, loss_fn, trialLevelDF, epNum = 0, earlyStop="none", lossChecker = "none", mode = "loss"):
    size = len(dataloader.dataset)
    model.eval() #set the model to evaluation mode - important for batch normalization and dropout layers

    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    #evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    with torch.no_grad():
      batchCount = 0
      for X, y in dataloader:
          pred = model(X) #apply model to batch
          predSize = len(pred)
          for p in (range(predSize)): #iterate through current batch to get trial-level accuracies

              #determine whether there was a correct or incorrect prediction on the current dataframe
              trialAcc = (pred[p].argmax().item() == y[p].item())

              #save trial-level accuracy to initialized dataframe
              trialLevelDF.iloc[(batch_size*batchCount)+p,epNum] = trialAcc

          #update loss and accuracy metrics for reporting
          test_loss += loss_fn(pred, y).item()
          correct += (pred.argmax(1) == y).type(torch.float).sum().item() #same calculation as above, but performed for entire batch at once
          batchCount+=1

    #compute overall loss and accuracy
    test_loss /= num_batches
    correct /= size

    #if we want to save the CNN's lowest loss value, record that minimum loss and the epoch where it occurred
    if lossChecker != "none" and test_loss < lossChecker[0]:
      lossChecker[0] = test_loss
      lossChecker[1] = epNum

    #if we want to implement early stopping, apply early stopping
    if earlyStop != "none":
      earlyStop(test_loss, model)

    #print the relevant performance metrics -- loss, accuracy, or both
    if mode == "loss":
      print(f"Epoch {epNum+1} complete! \n Current loss: {test_loss:>8f}    Lowest loss: {lossChecker[0]:>8f} \n")
    elif mode == "accuracy":
      print(f"Accuracy: {(100*correct):>8f}% \n")
    elif mode =="both":
      print(f"Epoch {epNum+1} complete! \n Current loss: {test_loss:>8f}    Lowest loss: {lossChecker[0]:>8f} \n Accuracy: {(100*correct):>8f}% \n")

# Implement training and test procedures

Here, we initalize a dataframe to store the trial-level validation accuracies and implement the training and validation functions we created earlier. We also save the trial-level accuracies as a .csv file to store the NN's current performance (as the fitting process is stochastic).


In [None]:
#for storing CNN validation accuracy in each epoch
performanceSummary = pd.DataFrame(index=range(len(test_dataloader.dataset)), columns = range(epochs))

#early stopping object
early_stopping = EarlyStopping(patience=10, verbose=True)

#for storing minimum loss value and epoch where it occurred
lossCheck = [np.inf, 0]

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")

    #training
    train_model(train_dataloader, model, loss_fn, optimizer)

    #validation
    test_model(test_dataloader, model, loss_fn, performanceSummary,
               epNum = t, earlyStop = early_stopping, lossChecker=lossCheck, mode = "loss")

    #check early stopping criteria
    if early_stopping.early_stop:
      print(f"Early stopping @ epoch {t+1}")
      break

print("Done!")

# Quick summary of model results
Once our CNN has been fit to the data, we can take a quick look at its accuracy in predicting target location from eye position using the trial level accuracies. Here, we graph the accuracy across all run epochs, and identify the accuracy that occurred when the network demonstrated minimum loss. This accuracy value is the one we use as a metric of model performance. The associated parameter weights that gave rise to this performance have been saved as `checkpoint.pth`.

In [None]:
import matplotlib.pyplot as plt
bestEpoch = lossCheck[1]
performanceSummary = performanceSummary.iloc[:,0:t+1] #in case early stopping happens
epAcc = (100*(performanceSummary.sum() / len(test_dataloader.dataset)))
bestPerformance = epAcc[bestEpoch]

#plot accuracy across epochs and label accuracy from interation with lowest validation loss
plt.plot(range(t+1), epAcc)
plt.arrow(x = bestEpoch, y = bestPerformance - 0.1*bestPerformance, dx = 0, dy = 0.1*bestPerformance, width = .2)
plt.text(x = bestEpoch-0.33*bestEpoch, y = bestPerformance - 0.2*bestPerformance, s = f"Accuracy from epoch {bestEpoch} (starting \n at 0) with the lowest loss -- \n we use this in analysis!")
plt.text(x = bestEpoch-0.33*bestEpoch, y = bestPerformance - 0.3*bestPerformance, s = "Note: Overfitting may have \n happened in later epochs, \n causing a severe accuracy \n decrease")

#save graph of performance across epochs
plt.savefig(".png")
files.download(".png")

#view graph and accuracy values across epochs
plt.show()
print(epAcc)

# View accuracy from epoch with lowest loss

In [None]:
#load in best weights
bestWeights = torch.load("checkpoint.pth")
model.load_state_dict(bestWeights)

#check of prediction accuracy for model
trialLevelAcc = pd.DataFrame(index=range(len(test_dataloader.dataset)), columns = range(1))
test_model(test_dataloader, model, loss_fn, trialLevelAcc, mode = "accuracy")

#Download trial level accuracy as .csv file through Google Drive

In [None]:
trialLevelAcc.to_csv('.csv')
files.download('.csv')

# Save model parameter weights locally and download to computer

In [None]:
torch.save(model.state_dict(), ".pth") #.pth file is type used for storing parameter weights
files.download(".pth")

#Calculate SHAP values for each trial of validation set
Fair warning, this takes a substantial amount of time to run.

In [None]:
trialLevelSHAP = pd.DataFrame(index=range(len(test_dataloader.dataset)), columns = range(len(test_dataloader.dataset[0][0][0])))
print(trialLevelSHAP.shape)

shapBatch = test_dataloader.dataset[0:100] #first 100 trials used for explanation model
explainer = shap.DeepExplainer(model, shapBatch[0])

loopDur = len(test_dataloader.dataset)
for s in range(loopDur):

  #calculate SHAP values relative to explainer set
  shap_values = explainer.shap_values(test_dataloader.dataset[s:s+1][0])

  #compute global feature importance by taking absolute value
  abs_shap_values = np.abs(shap_values)

  #average across classes and features to get one global feature importance value per sample
  SHAP = abs_shap_values.mean(axis=(0,1,3))

  #save trial-level accuracy to initialized dataframe
  trialLevelSHAP.iloc[s,:] = SHAP
  if s % 100 ==0:
    print(s)

print("Done!")

#Download SHAP values as .csv file

In [None]:
trialLevelSHAP = pd.DataFrame(trialLevelSHAP)
trialLevelSHAP.to_csv('.csv')
files.download('.csv')

#Graph heatmap of SHAP values for last trial as a quick visualization

In [None]:
df = pd.DataFrame({"SHAP values for each sample in timecourse": abs_shap_values.mean(axis=(0,1,3))},
                  index=range(len(shap_values[0][0])))
g = sns.heatmap(df, fmt="g", cmap='viridis')
g.set_yticks(range(0,len(test_data[1,1,:]),50))
g.set_yticklabels(range(0,len(test_data[1,1,:]),50))
plt.show()