In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy import stats
import seaborn as sns
import matplotlib.pyplot as plt

import multiprocessing as mp
import pickle 
import warnings 
warnings.filterwarnings('ignore')

from imports import*
from utils import *
from logistic_regression import *
from rnn import *

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [2]:
# num of agent
num_of_agents = 40

# num of block
num_of_block = 3

# num of trials 
num_of_trials = 200

# for cross valdation 
array = np.arange(num_of_block)
cv = [np.roll(array,i) for i in range(num_of_block)]
cv = np.array(cv)


def bce(y_hat,y_true):
    eps = 1e-7
    return -np.sum( y_true*np.log(y_hat+eps) + (1-y_true)*np.log(1-y_hat+eps) )

In [3]:
# upload data
def upload_data(num_of_block,num_of_agents,model):
    all_data = [] 
    for sim in range(1,num_of_block+1):
        data_per_block = []
        for agent in range(1,num_of_agents+1):
            data_per_block.append((pd.read_csv(f'../data/{model}/{model}_agent_{agent}_sim_{sim}.csv')))
        all_data.append(data_per_block)
        
    block_0 = all_data[0]
    block_1 = all_data[1]
    block_2 = all_data[2]

    all_blocks = [block_0,block_1,block_2]

    return all_blocks

In [9]:
def compute_weight_changes_between_states(network_state_t, network_state_t_plus_1):
    """
    Calculate weight changes between two consecutive network states.
    
    Args:
        network_state_t: Network state at time t
        network_state_t_plus_1: Network state at time t+1
        
    Returns:
        numpy.ndarray: Matrix where each row represents flattened weight changes
    """
    weight_changes_list = []
    
    # Iterate through corresponding parameters in both network states
    params_t = dict(network_state_t.named_parameters())
    params_t_plus_1 = dict(network_state_t_plus_1.named_parameters())

    for name in params_t.keys():     
        # Only consider weight matrices, not biases
        if 'weight' in name:
            # Calculate weight difference and convert to numpy
            weight_difference = params_t_plus_1[name].data - params_t[name].data
            # Flatten the weight difference matrix into a vector
            weight_changes_list.append(weight_difference.flatten().cpu().numpy())
            
    # Stack vectors vertically to create weight changes matrix
    return np.concatenate(weight_changes_list)  

def calculate_weight_changes_rank(net, network_states_per_epoch, test_loss, window_size=5):
    """
    Calculate the rank of weight changes matrix around the optimal epoch.
    
    Args:
        net: The trained neural network
        network_states_per_epoch: List of network states from training
        val_loss: Validation loss array from training
        window_size: Number of epochs to consider before and after optimal epoch
        
    Returns:
        tuple: (weight_changes_rank, optimal_epoch_idx)
    """
    # Find optimal epoch based on validation loss
    optimal_epoch_idx = np.argmin(test_loss)
    
    # Calculate weight changes around optimal epoch
    weight_changes_matrix_list = []
    
    # Consider window_size epochs before and after the optimal epoch
    start_epoch = max(0, optimal_epoch_idx - window_size)
    end_epoch = min(len(network_states_per_epoch) - 1, optimal_epoch_idx + window_size)
    
    for t in range(start_epoch, end_epoch):
        # Create temporary networks for states at t and t+1
        network_at_t = GRU_NN(INPUT_SIZE, net.hidden_size, 1, OUTPUT_SIZE).to(device)
        network_at_t_plus_1 = GRU_NN(INPUT_SIZE, net.hidden_size, 1, OUTPUT_SIZE).to(device)
        
        # Load states
        network_at_t.load_state_dict(network_states_per_epoch[t])
        network_at_t_plus_1.load_state_dict(network_states_per_epoch[t + 1])
        
        # Compute weight changes between consecutive states
        weight_changes = compute_weight_changes_between_states(
            network_at_t, network_at_t_plus_1)
        weight_changes_matrix_list.append(weight_changes)
    
    # Combine all weight changes into single matrix
    complete_weight_changes_matrix = np.hstack(weight_changes_matrix_list)
    
    # Calculate rank of weight changes matrix
    weight_changes_rank = np.linalg.matrix_rank(complete_weight_changes_matrix)
    
    return weight_changes_rank, optimal_epoch_idx

