# DQN notebook

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import neat
import matplotlib.pyplot as plt
from matplotlib import cm

import pickle
import multimodal_mazes
from tqdm import tqdm

import itertools
import torch
import torch.nn as nn
import torch.optim as optim

## Ideas 
* Optuna: https://colab.research.google.com/github/araffin/tools-for-robotic-rl-icra2022/blob/main/notebooks/optuna_lab.ipynb#scrollTo=E0yEokTDxhrC 

## Working

In [None]:
agnt = multimodal_mazes.AgentDQN(location=[5,5], channels=[1,1], sensor_noise_scale=0.05, n_hidden_units=8, wm_flags=np.array([0,0,0,0,0,0,1]))
agnt

In [None]:
maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=1000, noise_scale=0.0, gaps=0)
agnt.generate_policy(maze, n_steps=6)

In [None]:
plt.imshow(agnt.output_to_hidden.weight.grad)
plt.colorbar()

## Fitness vs noise

In [None]:
# Fitness vs noise
noises = np.linspace(start=0.0, stop=0.5, num=21)

# Small set
wm_flags = np.array([[0,0,0,0,0,0,0], [0,0,0,0,0,0,0], [1,1,1,1,0,0,0],[1,1,1,1,1,1,1]])
colors = ["xkcd:gray", [0.0, 0.0, 0.0, 0.5], [0.039, 0.73, 0.71, 1], list(np.array([24, 156, 196, 255]) / 255)]

# Full set 
# wm_flags = np.array(list(itertools.product([0,1], repeat=7)))
# wm_flags = np.vstack((wm_flags[0], wm_flags))
# colors = cm.get_cmap("plasma", len(wm_flags)).colors.tolist()

results = np.zeros((len(noises), len(wm_flags)))

# Generate mazes
maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=100000, noise_scale=0.0, gaps=2)

maze_test = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze_test.generate(number=1000, noise_scale=0.0, gaps=2)

# Run
for b, wm_flag in enumerate(tqdm(wm_flags)): 

    # Control architecture 
    if b != 1: 
        n_hidden_units = 8
    else:
        n_hidden_units = 34 
    
    # Train
    agnt = multimodal_mazes.AgentDQN(location=[5,5], channels=[1,1], sensor_noise_scale=0.05, n_hidden_units=n_hidden_units, wm_flags=wm_flag)
    agnt.generate_policy(maze, n_steps=6) 

    # Test 
    for a, noise in enumerate(noises):
        results[a,b] = multimodal_mazes.eval_fitness(genome=None, config=None, channels=[1,1], sensor_noise_scale=noise, drop_connect_p=0.0, maze=maze_test, n_steps=6, agnt=agnt)

# Plotting
plt.plot([0.05, 0.05], [0,1], ':', color='k', alpha=0.5, label='Training noise')
for b, wm_flag in enumerate(wm_flags): 
    plt.plot(noises, results[:,b], color=colors[b], label=wm_flag)

plt.ylim([0, 1.05])
plt.ylabel('Fitness')
plt.xlabel('Sensor noise')
plt.legend()

In [None]:
# Fitness vs noise AUC 
auc = np.trapz(y=results.T, x=noises, axis=1)
idxs = np.argsort(auc)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5*3,5), sharex=False, sharey=True)
for b, idx in enumerate(idxs): 
    ml, sl, _ = plt.stem(b, auc[idx])
    ml.set_color('k')
    sl.set_color('k')
# plt.xticks(range(len(wm_flags)), policies, rotation='vertical')
plt.ylabel('AUC');

## Loading

In [None]:
import os
paths = ['../Results/test' + str(n) + '/' for n in range(58,68)]
print(paths)

noises = np.linspace(start=0.0, stop=0.5, num=21) # noises,
results = np.zeros((len(noises), 129, len(paths))) * np.nan # noises x architectures x repeats
wm_flags = np.zeros((129, 7, len(paths))) * np.nan # architectures x flags x repeats
n_parameters = np.zeros((129, len(paths))) * np.nan # architectures x repeats
auc = np.zeros((129, len(paths))) * np.nan # architectures x repeats 

for a, path in enumerate(tqdm(paths)):

    # Load data 
    for f in os.listdir(path):
        if f.endswith(".pickle"):
            with open(path + f, 'rb') as file:
                agnt = pickle.load(file)
                idx = int(os.path.splitext(f)[0])

                results[:, idx, a] = agnt.results
                wm_flags[idx, :, a] = agnt.wm_flags
                n_parameters[idx, a] = agnt.n_parameters
                auc[idx, a] = np.trapz(y=agnt.results, x=noises)

