In [None]:
import numpy as np
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import split_dataset
import scipy.io
from itertools import combinations

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### **Load data**

In [None]:
root_dir = '/content/drive/My Drive/Response Inhibition Project/2023EvanIEEE/IOP/'

stim_pupil = np.load('/content/drive/My Drive/Response Inhibition Project/2023EvanIEEE/IOP/data/4444/stim_pupil.npy')
spon_pupil = np.load('/content/drive/My Drive/Response Inhibition Project/2023EvanIEEE/IOP/data/4444/spon_pupil.npy')

stim_latent_space = np.load('/content/drive/My Drive/Response Inhibition Project/2023EvanIEEE/IOP/data/4444/stim_latent_space.npy')
spon_latent_space = np.load('/content/drive/My Drive/Response Inhibition Project/2023EvanIEEE/IOP/data/4444/spon_latent_space.npy')

# stim_pupil = stim_pupil.T  # Shape: (600, 900)
# spon_pupil = spon_pupil.T  # Shape: (600, 900)

stim_labels_list = [stim_latent_space[:, :, i] for i in range(6)]
spon_labels_list = [spon_latent_space[:, :, i] for i in range(6)]

# scaler_stim = StandardScaler()
# scaler_spon = StandardScaler()
# stim_pupil = scaler_stim.fit_transform(stim_pupil.reshape(-1, stim_pupil.shape[-1])).reshape(stim_pupil.shape)
# spon_pupil = scaler_spon.fit_transform(spon_pupil.reshape(-1, spon_pupil.shape[-1])).reshape(spon_pupil.shape)

# scaler_stim_labels = [StandardScaler().fit(label) for label in stim_labels_list]
# scaler_spon_labels = [StandardScaler().fit(label) for label in spon_labels_list]
# stim_labels_list = [scaler.transform(label) for scaler, label in zip(scaler_stim_labels, stim_labels_list)]
# spon_labels_list = [scaler.transform(label) for scaler, label in zip(scaler_spon_labels, spon_labels_list)]

# comment the following lines when not running the SUBJECT-CROSSVALIDATION#############
stim_animal_ids = np.load('/content/drive/My Drive/Response Inhibition Project/2023EvanIEEE/IOP/data/raw pupil + dict (animal IDs)/stim_eeg_dict.npy')
spon_animal_ids = np.load('/content/drive/My Drive/Response Inhibition Project/2023EvanIEEE/IOP/data/raw pupil + dict (animal IDs)/spon_eeg_dict.npy')
print("Shape of stim_dict:", stim_animal_ids.shape)
print("Shape of spon_dict:", spon_animal_ids.shape)
#######################################################################################

print("Shape of stim_pupil:", stim_pupil.shape)
print("Shape of spon_pupil:", spon_pupil.shape)
print("Shapes of stim_labels_list:", [label.shape for label in stim_labels_list])
print("Shapes of spon_labels_list:", [label.shape for label in spon_labels_list])

### **Define functions**

In [None]:
def build_model(input_shape):
    model = Sequential()
    model.add(Dense(64, activation='relu', input_shape=input_shape))
    model.add(Dense(32, activation='relu'))
    model.add(Dense(16, activation='relu'))
    model.add(Dense(10))  # Output layer with 10 points
    model.compile(optimizer='adam', loss='mse')
    return model

def compute_metric(y_true, y_pred):
    return np.array([mean_squared_error(y_true[i], y_pred[i]) for i in range(len(y_true))])

class PrintMSECallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        train_loss = logs.get('loss')
        val_loss = logs.get('val_loss')
        print(f'Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

def train_with_combinations(data, labels, animal_ids, k, m):
    unique_animals = np.unique(animal_ids)
    best_model = None
    min_val_loss = float('inf')
    history_best = None

    for train_animals in combinations(unique_animals, k-1):

        train_indices = np.where(np.isin(animal_ids, train_animals))[0]
        test_indices = np.where(~np.isin(animal_ids, train_animals))[0]

        X_train, y_train = data[train_indices], labels[train_indices]
        X_val, y_val = data[test_indices], labels[test_indices]

        kf = KFold(n_splits=m)
        for fold_idx, (train_index, test_index) in kf.split(X_train):
            print(f"Pretraining step on {train_animals} ...  fold {fold_idx}")
            X_train_fold, X_val_fold = X_train[train_index], X_train[test_index]
            y_train_fold, y_val_fold = y_train[train_index], y_train[test_index]

            model = build_model((X_train_fold.shape[1],))
            early_stopping = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
            print_mse_callback = PrintMSECallback()
            history = model.fit(X_train_fold, y_train_fold, batch_size=32, epochs=100, validation_data=(X_val_fold, y_val_fold), callbacks=[early_stopping, print_mse_callback], verbose=0, shuffle=True)

            val_loss = min(history.history['val_loss'])
            if val_loss < min_val_loss:
                min_val_loss = val_loss
                best_model = model
                history_best = history.history

    print(f"Best model-0 found! Proceeding to fine-tuning ...")

    return best_model, history_best, min_val_loss

def fine_tune_and_test(animal, best_model, X_train, y_train, m):
    kf = KFold(n_splits=m)
    min_mse = float('inf')
    best_fold_history = None
    all_mse_vectors = []

    for fold_idx, (train_index, test_index) in kf.split(X_train):
        print(f"Fine-tuning step on {animal} ... fold {fold_idx}")
        X_train_fold, X_val_fold = X_train[train_index], X_train[test_index]
        y_train_fold, y_val_fold = y_train[train_index], y_train[test_index]

        early_stopping = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
        print_mse_callback = PrintMSECallback()
        history = best_model.fit(X_train_fold, y_train_fold, batch_size=32, epochs=100, validation_data=(X_val_fold, y_val_fold), callbacks=[early_stopping, print_mse_callback], verbose=1, shuffle=True)

        val_loss = min(history.history['val_loss'])
        y_pred = best_model.predict(X_val_fold)
        mse_vector = compute_metric(y_val_fold, y_pred)

        if np.mean(mse_vector) < min_mse:
            min_mse = np.mean(mse_vector)
            best_fold_history = history.history

        all_mse_vectors.extend(mse_vector)

    print(f"Best model-x found!")

    return min_mse, best_fold_history, all_mse_vectors

def cross_validate_animals(data, labels, animal_ids, k, m0, m1):
    unique_animals = np.unique(animal_ids)
    results = {}

    for animal in unique_animals:
        print(f"Processing animal {animal} as the test animal...")
        train_indices = np.where(animal_ids != animal)[0]
        test_indices = np.where(animal_ids == animal)[0]


        X_train, y_train = data[train_indices], labels[train_indices] # combinations of k-1 animals -> best pretrained model (Mdl-0)
        best_model, history_best, val_loss_step1 = train_with_combinations(X_train, y_train, animal_ids[train_indices], k, m0)

        X_train_ft, y_train_ft = data[test_indices], labels[test_indices] # left-out 1 animal -> best fine-tuned model (Mdl-x)
        val_loss_step2, best_fold_history, mse_vector = fine_tune_and_test(animal, best_model, X_train_ft, y_train_ft, m1)

        results[f'animal_{animal}'] = {
            'History_BestMdl-0': history_best,
            'AverageLoss_BestMdl-0': val_loss_step1,
            'History_BestMdl-x': best_fold_history,
            'AverageLoss_BestMdl-x': val_loss_step2,
            'Weights_BestMdl-x': best_model.get_weights(),
            'TrialMSE_Mdl-x': mse_vector
        } # lumped result saving

    return results

### **Run and save**

In [None]:
# Run cross-validation and save results
k = 12 # number of total animals
m0 = 5 # number of folds used in pretraining
m1 = 5 # number of folds used in fine-tuning
for i, (stim_labels, spon_labels) in enumerate(zip(stim_labels_list, spon_labels_list), 1):
    stim_results = cross_validate_animals(stim_pupil, stim_labels, stim_animal_ids, k, m0, m1)
    scipy.io.savemat(f'stim_results_{i}.mat', {f'stim_results_{i}': stim_results})

    print(f"Results for stim_pupil -> stim_latent_space_{i} saved.")

    spon_results = cross_validate_animals(spon_pupil, spon_labels, spon_animal_ids, k, m0, m1)
    scipy.io.savemat(f'spon_results_{i}.mat', {f'spon_results_{i}': spon_results})

    print(f"Results for spon_pupil -> spon_latent_space_{i} saved.")