In [5]:
N = num_of_agents

INPUT_SIZE = 4 # 3 for the action (one-hot format) and 1 for the reward of the chosen action
OUTPUT_SIZE = 3 # probabilities of choosing each action in the next trial
LERANING_RATE = 0.001

hidden_size = 5
num_layers = 1
epochs = 1000

In [6]:
def train_rnn_for_model(all_blocks, model):
    loss_train, loss_val, loss_test  = [], [], []
    ll_train, ll_val, ll_test = [], [], []
    ranks, optimal_epochs = [], []

    for n in tqdm(range(N)):
        network_states_per_epoch = []
        for train, val, test in cv:

            train_data = behavior_dataset(all_blocks[train][n])
            val_data = behavior_dataset(all_blocks[val][n])
            test_data = behavior_dataset(all_blocks[test][n])

            train_loader = DataLoader(train_data,shuffle=False,batch_size=len(train_data))
            val_loader = DataLoader(val_data,shuffle=False,batch_size=len(val_data))
            test_loader = DataLoader(test_data,shuffle=False,batch_size=len(test_data))
            
            rnn = GRU_NN(INPUT_SIZE, hidden_size, num_layers, OUTPUT_SIZE)
            rnn, train_loss, train_ll, val_loss, val_ll, test_loss, test_ll, network_states_per_epoch = train_model(rnn,
                                                                                    train_loader,
                                                                                    val_loader,
                                                                                    test_loader,
                                                                                    epochs=epochs,
                                                                                    lr=LERANING_RATE) 
            
            rank, optimal_epoch = calculate_weight_changes_rank(rnn, network_states_per_epoch, test_loss)
                                                                                                                                        
            loss_train.append(train_loss)
            loss_val.append(val_loss)
            loss_test.append(test_loss)
            
            ll_train.append(train_ll)
            ll_val.append(val_ll)
            ll_test.append(test_ll)

            ranks.append(rank)
            optimal_epochs.append(optimal_epoch)
            
        print('Done agent',n)
        
        
    with open(f'../results/{model}_loss_train.pickle', 'wb') as handle:
        pickle.dump(loss_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(f'../results/{model}_loss_val.pickle', 'wb') as handle:
        pickle.dump(loss_val, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(f'../results/{model}_loss_test.pickle', 'wb') as handle:
        pickle.dump(loss_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(f'../results/{model}_ll_train.pickle', 'wb') as handle:
        pickle.dump(ll_train, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(f'../results/{model}_ll_val.pickle', 'wb') as handle:
        pickle.dump(ll_val, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(f'../results/{model}_ll_test.pickle', 'wb') as handle:
        pickle.dump(ll_test, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(f'../results/{model}_ranks.pickle', 'wb') as handle:
        pickle.dump(ranks, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(f'../results/{model}_optimal_epochs.pickle', 'wb') as handle:
        pickle.dump(optimal_epochs, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    return loss_train, loss_val, loss_test, ll_train, ll_val, ll_test, ranks, optimal_epochs

In [13]:
def plot_results(model):
    # Load the results from pickle files
    file_paths = {
        "loss_train": f"../results/{model}_loss_train.pickle",
        "loss_val": f"../results/{model}_loss_val.pickle",
        "loss_test": f"../results/{model}_loss_test.pickle",
        "ll_train": f"../results/{model}_ll_train.pickle",
        "ll_val": f"../results/{model}_ll_val.pickle",
        "ll_test": f"../results/{model}_ll_test.pickle",
        # "ranks": f"../results/{model}_ranks.pickle",
        # "optimal_epochs": f"../results/{model}_optimal_epochs.pickle"
    }

    results = {}
    for key, path in file_paths.items():
        with open(path, 'rb') as handle:
            results[key] = pickle.load(handle)

    # Convert lists of lists to averaged lists per epoch
    epochs = len(results["loss_train"][0])  # Assuming all have the same epoch length
    avg_results = {key: [sum(epoch) / len(epoch) for epoch in zip(*values)] for key, values in results.items()}

    # Plotting loss and log-likelihood
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Loss plot
    axes[0].plot(range(epochs), avg_results["loss_train"], label="Train Loss", marker="o")
    axes[0].plot(range(epochs), avg_results["loss_val"], label="Validation Loss", marker="s")
    axes[0].plot(range(epochs), avg_results["loss_test"], label="Test Loss", marker="^")
    axes[0].set_title("Loss Over Epochs")
    axes[0].set_xlabel("Epochs")
    axes[0].set_ylabel("Loss")
    axes[0].legend()

    # Log-likelihood plot
    axes[1].plot(range(epochs), avg_results["ll_train"], label="Train Log-Likelihood", marker="o")
    axes[1].plot(range(epochs), avg_results["ll_val"], label="Validation Log-Likelihood", marker="s")
    axes[1].plot(range(epochs), avg_results["ll_test"], label="Test Log-Likelihood", marker="^")
    axes[1].set_title("Log-Likelihood Over Epochs")
    axes[1].set_xlabel("Epochs")
    axes[1].set_ylabel("Log-Likelihood")
    axes[1].legend()

    plt.tight_layout()
    plt.show()

In [None]:
models = ["none", "exploration_only", "noise_only", "both"]
optimal_epochs_per_model = {}
ranks_per_model = {}
test_losses_per_model = {}

for model in models:
    all_blocks = upload_data(num_of_block,num_of_agents,model)
    loss_train, loss_val, loss_test, ll_train, ll_val, ll_test, ranks, optimal_epochs = train_rnn_for_model(all_blocks, model)
    #plot_results(model)
    optimal_epochs_per_model[model] = optimal_epochs
    ranks_per_model[model] = ranks
    test_losses_per_model[model] = loss_test

  2%|▎         | 1/40 [03:29<2:16:05, 209.36s/it]

Done agent 0


  5%|▌         | 2/40 [06:58<2:12:38, 209.42s/it]

Done agent 1


  8%|▊         | 3/40 [10:31<2:10:10, 211.11s/it]

Done agent 2


 10%|█         | 4/40 [13:58<2:05:31, 209.21s/it]

Done agent 3


 12%|█▎        | 5/40 [17:22<2:00:57, 207.37s/it]

Done agent 4


 15%|█▌        | 6/40 [20:41<1:55:52, 204.48s/it]

Done agent 5


 18%|█▊        | 7/40 [24:04<1:52:17, 204.18s/it]

Done agent 6


 20%|██        | 8/40 [27:21<1:47:36, 201.77s/it]

Done agent 7


 22%|██▎       | 9/40 [30:37<1:43:16, 199.89s/it]

Done agent 8


 25%|██▌       | 10/40 [33:54<1:39:36, 199.21s/it]

Done agent 9


 28%|██▊       | 11/40 [37:18<1:36:59, 200.67s/it]

Done agent 10


 30%|███       | 12/40 [40:37<1:33:19, 199.97s/it]

Done agent 11


 32%|███▎      | 13/40 [43:56<1:29:50, 199.64s/it]

Done agent 12


 35%|███▌      | 14/40 [47:14<1:26:18, 199.17s/it]

Done agent 13


 38%|███▊      | 15/40 [50:32<1:22:56, 199.05s/it]

Done agent 14


 40%|████      | 16/40 [53:50<1:19:28, 198.68s/it]

Done agent 15


 42%|████▎     | 17/40 [57:08<1:16:01, 198.32s/it]

Done agent 16


 45%|████▌     | 18/40 [1:00:26<1:12:44, 198.39s/it]

Done agent 17


 48%|████▊     | 19/40 [1:03:44<1:09:25, 198.33s/it]

Done agent 18


 50%|█████     | 20/40 [1:07:02<1:06:02, 198.13s/it]

Done agent 19


 52%|█████▎    | 21/40 [1:10:27<1:03:22, 200.16s/it]

Done agent 20


 55%|█████▌    | 22/40 [1:13:45<59:52, 199.60s/it]  

Done agent 21


 57%|█████▊    | 23/40 [1:17:03<56:21, 198.93s/it]

Done agent 22


 60%|██████    | 24/40 [1:20:19<52:50, 198.14s/it]

Done agent 23


 62%|██████▎   | 25/40 [1:23:31<49:05, 196.40s/it]

Done agent 24


 65%|██████▌   | 26/40 [1:26:46<45:40, 195.75s/it]

Done agent 25


 68%|██████▊   | 27/40 [1:30:09<42:56, 198.21s/it]

Done agent 26


 70%|███████   | 28/40 [1:33:32<39:55, 199.62s/it]

Done agent 27


 72%|███████▎  | 29/40 [1:36:55<36:46, 200.57s/it]

Done agent 28


 75%|███████▌  | 30/40 [1:40:18<33:32, 201.26s/it]

Done agent 29


 78%|███████▊  | 31/40 [1:43:41<30:16, 201.84s/it]

Done agent 30


 80%|████████  | 32/40 [1:47:04<26:57, 202.19s/it]

Done agent 31


 82%|████████▎ | 33/40 [1:50:27<23:36, 202.35s/it]

Done agent 32


 85%|████████▌ | 34/40 [1:53:42<20:00, 200.01s/it]

Done agent 33


 88%|████████▊ | 35/40 [1:56:55<16:30, 198.08s/it]

Done agent 34


 90%|█████████ | 36/40 [2:00:09<13:07, 196.79s/it]

Done agent 35


 92%|█████████▎| 37/40 [2:03:23<09:48, 196.03s/it]

Done agent 36


 95%|█████████▌| 38/40 [2:06:37<06:30, 195.41s/it]

Done agent 37


 98%|█████████▊| 39/40 [2:09:51<03:15, 195.05s/it]

Done agent 38


100%|██████████| 40/40 [2:13:05<00:00, 199.63s/it]

Done agent 39



  2%|▎         | 1/40 [03:13<2:05:43, 193.41s/it]

Done agent 0


  5%|▌         | 2/40 [06:26<2:02:27, 193.35s/it]

Done agent 1


  8%|▊         | 3/40 [09:40<1:59:25, 193.67s/it]

Done agent 2


 10%|█         | 4/40 [12:54<1:56:10, 193.62s/it]

Done agent 3


 12%|█▎        | 5/40 [16:07<1:52:54, 193.54s/it]

Done agent 4


 15%|█▌        | 6/40 [19:21<1:49:47, 193.75s/it]

Done agent 5


 18%|█▊        | 7/40 [22:35<1:46:33, 193.74s/it]

Done agent 6


 20%|██        | 8/40 [25:49<1:43:24, 193.88s/it]

Done agent 7


 22%|██▎       | 9/40 [29:03<1:40:10, 193.89s/it]

Done agent 8


 25%|██▌       | 10/40 [32:17<1:36:57, 193.91s/it]

Done agent 9


 28%|██▊       | 11/40 [35:31<1:33:39, 193.77s/it]

Done agent 10


 30%|███       | 12/40 [38:44<1:30:18, 193.51s/it]

Done agent 11


 32%|███▎      | 13/40 [41:57<1:27:05, 193.54s/it]

Done agent 12


 35%|███▌      | 14/40 [45:11<1:23:53, 193.59s/it]

Done agent 13


 38%|███▊      | 15/40 [48:24<1:20:36, 193.45s/it]

Done agent 14


 40%|████      | 16/40 [51:44<1:18:10, 195.45s/it]

Done agent 15


In [None]:
# The bug was in plot resutls, I fixed it but I moved it to a new loop and code block so that all of the data will be saved before we procceed
for model in models:
    plot_results(model)

#plot optimal epochs by model
fig, ax = plt.subplots(figsize=(10, 5))
for model, optimal_epochs in optimal_epochs_per_model.items():
    sns.histplot(optimal_epochs, kde=True, label=model, ax=ax)
ax.set_title("Optimal Epochs")
ax.set_xlabel("Epochs")
ax.set_ylabel("Frequency")
plt.legend()
plt.show()

#plot ranks by model
fig, ax = plt.subplots(figsize=(10, 5))
for model, ranks in ranks_per_model.items():
    sns.histplot(ranks, kde=True, label=model, ax=ax)
ax.set_title("Rank of Weight Changes Matrix Around Optimal Epoch")
ax.set_xlabel("Rank")
ax.set_ylabel("Frequency")
plt.legend()
plt.show()

#plot test losses by epoch, grouped by model
# fig, ax = plt.subplots(figsize=(10, 5))
# for model, test_losses in test_losses_per_model.items():
#     for test_loss in test_losses:
#         ax.plot(range(len(test_loss)), test_loss, label=model)
# ax.set_title("Test Loss Over Epochs")