In [None]:
# Graph features 
import networkx as nx
from scipy.linalg import expm

Gs, communcability, n_cycles, cycle_c = [], [], [], []
for a, wm_flag in enumerate(np.nanmax(wm_flags, axis=2)):
    G = nx.DiGraph([(0,1), (1,2)]) # F0 and F1 
    if wm_flag[0]: G.add_edge(0,0) # L0
    if wm_flag[1]: G.add_edge(1,1) # L1
    if wm_flag[2]: G.add_edge(2,2) # L2
    if wm_flag[3]: G.add_edge(0,2) # S0
    if wm_flag[4]: G.add_edge(2,0) # S1
    if wm_flag[5]: G.add_edge(1,0) # B0
    if wm_flag[6]: G.add_edge(2,1) # B1

    Gs.append(G)
    communcability.append(np.mean(expm(nx.adjacency_matrix(G).toarray())))
    n_cycles.append(len(list(nx.simple_cycles(G))))
    try:
        cycle_c.append(max([len(c) for c in nx.simple_cycles(G)]))
    except: 
        cycle_c.append(0)

In [None]:
# Sorting
idxs = np.argsort(np.nanmax(auc,axis=1) - np.nanstd(auc,axis=1)) # Sort by max minus std  

# Networks of interest
# interest = [1,128, np.argmax(np.nanmean(results[2,:,:], axis=1)), np.argmax(np.nanmean(auc, axis=1))]
# i_cols = ['xkcd:dusty blue', 'xkcd:purple', 'xkcd:dark seafoam', 'xkcd:yellowish']
# i_labels = ['Feedforward (L)', 'Fully recurrent', 'Most accurate', 'Most robust']

interest = [74, 10, 66, 73]
i_cols = ['xkcd:dark seafoam', 'xkcd:wintergreen', 'xkcd:tree green', 'xkcd:pale lime']
i_labels = ['None', 'L0', 'S0', 'B1']

# interest = [74]

# np.argwhere((np.nanmax(wm_flags, axis=2) == [1,0,0,1,0,0,0]).all(1))
# interest = [74, 10,66, 73]
# interest = [65, 33, 17]

print(interest)

In [None]:
# AUC vs paramters (scatters and curve fits)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10*2,5), sharex=True, sharey=False)

for a in range(2):
    plt.sca(ax[a])
    
    x = n_parameters.reshape(-1) # networks,

    if a == 0: 
        y = results[2,:,:].reshape(-1) # networks,
        plt.ylabel('Test accuracy')

    if a == 1: 
        y = auc.reshape(-1) # networks,
        plt.ylabel('Robustness to noise')

    # if a == 2: 
    #     x = np.repeat(communcability, repeats=len(paths)) # networks
    #     plt.xlabel('Mean communcability')

    z = np.repeat(np.arange(start=0, stop=129), repeats=len(paths)) # networks

    x = x[np.isnan(y) == False]
    z = z[np.isnan(y) == False]
    y = y[np.isnan(y) == False]

    idx = np.argsort(x)

    curve = np.poly1d(np.polyfit(x[idx],y[idx],deg=2))
    plt.plot(x[idx], curve(x[idx]), color='xkcd:grey', alpha=0.5)

    plt.scatter(x, y, s=10, color='k', marker='.', alpha=0.2, label='All networks')

    for b, i in enumerate(interest):
        plt.scatter(x[z == i], y[z == i], color=i_cols[b], marker='.', alpha=0.75, label=i_labels[b])
    
    if a == 0:
        plt.legend(loc='upper left')
    
    plt.xlabel('Number of parameters')

In [None]:
# Fitness vs noise AUC 
from matplotlib.patches import Rectangle

wm_f_labels = ['L0', 'L1', 'L2', 'S0', 'S1', 'B0', 'B1']

fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(5*3,5), sharex=True, sharey=False)

plt.sca(ax[0])
plt.plot([range(129), range(129)], [np.nanmin(auc,axis=1)[idxs], np.nanmax(auc,axis=1)[idxs]], color='k', alpha=0.25);
plt.ylabel('AUC');
plt.xticks([])

