In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt

import multiprocessing as mp
import pickle 
import warnings 
warnings.filterwarnings('ignore')

from imports import*
from utils import *
from logistic_regression import *
from rnn import *

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [2]:
# num of agent
num_of_agents = 40

# num of block
num_of_block = 3

# num of trials 
num_of_trials = 200

# for cross valdation 
array = np.arange(num_of_block)
cv = [np.roll(array,i) for i in range(num_of_block)]
cv = np.array(cv)


def bce(y_hat,y_true):
    eps = 1e-7
    return -np.sum( y_true*np.log(y_hat+eps) + (1-y_true)*np.log(1-y_hat+eps) )

In [3]:
# upload data
def upload_data(num_of_block,num_of_agents,model):
    all_data = [] 
    for sim in range(1,num_of_block+1):
        data_per_block = []
        for agent in range(1,num_of_agents+1):
            data_per_block.append((pd.read_csv(f'../data/{model}/{model}_agent_{agent}_sim_{sim}.csv')))
        all_data.append(data_per_block)
        
    block_0 = all_data[0]
    block_1 = all_data[1]
    block_2 = all_data[2]

    all_blocks = [block_0,block_1,block_2]

    return all_blocks

In [None]:
def compute_weight_changes_between_states(network_state_t, network_state_t_plus_1):
    """
    Calculate weight changes between two consecutive network states.
    
    Args:
        network_state_t: Network state at time t
        network_state_t_plus_1: Network state at time t+1
        
    Returns:
        numpy.ndarray: Matrix where each row represents flattened weight changes
    """
    weight_changes_list = []
    
    # Iterate through corresponding parameters in both network states
    params_t = dict(network_state_t.named_parameters())
    params_t_plus_1 = dict(network_state_t_plus_1.named_parameters())

    for name in params_t.keys():     
        # Only consider weight matrices, not biases
        if 'weight' in name:
            # Calculate weight difference and convert to numpy
            weight_difference = params_t_plus_1[name].data - params_t[name].data
            # Flatten the weight difference matrix into a vector
            weight_changes_list.append(weight_difference.flatten().cpu().numpy())
            
    # Stack vectors vertically to create weight changes matrix
    return np.concatenate(weight_changes_list)  

def calculate_weight_changes_rank(net, network_states_per_epoch, test_loss, window_size=5):
    """
    Calculate the rank of weight changes matrix around the optimal epoch.
    
    Args:
        net: The trained neural network
        network_states_per_epoch: List of network states from training
        val_loss: Validation loss array from training
        window_size: Number of epochs to consider before and after optimal epoch
        
    Returns:
        tuple: (weight_changes_rank, optimal_epoch_idx)
    """
    # Find optimal epoch based on validation loss
    optimal_epoch_idx = np.argmin(test_loss)
    
    # Calculate weight changes around optimal epoch
    weight_changes_matrix_list = []
    
    # Consider window_size epochs before and after the optimal epoch
    start_epoch = max(0, optimal_epoch_idx - window_size)
    end_epoch = min(len(network_states_per_epoch) - 1, optimal_epoch_idx + window_size)
    
    for t in range(start_epoch, end_epoch):
        # Create temporary networks for states at t and t+1
        network_at_t = GRU_NN(INPUT_SIZE, net.hidden_size, 1, OUTPUT_SIZE).to(device)
        network_at_t_plus_1 = GRU_NN(INPUT_SIZE, net.hidden_size, 1, OUTPUT_SIZE).to(device)
        
        # Load states
        network_at_t.load_state_dict(network_states_per_epoch[t])
        network_at_t_plus_1.load_state_dict(network_states_per_epoch[t + 1])
        
        # Compute weight changes between consecutive states
        weight_changes = compute_weight_changes_between_states(
            network_at_t, network_at_t_plus_1)
        weight_changes_matrix_list.append(weight_changes)
    
    # Combine all weight changes into single matrix
    complete_weight_changes_matrix = np.hstack(weight_changes_matrix_list)
    
    # Calculate rank of weight changes matrix
    weight_changes_rank = np.linalg.matrix_rank(complete_weight_changes_matrix)
    
    return weight_changes_rank, optimal_epoch_idx

In [5]:
N = num_of_agents

INPUT_SIZE = 4 # 3 for the action (one-hot format) and 1 for the reward of the chosen action
OUTPUT_SIZE = 3 # probabilities of choosing each action in the next trial
LERANING_RATE = 0.001

hidden_size = 5
num_layers = 1
epochs = 1000

