In [1]:
import os
import json
import torch
from torch import nn
import argparse
import pandas as pd
import numpy as np
from src.utils.basic_funcs import set_seed
from src.models import neural_network as net
from src.analysis import ann as ann
from torch.utils.data import Dataset, DataLoader
from src.utils import basic_funcs as basic
import math
import copy
from tqdm.auto import tqdm
from scipy import stats
import math

# Helpers

In [2]:
def numpy_to_python(obj):
    """Convert numpy objects to Python native types for JSON serialization."""
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, dict):
        return {k: numpy_to_python(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [numpy_to_python(item) for item in obj]
    return obj
    
def filter_participant_data(df, participant, task_section):
    """
    Filter participant data by task section.
    """
    return df.loc[
        (df['participant'] == participant) & (df['task_section'] == task_section),
        ['index', 'feature_idx', 'feat_val', 'noisy_feedback_value', 'stimID','test_trial']
    ].reset_index(drop=True)

def adjust_indices(participant_data, offset):
    """
    Adjust the indices of participant data by the specified offset.
    """
    participant_data['index'] -= offset
    return participant_data.reset_index(drop=True)

def create_inputs_matrix(participant_data, n_stim_per_task):
    """
    Create an inputs matrix with one-hot encoded stimulus IDs.
    """
    length = participant_data.shape[0]
    inputs = np.zeros((length, n_stim_per_task * 2))
    for index, row in participant_data.iterrows():
        inputs[index, int(row['stimID'])] = 1
    return inputs

def process_raw_inputs_and_labels(participant_data, n_stim_per_task, task_idx):
    """
    Process raw inputs and labels for a given task.
    """
    unique_inputs = participant_data['stimID'].unique().astype(int)
    raw_inputs = np.full((n_stim_per_task, n_stim_per_task * 2), np.nan, dtype=np.float32)
    raw_labels = np.full((4, n_stim_per_task), np.nan, dtype=np.float32)

    for idx, stim_id in enumerate(unique_inputs):
        feat1 = participant_data.loc[
            (participant_data['stimID'] == stim_id) & (participant_data['feature_idx'] == 0), 'feat_val'
        ].unique()
        feat2 = participant_data.loc[
            (participant_data['stimID'] == stim_id) & (participant_data['feature_idx'] == 1), 'feat_val'
        ].unique()
        raw_labels[0, idx] = np.cos(feat1)[0]
        raw_labels[1, idx] = np.sin(feat1)[0]
        raw_labels[2, idx] = np.cos(feat2)[0]
        raw_labels[3, idx] = np.sin(feat2)[0]

        input_skeleton = np.zeros((n_stim_per_task * 2))
        input_skeleton[stim_id] = 1
        raw_inputs[idx, :] = input_skeleton

    return raw_inputs, raw_labels

def assemble_dataset(participant_data, inputs, label_cos, label_sin):
    """
    Assemble the dataset dictionary for a task.
    """
    return {
        'index': participant_data['index'].values,
        'stim_index': participant_data['stimID'].values,
        'input': inputs,
        'feature_probe': participant_data['feature_idx'].values,
        'test_stim': participant_data['test_trial'].values,
        'label_x': label_cos,
        'label_y': label_sin,
    }
    
def get_datasets(df, participant, task_parameters):
    """
    Main function to get datasets and process raw inputs and labels.
    """
    # Filter data for each task section
    participant_training_A1 = filter_participant_data(df, participant, 'A1')
    participant_training_B = filter_participant_data(df, participant, 'B')
    participant_training_A2 = filter_participant_data(df, participant, 'A2')

    # Adjust indices for B and C
    A_length = len(participant_training_A1)
    B_length = len(participant_training_B)
    participant_training_B = adjust_indices(participant_training_B, A_length)
    participant_training_A2 = adjust_indices(participant_training_A2, A_length + B_length)

    # Create inputs matrices
    A1_inputs = create_inputs_matrix(participant_training_A1, task_parameters['nStim_perTask'])
    B_inputs = create_inputs_matrix(participant_training_B, task_parameters['nStim_perTask'])
    A2_inputs = create_inputs_matrix(participant_training_A2, task_parameters['nStim_perTask'])

    # Process raw inputs and labels
    raw_inputs = np.full((3, task_parameters['nStim_perTask'], task_parameters['nStim_perTask'] * 2), np.nan, dtype=np.float32)
    raw_labels = np.full((3, 4, task_parameters['nStim_perTask']), np.nan, dtype=np.float32)

    raw_inputs[0], raw_labels[0] = process_raw_inputs_and_labels(participant_training_A1, task_parameters['nStim_perTask'], 0)
    raw_inputs[1], raw_labels[1] = process_raw_inputs_and_labels(participant_training_B, task_parameters['nStim_perTask'], 1)
    raw_inputs[2], raw_labels[2] = process_raw_inputs_and_labels(participant_training_A2, task_parameters['nStim_perTask'], 2)

    # Assemble datasets
    dataset_A1 = assemble_dataset(participant_training_A1, A1_inputs, np.cos(participant_training_A1['feat_val'].values), np.sin(participant_training_A1['feat_val'].values))
    dataset_B = assemble_dataset(participant_training_B, B_inputs, np.cos(participant_training_B['feat_val'].values), np.sin(participant_training_B['feat_val'].values))
    dataset_A2 = assemble_dataset(participant_training_A2, A2_inputs, np.cos(participant_training_A2['feat_val'].values), np.sin(participant_training_A2['feat_val'].values))

    return dataset_A1, dataset_B, dataset_A2, raw_inputs, raw_labels

class CreateParticipantDataset(Dataset):
    """PyTorch Dataset for participant data."""
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset['index'])

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = {key: self.dataset[key][idx] for key in self.dataset}
        if self.transform:
            sample = self.transform(sample)
        return sample

def compute_accuracy(predictions, ground_truth):
    """Compute accuracy between predictions and ground truth in radians."""
    predictions = np.asarray(predictions)
    ground_truth = np.asarray(ground_truth)
    wrapped_difference = basic.wrap_to_pi(predictions - ground_truth)
    normalized_error = np.abs(wrapped_difference) / np.pi
    return 1 - normalized_error

def batch_to_torch(numpy_version):
    """Convert numpy batch to torch tensor."""
    return numpy_version.type(torch.FloatTensor)

# Data

In [3]:
# Set random seed
set_seed(2024)
condition_name = "rich_50" # Condition to run (e.g., rich_10, rich_50, rich_200)
base_folder='./'


In [4]:
# Setup paths
data_folder = os.path.join(base_folder, 'data')
config_path = os.path.join(base_folder, 'src', 'models', 'ann_experiments.json')

In [5]:
# Load settings and find specified condition
settings = json.load(open(config_path, 'r'))

condition = next((c for c in settings['conditions'] if c['name'] == condition_name), None)
if not condition:
    raise ValueError(f"Condition '{condition_name}' not found in settings")

condition

{'name': 'rich_50', 'gamma': 0.001, 'dim_hidden': 50}

In [6]:
# Load participant data
df = pd.read_csv(os.path.join(data_folder, 'participants', 'trial_df.csv'))

df.loc[df['task_section']=='B','test_trial']=0

df = df.loc[(df['task_section']=='A1') | 
                 (df['task_section']=='B') | 
                 (df['task_section']=='A2'), :] # remove debrief trials from analysis

participants = df['participant'].unique()

df


Unnamed: 0,participant,index,task_section,feature_idx,stimID,taskID,feat_val,noisy_feedback_value,resp_reactiontime,dial_resp,...,block,regime,accuracy,study,A_rule,B_rule,rule_applied,test_stim_B,test_stim_A,test_trial
0,study1_same_sub9,0,A1,0.0,2.0,0.0,0.900836,0.910773,12.212,4.886953,...,0,study1_same,0.268820,1,-1.901668,-1.901668,,5.0,9.0,0
1,study1_same_sub9,1,A1,1.0,2.0,0.0,5.282353,5.264300,8.746,3.202714,...,0,study1_same,0.338030,1,-1.901668,-1.901668,2.291941,5.0,9.0,0
2,study1_same_sub9,2,A1,0.0,0.0,0.0,5.849861,5.946899,10.087,3.227958,...,0,study1_same,0.165422,1,-1.901668,-1.901668,,5.0,9.0,0
3,study1_same_sub9,3,A1,1.0,0.0,0.0,3.948193,4.032603,8.283,2.462548,...,0,study1_same,0.527105,1,-1.901668,-1.901668,2.798835,5.0,9.0,0
4,study1_same_sub9,4,A1,0.0,11.0,0.0,4.051682,4.148304,4.658,3.951195,...,0,study1_same,0.968014,1,-1.901668,-1.901668,,5.0,9.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
117357,study1_far_sub57,355,A2,1.0,0.0,0.0,1.375241,,3.314,4.466989,...,29,study1_far,0.015866,1,1.788678,-1.352915,-1.351399,7.0,0.0,1
117358,study1_far_sub57,356,A2,0.0,9.0,0.0,4.865064,4.963278,2.577,2.791894,...,29,study1_far,0.340090,1,1.788678,-1.352915,,7.0,0.0,0
117359,study1_far_sub57,357,A2,1.0,9.0,0.0,0.370556,,1.249,3.222413,...,29,study1_far,0.092226,1,1.788678,-1.352915,-1.740865,7.0,0.0,0
117360,study1_far_sub57,358,A2,0.0,5.0,0.0,5.273465,5.240881,2.183,0.693492,...,29,study1_far,0.457851,1,1.788678,-1.352915,,7.0,0.0,0


In [7]:
participants.shape

(305,)

# Setup

In [8]:
# Setup parameters
task_parameters = {
        "nStim_perTask": 6,
        "schedules": ['same', 'near', 'far'],
        "schedule_names": ['same rule', 'near rule', 'far rule']
    }

# Network parameters
dim_input = task_parameters['nStim_perTask'] * 2
dim_hidden = condition['dim_hidden']
dim_output = 4  # 2 dimensions for each feature
network_params = [dim_input, dim_hidden, dim_output]


# Training parameters - convert to list format expected by neural_network.py
training_params = [
    participants,  # participants list
    settings['n_phase'],  # n_phase
    settings['n_epochs'],  # n_epochs
    settings['n_epochs'] * (task_parameters['nStim_perTask']*2) * 10,  # n_train_trials
    settings['shuffle'],  # shuffle
    settings['batch_size'],  # batch_size
    condition['gamma'],  # gamma
    settings['learning_rate'],  # learning rate
]


In [9]:
# Setup simulation folder if not existent yet
sim_folder = os.path.join(data_folder, 'any_network', condition_name)
os.makedirs(sim_folder, exist_ok=True)

In [10]:
# Save settings of the run
settings_to_save = {
    "condition": condition,
    "training_params": {
        "participants": ann.numpy_to_python(participants),
        "n_phase": settings['n_phase'],
        "n_epochs": settings['n_epochs'],
        "n_train_trials": settings['n_epochs'] * (task_parameters['nStim_perTask']*2) * 10,
        "shuffle": settings['shuffle'],
        "batch_size": settings['batch_size'],
        "gamma": condition['gamma'],
        "lr": settings['learning_rate'],
    },
    "network_params": network_params,
    "task_parameters": task_parameters
}

# Convert numpy arrays to Python native types before saving
settings_to_save = ann.numpy_to_python(settings_to_save)
with open(os.path.join(sim_folder, 'settings.json'), 'w') as f:
    json.dump(settings_to_save, f, indent=4)

In [11]:
# Unpack parameters
dim_input, dim_hidden, dim_output = network_params
participants, n_phase, n_epochs, n_train_trials, shuffle, batch_size, gamma, lr = training_params

# add these params
do_test = 1
dosave=1

results = []

# Train

In [12]:
# for each participant
for idx_p, participant in tqdm(enumerate(participants[0:10])): # test for first participant only
    print(f'Starting participant {participant}')

    # Get participant data
    dataset_A1, dataset_B, dataset_A2, raw_inputs, raw_labels = basic.get_datasets(df, participant, task_parameters)
    
    # Order inputs by feature
    A_inputs = raw_inputs[0]
    B_inputs = raw_inputs[1] 
    A_labels_feat1 = raw_labels[0, 0:2].T
    B_labels_feat1 = raw_labels[1, 0:2].T
    ordered_indices_A = basic.get_clockwise_order(A_labels_feat1)
    ordered_indices_B = basic.get_clockwise_order(B_labels_feat1)
    ordered_inputs = np.concatenate((A_inputs[ordered_indices_A], B_inputs[ordered_indices_B]), axis=0)


0it [00:00, ?it/s]

Starting participant study1_same_sub9
Starting participant study2_near_sub43
Starting participant study1_same_sub43
Starting participant study2_near_sub57
Starting participant study1_near_sub21
Starting participant study1_near_sub35
Starting participant study2_same_sub34
Starting participant study1_near_sub34
Starting participant study2_same_sub20
Starting participant study1_same_sub56


### Inputs

In [13]:
pd.DataFrame(A_inputs)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
2,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0


### Labels

In [14]:
pd.DataFrame(A_labels_feat1)


Unnamed: 0,0,1
0,-0.937025,-0.349262
1,0.209979,-0.977706
2,0.765025,0.644001
3,0.592507,0.805566
4,-0.493809,-0.86957
5,-0.301487,0.95347


In [15]:
ordered_indices_A

array([3, 2, 1, 4, 0, 5])

In [16]:
pd.DataFrame(ordered_inputs)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
1,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
3,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
7,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
8,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0


In [17]:
# Create data loaders
trainloader_A1 = DataLoader(CreateParticipantDataset(dataset_A1), batch_size=batch_size, shuffle=shuffle)
trainloader_B = DataLoader(CreateParticipantDataset(dataset_B), batch_size=batch_size, shuffle=shuffle)
trainloader_A2 = DataLoader(CreateParticipantDataset(dataset_A2), batch_size=batch_size, shuffle=shuffle)


In [18]:
# Run a complete learning cycle
"""
Runs a complete learning cycle:
A: n_epochs of training on task A stimuli
B: n_epochs of training on task B stimuli
"""
n_train_trials = n_epochs * dim_input * 10
n_phase = 3  # A, B, A

# Preallocate results matrices
results = {
    "indexes": np.full((n_phase, n_train_trials), np.nan, dtype=np.float32),
    "inputs": np.full((n_phase, n_train_trials, dim_input), np.nan, dtype=np.float32),
    "labels": np.full((n_phase, n_train_trials, 2), np.nan, dtype=np.float32),
    "test_stim": np.full((n_phase, n_train_trials), np.nan, dtype=np.float32),
    "probes": np.full((n_phase, n_train_trials), np.nan, dtype=np.float32),
    "losses": np.full((n_phase, n_train_trials), np.nan, dtype=np.float32),
    "accuracy": np.full((n_phase, n_train_trials), np.nan, dtype=np.float32),
    "predictions": np.full((n_phase, n_train_trials, dim_output), np.nan, dtype=np.float32),
    "hiddens": np.full((n_phase, n_train_trials, dim_hidden), np.nan, dtype=np.float32),
    "embeddings": np.full((n_phase, n_train_trials, dim_hidden, dim_input), np.nan, dtype=np.float32),
    "readouts": np.full((n_phase, n_train_trials, dim_output, dim_hidden), np.nan, dtype=np.float32),
}



## Network

In [19]:
class simpleLinearNet(nn.Module):
    """A simple linear neural network with one hidden layer.
    
    Architecture:
    input -> hidden layer -> output
    All layers are fully connected with no bias terms.
    """
    def __init__(self, dim_input, dim_hidden, dim_output):
        super(simpleLinearNet, self).__init__()
        self.in_hid = nn.Linear(dim_input, dim_hidden, bias=False)
        self.hid_out = nn.Linear(dim_hidden, dim_output, bias=False)
        
    def forward(self, x):
        """Forward pass through the network."""
        hid = self.in_hid(x)
        out = self.hid_out(hid)
        return out, hid

def ex_initializer_(model, gamma=1e-3,mean=0.0):
    """
    In-place Re-initialization of weights

    Args:
        model: torch.nn.Module
        PyTorch neural net model
        
        gamma: float
        Initialization scale

    Returns:
        Nothing
    """
    for name, param in model.named_parameters():
        if "weight" in name:  
            n_out, n_in = param.shape
                
            if "hid_out" in name:  # Output layer weights
                std = 1e-3
            else:  # Hidden layer weights
                std = gamma
                
            nn.init.normal_(param, mean=mean, std=std)

def ordered_sweep(network, ranked_inputs):
    """Run network on ordered inputs for interpretable results."""
    preds, hids = network(ranked_inputs)
    return preds.detach().numpy().copy(), hids.detach().numpy().copy()

In [20]:
# Define the network
network = simpleLinearNet(dim_input, dim_hidden, dim_output)

# Initialize weights
ex_initializer_(network, gamma)

optimizer = torch.optim.SGD(network.parameters(), lr=lr)
loss_function = nn.MSELoss()

In [21]:
network

simpleLinearNet(
  (in_hid): Linear(in_features=12, out_features=50, bias=False)
  (hid_out): Linear(in_features=50, out_features=4, bias=False)
)

In [22]:
# Initial pass of the network
initial_preds, initial_hiddens = ordered_sweep(network, torch.from_numpy(ordered_inputs).float())

results["preds_pre_training"] = initial_preds
results["hiddens_pre_training"] = initial_hiddens

In [23]:
pd.DataFrame(initial_preds)

Unnamed: 0,0,1,2,3
0,1.050709e-06,8.42633e-06,-5.203451e-06,-5.267597e-06
1,-2.99046e-07,8.942328e-06,-6.44604e-06,-2.22167e-06
2,5.744795e-06,5.501563e-06,-3.022177e-06,-3.02439e-06
3,-3.562974e-06,1.052336e-05,1.010967e-06,-6.78174e-06
4,1.171294e-05,2.890949e-06,-4.761073e-06,-3.515678e-06
5,-6.708689e-06,3.999088e-06,-2.433872e-06,-1.128135e-05
6,-6.846979e-06,1.498318e-05,-2.232268e-06,5.640319e-06
7,5.477206e-06,3.443527e-06,1.22248e-06,-5.853476e-07
8,-6.309034e-06,7.198665e-06,3.653147e-07,2.291165e-06
9,-5.884042e-07,-4.763183e-07,-3.460113e-06,1.25746e-05


In [24]:
pd.DataFrame(initial_hiddens)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,-0.001099,-0.001867,0.000839,-1.5e-05,0.000547,-0.000598,0.002483,0.001643,-0.00128,8.2e-05,...,0.002132,-0.000674,-0.001196,-0.000292,-9.5e-05,0.000382,-0.000453,-0.001808,-0.001155,0.000997
1,5.6e-05,0.00122,0.000316,0.000388,-0.003395,0.00051,-0.000424,-0.000273,0.000444,0.000615,...,0.000365,-0.000855,-0.001904,-0.00078,3.3e-05,-0.000516,7.3e-05,-0.001038,-0.000938,8.5e-05
2,-0.001874,-0.000335,-0.000608,2.2e-05,-0.000312,-0.001386,-8.4e-05,-0.00183,0.000106,0.00043,...,-0.000458,-0.000353,0.000877,0.000956,-0.001002,-0.000706,0.001458,-0.000365,0.000422,0.000775
3,0.000107,0.00047,0.00199,-0.001695,0.000276,0.001973,-0.001517,0.000486,-0.001194,0.000456,...,0.00014,0.001528,0.000449,0.000501,0.000195,0.000872,-0.000598,0.0005,-1.1e-05,3.5e-05
4,-0.002238,0.000742,-0.000426,-0.000222,0.001374,0.000875,-0.00012,0.001852,0.000364,0.000578,...,0.000176,-0.00022,0.000232,-0.001109,-0.000736,0.000733,-0.000893,-0.003072,0.000361,-0.000922
5,-0.002044,0.000922,0.000228,-0.000266,-0.000422,-0.000686,-0.00184,-0.000806,-0.000595,0.000505,...,-0.001072,0.000289,5.8e-05,0.001678,0.001661,0.000411,-4.9e-05,0.000489,-0.001203,-0.000291
6,2.9e-05,0.001023,-0.000884,-0.00158,-0.002166,5.1e-05,2.6e-05,0.000749,-0.00098,0.000246,...,0.000643,-0.000167,-0.001585,-0.0007,-0.000126,0.000181,-0.000192,-0.000947,0.000155,0.000455
7,-0.000696,0.000968,-0.000193,0.001268,-0.001029,-0.001047,-0.001812,-1.1e-05,-0.001542,-0.001274,...,0.000634,2e-06,-0.001137,-0.000761,0.001435,-0.000357,-0.000146,0.001629,0.001632,0.000655
8,-0.001891,-0.000977,-0.000832,0.000799,-0.000965,-0.001495,-0.0016,-0.000468,0.001563,0.001208,...,0.001349,-0.000393,-0.001336,0.001304,-0.001281,4e-06,0.002308,0.001004,-0.000228,-0.000332
9,-0.00169,-0.000787,0.001698,0.001076,-0.000976,0.000247,0.000619,-0.000171,0.000168,-0.000395,...,-0.000113,-0.000959,0.000938,-0.000533,-0.000664,-0.001332,0.002711,0.001121,-0.001026,0.001627


### Training function

In [25]:
def train_participant_schedule(network, trainloader, n_epochs, loss_function, optimizer, do_update, do_test):
    """
    Train the network on x-y coordinates 
    Returns:
        tuple: Various metrics including indexes, inputs, labels, probes, losses, accuracy, predictions, hiddens, embeddings, readouts.
    """
    # Initialize storage lists
    metrics = {
        "indexes": [],
        "losses": [],
        "accuracy": [],
        "predictions": [],
        "hiddens": [],
        "embeddings": [],
        "readouts": [],
        "probes": [],
        "test_stim":[],
        "labels": [],
        "inputs": [],
    }

    for epoch in range(n_epochs):
        for batch_idx, data in enumerate(trainloader):
            # Reset gradients
            optimizer.zero_grad()

            # Extract batch data
            index = data['stim_index']
            input = batch_to_torch(data['input'])
            label_x = batch_to_torch(data['label_x'])
            label_y = batch_to_torch(data['label_y'])
            feature_probe = batch_to_torch(data['feature_probe'])
            test_stim = batch_to_torch(data['test_stim'])
            
                    
            joined_label = torch.cat((label_x.unsqueeze(1), label_y.unsqueeze(1)), dim=1)
            radians_label = math.atan2(label_x, label_y)

            # Forward pass
            out, hid = network(input)

            # Calculate loss based on feature probe
            if feature_probe == 0:
                loss = loss_function(out[:, :2], joined_label)
                pred_rads = math.atan2(out[:, 0].detach().numpy(),out[:, 1].detach().numpy())
                accuracy = compute_accuracy(pred_rads, radians_label)
                
            elif feature_probe == 1:
                loss = loss_function(out[:, 2:4], joined_label)
                pred_rads = math.atan2(out[:, 2].detach().numpy(),out[:, 3].detach().numpy())
                accuracy = compute_accuracy(pred_rads, radians_label)
                
            else:
                raise ValueError("Undefined loss setting for feature_probe.")

            # Update network if required
            if do_update == 1 and do_test==1 and test_stim.numpy() == 0:
              loss.backward()
              optimizer.step()
            elif do_update == 1 and do_test ==0:
              loss.backward()
              optimizer.step()
            elif do_update == 2 and feature_probe == 0:  # In C, only update for feature 0 
              loss.backward()
              optimizer.step()

            # Store metrics
            metrics["indexes"].append(index)
            metrics["inputs"].append(input.numpy())
            metrics["labels"].append(joined_label.numpy())
            metrics["probes"].append(feature_probe.numpy())
            metrics["test_stim"].append(test_stim.numpy())
            metrics["losses"].append(loss.item())
            metrics["accuracy"].append(accuracy)
            metrics["predictions"].append(np.expand_dims(out.detach().numpy(), axis=1))
            metrics["hiddens"].append(hid.detach().numpy())
            metrics["embeddings"].append(network.in_hid.weight.detach().numpy())
            metrics["readouts"].append(network.hid_out.weight.detach().numpy())

    # Convert lists to arrays where applicable
    metrics = {key: np.squeeze(value) for key, value in metrics.items()}
    
    return (
        metrics["indexes"],
        metrics["inputs"],
        metrics["labels"],
        metrics["probes"],
        metrics["test_stim"],
        metrics["losses"],
        metrics["accuracy"],
        metrics["predictions"],
        metrics["hiddens"],
        metrics["embeddings"],
        metrics["readouts"],
    )

# Training Phases


In [26]:
# Training Phases
phases = [
    (0, trainloader_A1, 1),
    (1, trainloader_B, 1),
    (2, trainloader_A2, 2),
]

## Train Phase A1

In [27]:
phase = 0
loader =  trainloader_A1
do_update = 1 # Controls how updates are applied (0 = no update, 1 = standard, 2 = conditional on feature_probe).

(
    results["indexes"][phase, :],
    results["inputs"][phase, :, :],
    results["labels"][phase, :, :],
    results["probes"][phase, :],
    results["test_stim"][phase, :],
    results["losses"][phase, :],
    results["accuracy"][phase, :],
    results["predictions"][phase, :, :],
    results["hiddens"][phase, :, :],
    results["embeddings"][phase, :, :, :],
    results["readouts"][phase, :, :, :],
) = train_participant_schedule(
    network, loader, n_epochs, loss_function, optimizer, do_update, do_test
)


  pred_rads = math.atan2(out[:, 0].detach().numpy(),out[:, 1].detach().numpy())
  pred_rads = math.atan2(out[:, 2].detach().numpy(),out[:, 3].detach().numpy())


In [28]:
# Post-phase ordered sweep
post_preds, post_hiddens = ordered_sweep(network, torch.from_numpy(ordered_inputs).float())
results[f"preds_post_phase_{phase}"] = post_preds
results[f"hiddens_post_phase_{phase}"] = post_hiddens

In [29]:
pd.DataFrame(results[f"preds_post_phase_{phase}"])

Unnamed: 0,0,1,2,3
0,0.591986,0.805874,-0.708005,0.706234
1,0.765126,0.643932,-0.52282,0.852373
2,0.207461,-0.976264,0.989735,0.054029
3,-0.493007,-0.870024,0.785908,-0.618343
4,-0.937956,-0.348717,0.206081,-0.978041
5,-0.303848,0.954864,-0.986702,-0.15462
6,-0.000436,0.001425,-0.00141,-0.000171
7,0.000244,0.000852,-0.000821,0.000353
8,0.000267,0.000858,-0.000881,0.000359
9,0.001867,0.000288,-9.5e-05,0.001851


In [30]:
pd.DataFrame(results[f"hiddens_post_phase_{phase}"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,-0.026242,0.047492,0.056127,0.169849,-0.32081,-0.141765,0.187153,0.119936,-0.043548,0.145444,...,0.099333,-0.131288,-0.145031,-0.126503,0.019012,-0.065352,-0.010684,0.045975,-0.239439,0.131282
1,0.053478,-0.005502,0.023143,0.175263,-0.344569,-0.133468,0.216624,0.115962,-0.033178,0.123241,...,0.117171,-0.155559,-0.122068,-0.157197,-0.003166,-0.099921,0.012425,0.078904,-0.17273,0.151565
2,0.253063,-0.211153,-0.13725,-0.070964,0.125848,0.093585,-0.00992,-0.095244,0.055665,-0.148952,...,0.005656,0.001018,0.159754,-0.009441,-0.081132,-0.078125,0.079472,0.070946,0.344571,-0.009555
3,0.065177,-0.076661,-0.068891,-0.165286,0.304497,0.14379,-0.165612,-0.116422,0.044693,-0.153498,...,-0.08494,0.116924,0.153393,0.108467,-0.029987,0.047814,0.02107,-0.029686,0.267424,-0.116578
4,-0.162844,0.086907,0.025392,-0.165466,0.338214,0.111631,-0.24253,-0.100265,0.018321,-0.07778,...,-0.13341,0.174098,0.075042,0.183528,0.034075,0.138752,-0.044348,-0.121218,0.060543,-0.168103
5,-0.285882,0.218819,0.139732,0.066274,-0.09041,-0.095301,-0.022227,0.064916,-0.051473,0.150787,...,-0.020644,0.019011,-0.152721,0.043673,0.088119,0.084347,-0.083162,-0.087035,-0.348607,-0.009081
6,2.9e-05,0.001023,-0.000884,-0.00158,-0.002166,5.1e-05,2.6e-05,0.000749,-0.00098,0.000246,...,0.000643,-0.000167,-0.001585,-0.0007,-0.000126,0.000181,-0.000192,-0.000947,0.000155,0.000455
7,-0.000696,0.000968,-0.000193,0.001268,-0.001029,-0.001047,-0.001812,-1.1e-05,-0.001542,-0.001274,...,0.000634,2e-06,-0.001137,-0.000761,0.001435,-0.000357,-0.000146,0.001629,0.001632,0.000655
8,-0.001891,-0.000977,-0.000832,0.000799,-0.000965,-0.001495,-0.0016,-0.000468,0.001563,0.001208,...,0.001349,-0.000393,-0.001336,0.001304,-0.001281,4e-06,0.002308,0.001004,-0.000228,-0.000332
9,-0.00169,-0.000787,0.001698,0.001076,-0.000976,0.000247,0.000619,-0.000171,0.000168,-0.000395,...,-0.000113,-0.000959,0.000938,-0.000533,-0.000664,-0.001332,0.002711,0.001121,-0.001026,0.001627


## Train Phase B

In [31]:
phase = 1
loader =  trainloader_B
do_update = 1 # Controls how updates are applied (0 = no update, 1 = standard, 2 = conditional on feature_probe).

(
    results["indexes"][phase, :],
    results["inputs"][phase, :, :],
    results["labels"][phase, :, :],
    results["probes"][phase, :],
    results["test_stim"][phase, :],
    results["losses"][phase, :],
    results["accuracy"][phase, :],
    results["predictions"][phase, :, :],
    results["hiddens"][phase, :, :],
    results["embeddings"][phase, :, :, :],
    results["readouts"][phase, :, :, :],
) = train_participant_schedule(
    network, loader, n_epochs, loss_function, optimizer, do_update, do_test
)

  pred_rads = math.atan2(out[:, 0].detach().numpy(),out[:, 1].detach().numpy())
  pred_rads = math.atan2(out[:, 2].detach().numpy(),out[:, 3].detach().numpy())


In [32]:
# Post-phase ordered sweep
post_preds, post_hiddens = ordered_sweep(network, torch.from_numpy(ordered_inputs).float())
results[f"preds_post_phase_{phase}"] = post_preds
results[f"hiddens_post_phase_{phase}"] = post_hiddens

In [33]:
pd.DataFrame(results[f"preds_post_phase_{phase}"])

Unnamed: 0,0,1,2,3
0,0.712918,0.983872,-0.867005,0.851413
1,0.956137,0.763754,-0.61269,1.059303
2,0.359081,-1.262264,1.299389,0.165426
3,-0.576346,-1.073404,0.976168,-0.729589
4,-1.210592,-0.373422,0.188087,-1.253103
5,-0.485371,1.242888,-1.303052,-0.29552
6,0.620901,0.783886,-0.682487,0.7309
7,0.869502,-0.493961,0.618106,0.786075
8,0.374456,-0.927277,0.972703,0.231929
9,-0.425104,-0.905161,0.831583,-0.555375


In [34]:
pd.DataFrame(results[f"hiddens_post_phase_{phase}"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,-0.026242,0.047492,0.056127,0.169849,-0.32081,-0.141765,0.187153,0.119936,-0.043548,0.145444,...,0.099333,-0.131288,-0.145031,-0.126503,0.019012,-0.065352,-0.010684,0.045975,-0.239439,0.131282
1,0.053478,-0.005502,0.023143,0.175263,-0.344569,-0.133468,0.216624,0.115962,-0.033178,0.123241,...,0.117171,-0.155559,-0.122068,-0.157197,-0.003166,-0.099921,0.012425,0.078904,-0.17273,0.151565
2,0.253063,-0.211153,-0.13725,-0.070964,0.125848,0.093585,-0.00992,-0.095244,0.055665,-0.148952,...,0.005656,0.001018,0.159754,-0.009441,-0.081132,-0.078125,0.079472,0.070946,0.344571,-0.009555
3,0.065177,-0.076661,-0.068891,-0.165286,0.304497,0.14379,-0.165612,-0.116422,0.044693,-0.153498,...,-0.08494,0.116924,0.153393,0.108467,-0.029987,0.047814,0.02107,-0.029686,0.267424,-0.116578
4,-0.162844,0.086907,0.025392,-0.165466,0.338214,0.111631,-0.24253,-0.100265,0.018321,-0.07778,...,-0.13341,0.174098,0.075042,0.183528,0.034075,0.138752,-0.044348,-0.121218,0.060543,-0.168103
5,-0.285882,0.218819,0.139732,0.066274,-0.09041,-0.095301,-0.022227,0.064916,-0.051473,0.150787,...,-0.020644,0.019011,-0.152721,0.043673,0.088119,0.084347,-0.083162,-0.087035,-0.348607,-0.009081
6,-0.009257,0.034482,0.040479,0.138709,-0.271529,-0.115428,0.157772,0.099575,-0.035336,0.116823,...,0.084029,-0.111819,-0.118279,-0.110028,0.012869,-0.058124,-0.00582,0.042946,-0.187886,0.111367
7,0.23258,-0.163411,-0.090455,0.050791,-0.119294,-0.00773,0.128359,0.008533,0.018492,-0.042694,...,0.078448,-0.096656,0.045336,-0.112088,-0.063393,-0.118488,0.068057,0.108272,0.161432,0.088632
8,0.209406,-0.166806,-0.106365,-0.043915,0.064477,0.065333,0.013725,-0.055161,0.040831,-0.108672,...,0.017375,-0.015222,0.113067,-0.027912,-0.066549,-0.067324,0.06591,0.067289,0.258903,0.007319
9,0.061666,-0.072648,-0.061012,-0.129341,0.244079,0.115487,-0.1288,-0.098136,0.039544,-0.127874,...,-0.066442,0.089598,0.129702,0.084278,-0.028798,0.030307,0.024174,-0.018599,0.226109,-0.090189


## Train Phase A2

In [35]:
phase = 2
loader =  trainloader_A2
do_update = 2  # Controls how updates are applied (0 = no update, 1 = standard, 2 = conditional on feature_probe).

(
    results["indexes"][phase, :],
    results["inputs"][phase, :, :],
    results["labels"][phase, :, :],
    results["probes"][phase, :],
    results["test_stim"][phase, :],
    results["losses"][phase, :],
    results["accuracy"][phase, :],
    results["predictions"][phase, :, :],
    results["hiddens"][phase, :, :],
    results["embeddings"][phase, :, :, :],
    results["readouts"][phase, :, :, :],
) = train_participant_schedule(
    network, loader, n_epochs, loss_function, optimizer, do_update, do_test
)

  pred_rads = math.atan2(out[:, 0].detach().numpy(),out[:, 1].detach().numpy())
  pred_rads = math.atan2(out[:, 2].detach().numpy(),out[:, 3].detach().numpy())


In [36]:
# Post-phase ordered sweep
post_preds, post_hiddens = ordered_sweep(network, torch.from_numpy(ordered_inputs).float())
results[f"preds_post_phase_{phase}"] = post_preds
results[f"hiddens_post_phase_{phase}"] = post_hiddens

In [37]:
pd.DataFrame(results[f"preds_post_phase_{phase}"])

Unnamed: 0,0,1,2,3
0,0.592506,0.805566,-0.79923,0.784924
1,0.765026,0.644001,-0.58147,0.960422
2,0.209978,-0.977704,1.149105,0.106495
3,-0.493809,-0.86957,0.891572,-0.680684
4,-0.937027,-0.349262,0.212223,-1.117588
5,-0.301488,0.953472,-1.143899,-0.216881
6,0.55769,0.697243,-0.682487,0.7309
7,0.737688,-0.417046,0.618106,0.786075
8,0.298254,-0.805157,0.972703,0.231929
9,-0.391848,-0.799949,0.831583,-0.555375


In [38]:
pd.DataFrame(results[f"hiddens_post_phase_{phase}"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,40,41,42,43,44,45,46,47,48,49
0,-0.024378,0.043548,0.051762,0.156744,-0.295575,-0.130796,0.17262,0.110429,-0.040127,0.13414,...,0.091636,-0.121076,-0.133633,-0.116465,0.017496,-0.060324,-0.009823,0.042268,-0.220852,0.121037
1,0.043951,-0.000946,0.023658,0.160758,-0.315924,-0.122995,0.19718,0.107227,-0.031101,0.114889,...,0.106355,-0.141591,-0.114012,-0.142846,-0.00139,-0.089413,0.009833,0.070369,-0.16365,0.137983
2,0.217532,-0.183142,-0.119936,-0.067509,0.121353,0.085648,-0.016244,-0.087335,0.049594,-0.133389,...,0.000739,0.006385,0.142959,-0.002286,-0.070292,-0.064504,0.068527,0.058529,0.304921,-0.013641
3,0.057664,-0.068557,-0.062182,-0.153158,0.281499,0.132804,-0.15384,-0.106732,0.040673,-0.140958,...,-0.078965,0.108805,0.140567,0.100796,-0.02686,0.04563,0.018503,-0.028467,0.244483,-0.108218
4,-0.137959,0.07064,0.017947,-0.150383,0.307503,0.102897,-0.217968,-0.093491,0.01837,-0.074473,...,-0.119193,0.156127,0.07256,0.164414,0.027765,0.121435,-0.03707,-0.106616,0.065955,-0.151077
5,-0.245785,0.187639,0.120795,0.064628,-0.089696,-0.088183,-0.012497,0.057576,-0.045164,0.135055,...,-0.013708,0.011125,-0.135602,0.033782,0.076052,0.067974,-0.07086,-0.07226,-0.30674,-0.002646
6,-0.009257,0.034482,0.040479,0.138709,-0.271529,-0.115428,0.157772,0.099575,-0.035336,0.116823,...,0.084029,-0.111819,-0.118279,-0.110028,0.012869,-0.058124,-0.00582,0.042946,-0.187886,0.111367
7,0.23258,-0.163411,-0.090455,0.050791,-0.119294,-0.00773,0.128359,0.008533,0.018492,-0.042694,...,0.078448,-0.096656,0.045336,-0.112088,-0.063393,-0.118488,0.068057,0.108272,0.161432,0.088632
8,0.209406,-0.166806,-0.106365,-0.043915,0.064477,0.065333,0.013725,-0.055161,0.040831,-0.108672,...,0.017375,-0.015222,0.113067,-0.027912,-0.066549,-0.067324,0.06591,0.067289,0.258903,0.007319
9,0.061666,-0.072648,-0.061012,-0.129341,0.244079,0.115487,-0.1288,-0.098136,0.039544,-0.127874,...,-0.066442,0.089598,0.129702,0.084278,-0.028798,0.030307,0.024174,-0.018599,0.226109,-0.090189


In [39]:
# dataframe ob indexes in results
pd.DataFrame(results['indexes'])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11990,11991,11992,11993,11994,11995,11996,11997,11998,11999
0,3.0,3.0,10.0,10.0,0.0,0.0,6.0,6.0,4.0,4.0,...,0.0,0.0,3.0,3.0,4.0,4.0,10.0,10.0,9.0,9.0
1,2.0,2.0,1.0,1.0,11.0,11.0,8.0,8.0,5.0,5.0,...,1.0,1.0,8.0,8.0,7.0,7.0,11.0,11.0,5.0,5.0
2,3.0,3.0,4.0,4.0,9.0,9.0,10.0,10.0,0.0,0.0,...,3.0,3.0,9.0,9.0,0.0,0.0,10.0,10.0,6.0,6.0


## Results

In [45]:
pd.DataFrame(participants).T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,295,296,297,298,299,300,301,302,303,304
0,study1_same_sub9,study2_near_sub43,study1_same_sub43,study2_near_sub57,study1_near_sub21,study1_near_sub35,study2_same_sub34,study1_near_sub34,study2_same_sub20,study1_same_sub56,...,study2_same_sub38,study2_near_sub72,study1_same_sub66,study2_near_sub66,study1_same_sub4,study2_far_sub46,study2_far_sub52,study2_same_sub1,study1_far_sub43,study1_far_sub57


In [46]:
participant_results['participant'] = participant

NameError: name 'participant_results' is not defined

In [None]:
participant_results

## save and clean

In [None]:
# Save results if requested
if dosave:
    file_path = f"{sim_folder}/sim_{participant}.npz"
    np.savez_compressed(file_path, **participant_results)

# Cleanup
del participant_results