plt.sca(ax[1])
for i in range(7):
    plt.scatter(range(129), np.ones(129) * i, c=[(0, 0, 0, alpha) for alpha in np.nanmax(wm_flags, axis=2)[idxs,i]], s=5)
plt.yticks(range(7), wm_f_labels)
plt.xlabel('Architectures')

for a, i in enumerate(interest):
    plt.sca(ax[0])
    plt.plot([np.where(idxs == i)[0], np.where(idxs == i)[0]], [np.nanmin(auc[i]), np.nanmax(auc[i])], color=i_cols[a])  

    ax[1].add_patch(Rectangle((np.where(idxs == i)[0][0] -0.25, -0.95), 0.5, 7.05, color=i_cols[a], alpha=0.5))

In [None]:
# Fitness vs noise 
plt.plot([0.05, 0.05], [0,1], ':', color='k', alpha=0.5, label='Training noise')

plt.plot([], [], 'k', alpha=0.1, label='All architectures')
plt.plot(noises, results[:,range(129), np.nanargmax(auc,axis=1)], c='k', alpha=0.1);

for a, i in enumerate(interest):
    plt.plot(noises, results[:, i, np.nanargmax(auc[i])], color=i_cols[a], label=i_labels[a])

plt.ylim([0, 1.05])
plt.ylabel('Fitness')
plt.xlabel('Sensor noise')
plt.legend()

In [None]:
# Fitness vs noise 
plt.plot([0.05, 0.05], [0,1], ':', color='k', alpha=0.5, label='Training noise')

for a, i in enumerate(interest):
    plt.plot(noises, results[:, i], color=i_cols[a], label=i_labels[a], alpha=0.5)

plt.ylim([0, 1.05])
plt.ylabel('Fitness')
plt.xlabel('Sensor noise')

## Plotting architectures

In [None]:
# Graph features 
import numpy as np
import networkx as nx

Gs = []
for a, wm_flag in enumerate(np.nanmax(wm_flags, axis=2)):
    G = nx.DiGraph([(0,1), (1,2)]) # F0 and F1 
    if wm_flag[0]: G.add_edge(0,0) # L0
    if wm_flag[1]: G.add_edge(1,1) # L1
    if wm_flag[2]: G.add_edge(2,2) # L2
    if wm_flag[3]: G.add_edge(0,2) # S0
    if wm_flag[4]: G.add_edge(2,0) # S1
    if wm_flag[5]: G.add_edge(1,0) # B0
    if wm_flag[6]: G.add_edge(2,1) # B1

    Gs.append(G)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nrows=8, ncols=16, figsize=(30,15), sharex=True, sharey=True)

for a, _ in enumerate(ax.ravel()):
    plt.sca(ax.ravel()[a])

    wm_flag = np.nanmax(wm_flags,axis=2)[a+1]
    multimodal_mazes.plot_dqn_architecture(wm_flag)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1)
multimodal_mazes.plot_dqn_architecture(np.nanmax(wm_flags,axis=2)[73])
new_color = 'k'

for line in ax.get_lines(): 
    line.set_color(new_color)

for patch in ax.patches: 
    patch.set_facecolor(new_color)
    patch.set_edgecolor(new_color)

for path_collection in ax.collections:
    path_collection.set_facecolor(new_color)
    path_collection.set_edgecolor(new_color)


## Exact Shapley Values

In [None]:
scores = np.nanmean(results[2,:,:],1) # test accuracy 
# scores = np.nanmean(auc, axis=1) # robustness to noise
wm_flags = np.nanmax(wm_flags, axis=2)

In [None]:
from itertools import combinations
def calculate_shapley_values(scores, wm_flags):
    n_hyperparams = wm_flags.shape[1]
    shapley_values = np.zeros(n_hyperparams)

    # Iterate over each hyperparameter
    for i in range(n_hyperparams):
        marginal_contributions = []
        
        # Iterate over all possible coalitions
        for coalition_size in range(n_hyperparams):
            for coalition in combinations(range(n_hyperparams), coalition_size):
                if i not in coalition:
                    # Convert coalition to list and add the current hyperparameter
                    coalition_with_i = list(coalition) + [i]
                    
                    # Calculate the indices for the coalitions with and without the hyperparameter
                    idx_coalition = np.where(np.all(wm_flags[:, coalition] == 1, axis=1) & (wm_flags[:,i] == 0))[0]
                    idx_coalition_with_i = np.where(np.all(wm_flags[:, coalition_with_i] == 1, axis=1))[0]
                    assert len(np.intersect1d(idx_coalition, idx_coalition_with_i)) == 0, "Indexing error"
                    
                    # Calculate the marginal contribution and add it to the list
                    marginal_contributions.extend(scores[idx_coalition_with_i] - scores[idx_coalition])
        
        # Calculate the Shapley value for the current hyperparameter as the average of its marginal contributions
        shapley_values[i] = np.mean(marginal_contributions)
        
        return shapley_values