In [6]:
def train_rnn_for_model(all_blocks, model):
    loss_train, loss_val, loss_test  = [], [], []
    ll_train, ll_val, ll_test = [], [], []
    ranks, optimal_epochs = [], []

    for n in tqdm(range(N)):
        network_states_per_epoch = []
        for train, val, test in cv:

            train_data = behavior_dataset(all_blocks[train][n])
            val_data = behavior_dataset(all_blocks[val][n])
            test_data = behavior_dataset(all_blocks[test][n])

            train_loader = DataLoader(train_data,shuffle=False,batch_size=len(train_data))
            val_loader = DataLoader(val_data,shuffle=False,batch_size=len(val_data))
            test_loader = DataLoader(test_data,shuffle=False,batch_size=len(test_data))
            
            rnn = GRU_NN(INPUT_SIZE, hidden_size, num_layers, OUTPUT_SIZE)
            rnn, train_loss, train_ll, val_loss, val_ll, test_loss, test_ll, network_states_per_epoch = train_model(rnn,
                                                                                    train_loader,
                                                                                    val_loader,
                                                                                    test_loader,
                                                                                    epochs=epochs,
                                                                                    lr=LERANING_RATE) 
            
            rank, optimal_epoch = calculate_weight_changes_rank(rnn, network_states_per_epoch, test_loss)
                                                                                                                                        
            loss_train.append(train_loss)
            loss_val.append(val_loss)
            loss_test.append(test_loss)
            
            ll_train.append(train_ll)
            ll_val.append(val_ll)
            ll_test.append(test_ll)

            ranks.append(rank)
            optimal_epochs.append(optimal_epoch)
            
        print('Done agent',n)
        
        
    with open(f'../results/{model}_loss_train.pickle', 'wb') as handle:
        pickle.dump(loss_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(f'../results/{model}_loss_val.pickle', 'wb') as handle:
        pickle.dump(loss_val, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(f'../results/{model}_loss_test.pickle', 'wb') as handle:
        pickle.dump(loss_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(f'../results/{model}_ll_train.pickle', 'wb') as handle:
        pickle.dump(ll_train, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(f'../results/{model}_ll_val.pickle', 'wb') as handle:
        pickle.dump(ll_val, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(f'../results/{model}_ll_test.pickle', 'wb') as handle:
        pickle.dump(ll_test, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(f'../results/{model}_ranks.pickle', 'wb') as handle:
        pickle.dump(ranks, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(f'../results/{model}_optimal_epochs.pickle', 'wb') as handle:
        pickle.dump(optimal_epochs, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    return loss_train, loss_val, loss_test, ll_train, ll_val, ll_test, ranks, optimal_epochs

In [7]:
def plot_results(model):
    # Load the results from pickle files
    file_paths = {
        "loss_train": f"../results/{model}_loss_train.pickle",
        "loss_val": f"../results/{model}_loss_val.pickle",
        "loss_test": f"../results/{model}_loss_test.pickle",
        "ll_train": f"../results/{model}_ll_train.pickle",
        "ll_val": f"../results/{model}_ll_val.pickle",
        "ll_test": f"../results/{model}_ll_test.pickle",
        "ranks": f"../results/{model}_ranks.pickle",
        "optimal_epochs": f"../results/{model}_optimal_epochs.pickle"
    }

    results = {}
    for key, path in file_paths.items():
        with open(path, 'rb') as handle:
            results[key] = pickle.load(handle)

    # Convert lists of lists to averaged lists per epoch
    epochs = len(results["loss_train"][0])  # Assuming all have the same epoch length
    avg_results = {key: [sum(epoch) / len(epoch) for epoch in zip(*values)] for key, values in results.items()}

    # Plotting loss and log-likelihood
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Loss plot
    axes[0].plot(range(epochs), avg_results["loss_train"], label="Train Loss", marker="o")
    axes[0].plot(range(epochs), avg_results["loss_val"], label="Validation Loss", marker="s")
    axes[0].plot(range(epochs), avg_results["loss_test"], label="Test Loss", marker="^")
    axes[0].set_title("Loss Over Epochs")
    axes[0].set_xlabel("Epochs")
    axes[0].set_ylabel("Loss")
    axes[0].legend()

    # Log-likelihood plot
    axes[1].plot(range(epochs), avg_results["ll_train"], label="Train Log-Likelihood", marker="o")
    axes[1].plot(range(epochs), avg_results["ll_val"], label="Validation Log-Likelihood", marker="s")
    axes[1].plot(range(epochs), avg_results["ll_test"], label="Test Log-Likelihood", marker="^")
    axes[1].set_title("Log-Likelihood Over Epochs")
    axes[1].set_xlabel("Epochs")
    axes[1].set_ylabel("Log-Likelihood")
    axes[1].legend()

    plt.tight_layout()
    plt.show()

In [8]:
models = ["none", "exploration_only", "noise_only", "both"]
optimal_epochs_per_model = {}
ranks_per_model = {}
test_losses_per_model = {}

for model in models:
    all_blocks = upload_data(num_of_block,num_of_agents,model)
    loss_train, loss_val, loss_test, ll_train, ll_val, ll_test, ranks, optimal_epochs = train_rnn_for_model(all_blocks, model)
    plot_results(model)
    optimal_epochs_per_model[model] = optimal_epochs
    ranks_per_model[model] = ranks
    test_losses_per_model[model] = loss_test

#plot optimal epochs by model
fig, ax = plt.subplots(figsize=(10, 5))
for model, optimal_epochs in optimal_epochs_per_model.items():
    sns.histplot(optimal_epochs, kde=True, label=model, ax=ax)
ax.set_title("Optimal Epochs")
ax.set_xlabel("Epochs")
ax.set_ylabel("Frequency")
plt.legend()
plt.show()

#plot ranks by model
fig, ax = plt.subplots(figsize=(10, 5))
for model, ranks in ranks_per_model.items():
    sns.histplot(ranks, kde=True, label=model, ax=ax)
ax.set_title("Rank of Weight Changes Matrix Around Optimal Epoch")
ax.set_xlabel("Rank")
ax.set_ylabel("Frequency")
plt.legend()
plt.show()

#plot test losses by epoch, grouped by model
# fig, ax = plt.subplots(figsize=(10, 5))
# for model, test_losses in test_losses_per_model.items():
#     for test_loss in test_losses:
#         ax.plot(range(len(test_loss)), test_loss, label=model)
# ax.set_title("Test Loss Over Epochs")





  0%|          | 0/40 [01:09<?, ?it/s]


ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 60 and the array at index 1 has size 75