# Calculate 
shapley_values = calculate_shapley_values(scores, wm_flags)
plt.scatter(range(len(shapley_values)), shapley_values)

## Random forest model

In [None]:
# Prepare data
X = np.transpose(wm_flags, (0, 2, 1)).reshape(-1,7)
# y = results[2,:,:].reshape(-1)
y = auc.reshape(-1)

X = X[np.isnan(y) == False]
y = y[np.isnan(y) == False]

In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, KFold
import shap

model = RandomForestRegressor()
kf = KFold(n_splits=10, shuffle=True)
scores = cross_val_score(model, X, y, cv=kf, scoring="neg_mean_squared_error") * -1
print(np.mean(scores))
model.fit(X,y)

In [None]:
explainer = shap.TreeExplainer(model, X)
shap_values = explainer(X)
shap.plots.beeswarm(shap_values)

In [None]:
import seaborn as sns

In [None]:
feature_importance = np.mean(np.abs(shap_values.values), axis=0)
feature_importance_ranking = np.argsort(feature_importance)

In [None]:
labels = ("L0", "L1", "L2", "S0", "S1", "B0", "B1")
colors = ('blue', 'orange')

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10*2,5), sharex=True, sharey=False)
plt.hlines(y=0.0, xmin=-0.5, xmax=6.5, color='k', alpha=0.25)

for i in feature_importance_ranking:
        sns.swarmplot(
        x=np.full(shap_values.values[:, i].shape, i),
        y=shap_values.values[:, i],
        hue=X[:, i],
        s=1.5,
        palette=colors,
        alpha=0.7,
        legend=False,
        );
        print(i)
plt.ylim([-0.1,0.1])
plt.ylabel("Shapley value")
plt.xticks(ticks=range(7), labels=[labels[i] for i in feature_importance_ranking])
plt.xlabel("Weight matrix")

plt.show()

In [None]:
import pandas as pd
# Create a DataFrame for easier plotting with seaborn
data = []
for i in range(7):
    for j in range(1288):
        data.append([shap_values[j, i], feature_names[i], X[j, i]])
df = pd.DataFrame(data, columns=['SHAP Value', 'Feature', 'Feature Value'])

plt.figure(figsize=(10, num_features * 0.5))

sns.violinplot(x='Feature', y='SHAP Value', hue='Feature Value', data=df, 
                palette=[colors[0], colors[1]], split=True, inner="quart")

In [None]:
sns.violinplot(x=np.ones(1288)*i,y=shap_values.values[:,i], hue=X[:,i], split=True, inner=None)

## Exploring function

### Loading networks 

In [None]:
import os
import copy 
paths = ['../Results/test' + str(n) + '/' for n in range(58,68)]
print(paths)

# interest = [1, 128, 74, 36]
# i_cols = ['xkcd:dusty blue', 'xkcd:purple', 'xkcd:dark seafoam', 'xkcd:yellowish']
# i_labels = ['Feedforward (L)', 'Fully recurrent', 'Most accurate', 'Most robust']

# interest = [74, 10, 66, 73]
# i_cols = ['xkcd:dark seafoam', 'xkcd:wintergreen', 'xkcd:tree green', 'xkcd:pale lime']
# i_labels = ['None', 'L0', 'S0', 'B1']

# interest = [73]

interest = list(np.arange(0,129))

agents, idxs = [], []
for a, path in enumerate(tqdm(paths)):

    # Load data 
    for f in interest:
        try:
            with open(path + str(f) + ".pickle", 'rb') as file:
                    agnt = pickle.load(file)
                    agents.append(agnt)
                    idxs.append(f)
        except: 
            pass

print(np.unique(idxs, return_counts=True))

### Sensitivity (Jacobian)

In [None]:
# Functions
from torch.autograd.functional import jacobian
from scipy.optimize import curve_fit
import copy 

def forward(*x):
    agnt.channel_inputs = x[0]
    agnt.outputs, _, _, _ = agnt.forward(x[1], x[2], x[3], tensor_input=True)   
    return agnt.outputs

def gaussian(x, amp, mean, stddev):
    return amp * np.exp(-((x - mean) ** 2) / (2 * stddev ** 2))

In [None]:
# Calculating sensitivity
sensor_noise_scale = 0.05 
n_steps = 6 

maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=1000, noise_scale=0, gaps=2)

popts = []
for a, i in enumerate(tqdm(interest)): # for each architecture
    for _, v in enumerate(np.where(np.array(idxs) == i)[0]): # for each network
        agnt = copy.deepcopy(agents[v])
        all_states = []        

        # Collect states
        for mz_n, mz in enumerate(maze.mazes):
            # Run trial
            _, _, trial_states = multimodal_mazes.maze_trial(
                mz=mz,
                mz_start_loc=maze.start_locations[mz_n],
                mz_goal_loc=maze.goal_locations[mz_n],
                channels=[1,1],
                sensor_noise_scale=sensor_noise_scale,
                drop_connect_p=0.0,
                n_steps=n_steps,
                agnt=agnt,
                genome=None,
                config=None,
                record_states=True,
            )

            all_states.extend(trial_states)

        # Calculate jacobian norms    
        xs, ys = [], []
        for _, state in enumerate(all_states):
            jm = jacobian(forward, state)
            x = (torch.sum(state[0][[0,1]]) - torch.sum(state[0][[2,3]])) / torch.sum(state[0][:4])
            y = torch.norm(jm[0], p='fro')

            xs.append(x)
            ys.append(y)

        # Remove nan values
        xs = np.array(xs)
        ys = np.array(ys)
        y = ys[np.isnan(xs) == False]
        x = xs[np.isnan(xs) == False]
        idx = np.argsort(x)

        # Fit curve
        try:
            popt, _ = curve_fit(gaussian, x[idx], y[idx], bounds=((-np.inf, -np.inf, 0), (np.inf, np.inf, np.inf)))
            y_fit = gaussian(x[idx], *popt)
            popts.append(popt)

            # Plot curve
            plt.plot(x[idx], y_fit, color=i_cols[a], alpha=0.75)
        except:
            popts.append(np.array([np.nan, np.nan, np.nan]))

plt.xticks(ticks=[-1, 0, 1], labels=("-1.0", "0.0", "1.0"))
plt.xlabel('(L-R)/(L+R)')
plt.ylabel("Sensitivity")

for a, i in enumerate(interest):
    plt.plot([], [], color=i_cols[a], alpha=0.75, label=i_labels[a])
plt.legend()

In [None]:
labels = np.repeat([0,1,2,3], repeats=10)
for l in [0, 1,2,3]:
    print(np.round(np.nanmean(np.array(popts)[labels==l], axis=0),2))

In [None]:
for a, popt in enumerate(popts):
    plt.scatter(labels[a], popt[2], color=i_cols[labels[a]], alpha=0.75)

plt.xlabel('Architectures')

### Ablations

In [None]:
sensor_noise_scale = 0.05 
n_steps = 6 

maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=1000, noise_scale=0, gaps=2)

fitness = np.zeros((len(agents), 8)) * np.nan

for a in tqdm(range((len(agents)))): # for each network
    wm_idx = [0] + list(np.where(agents[a].wm_flags)[0] + 1)
    for b, wm in enumerate([np.inf] + list(np.arange(2, len(list(agents[a].parameters()))))): # for each weight matrix
        agnt = copy.deepcopy(agents[a])
        for c, param in enumerate(agnt.parameters()):
            if c == wm: 
                param.data = torch.zeros(*param.shape)

        fitness[a, wm_idx.pop(0)] = multimodal_mazes.eval_fitness(genome=None, config=None, channels=[1,1], sensor_noise_scale=sensor_noise_scale, drop_connect_p=0.0, maze=maze, n_steps=n_steps, agnt=agnt)

fitness_norm = fitness - fitness[:,0][:,None]

np.save("./fitness", fitness)
np.save("./fitness_norm", fitness_norm)

In [None]:
plt.scatter(x=np.tile(np.arange(0,8), reps=40), y=fitness.reshape(-1), color='k', marker='.', alpha=0.5)
# plt.ylim([-0.05, 1.05])

for a, i in enumerate(interest):
    plt.scatter(x=np.tile(np.arange(0,8), reps=10), y=fitness[np.array(idxs) == i].reshape(-1), color=i_cols[a], marker='.', alpha=0.5)

plt.xticks(ticks=np.arange(8), labels=("N", "L0", "L1", "L2", "S0", "S1", "B0", "B1"))
plt.xlabel('Ablation')
plt.ylabel("Fitness")

In [None]:
plt.errorbar(x=range(8),y=np.nanmean(fitness_norm, axis=0), yerr=np.nanstd(fitness_norm, axis=0), color='k', ls='None', marker='o')
plt.xticks(ticks=np.arange(8), labels=("N", "L0", "L1", "L2", "S0", "S1", "B0", "B1"));
plt.xlabel('Ablation')
plt.ylabel("Change in fitness")

In [None]:
offsets = np.linspace(start=-0.25, stop=0.25, num=4)
for a, i in enumerate(interest):
    plt.errorbar(
        x=range(8) + np.repeat(offsets[a], repeats=8),
        y=np.nanmean(fitness_norm[np.array(idxs) == i], axis=0), 
        yerr=np.nanstd(fitness_norm[np.array(idxs) == i], axis=0), 
        color=i_cols[a], ls='None', marker='o')

plt.xticks(ticks=np.arange(8), labels=("N", "L0", "L1", "L2", "S0", "S1", "B0", "B1"));
plt.xlabel('Ablation')
plt.ylabel("Change in fitness")

### Memory

In [None]:
L_input = [0.1,0.1]
R_inputs = [[0.1, 0.0], [0.0, 0.1]]

for a, i in enumerate(interest): # for each architecture 
    L_outputs = []
    for i_num, v in enumerate(np.where(np.array(idxs) == i)[0]): # for each network
        agnt = copy.deepcopy(agents[v])

        for b in range(2):
            l_outputs = [] 
            
            # Reset agent
            agnt.prev_input *= 0.0
            agnt.hidden *= 0.0
            agnt.prev_output *= 0.0
            agnt.outputs *= 0.0
            agnt.channel_inputs *= 0.0    

            for t in range(6):

                if t == 0:
                    agnt.channel_inputs[0] = np.copy(L_input)
                    agnt.channel_inputs[1] = np.copy(R_inputs[b])
                else: 
                    agnt.channel_inputs *= 0.0

                agnt.policy()
                l_outputs.append(np.copy(agnt.outputs[0]))
            L_outputs.append(l_outputs)

    x = range(np.array(L_outputs).shape[1])
    y = np.nanmean(np.array(L_outputs),axis=0)
    sem = np.nanstd(np.array(L_outputs),axis=0) / np.sqrt(i_num + 1)
    plt.plot(x, y, color=i_cols[a], label=i_labels[a])
    plt.fill_between(x, y-sem, y+sem, color=i_cols[a], alpha=0.25, edgecolor=None)

plt.ylim([-0.05,1.05])
plt.legend()
plt.xlabel('Time since input')
plt.ylabel("p(Left)")

### Noise suppression

In [None]:
for a, i in enumerate(interest): # for each architecture 
    L_outputs = []
    for i_num, v in enumerate(np.where(np.array(idxs) == i)[0]): # for each network
        agnt = copy.deepcopy(agents[v])

        for _ in range(100):
            l_outputs = [] 
            
            # Reset agent
            agnt.prev_input *= 0.0
            agnt.hidden *= 0.0
            agnt.prev_output *= 0.0
            agnt.outputs *= 0.0
            agnt.channel_inputs *= 0.0    

            for t in range(6):
                agnt.channel_inputs = np.random.normal(
                    loc=0.0, scale=0.05, size=agnt.channel_inputs.shape
                )
                agnt.channel_inputs = np.clip(agnt.channel_inputs, a_min=0.0, a_max=1.0)
                agnt.policy()
                l_outputs.append(np.std(np.array(agnt.outputs)))
            L_outputs.append(l_outputs)

    x = range(np.array(L_outputs).shape[1])
    y = np.nanmean(np.array(L_outputs),axis=0)
    sem = np.nanstd(np.array(L_outputs),axis=0) / np.sqrt(i_num + 1)
    plt.plot(x, y, color=i_cols[a], label=i_labels[a])
    plt.fill_between(x, y-sem, y+sem, color=i_cols[a], alpha=0.25, edgecolor=None)

plt.ylim([-0.05,0.45])
plt.legend()
plt.xlabel('Time with noise input')
plt.ylabel("Std(p(outputs))")