# DQN notebook

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import neat
import matplotlib.pyplot as plt

from matplotlib import cm

import pickle
import multimodal_mazes
from tqdm import tqdm

import itertools
import torch
import torch.nn as nn
import torch.optim as optim

import copy

## WIP

In [None]:
maze = multimodal_mazes.HMaze(size=11, n_channels=2)
maze.generate(number=1000)

In [None]:
n = 0

# Colormaps 
from matplotlib import colors

cmap_wall = cm.binary
cmap_wall.set_under('k', alpha=0)

cmap_ch0 = colors.LinearSegmentedColormap.from_list(
    "", ["white", "xkcd:ultramarine"]
)

cmap_ch1 = colors.LinearSegmentedColormap.from_list(
    "", ["white", "xkcd:magenta"]
)

fig, ax = plt.subplots()

plt.imshow(1 - maze.mazes[n][:,:,-1], clim=[0.1,1.0], cmap=cmap_wall, alpha=0.25, zorder=1)

# Adjust axes 
plt.axis("off")

plt.imshow((cmap_ch0(maze.mazes[n][:,:,0]) + cmap_ch1(maze.mazes[n][:,:,1]))/2, zorder=0) 


## Ideas 
* Optuna: https://colab.research.google.com/github/araffin/tools-for-robotic-rl-icra2022/blob/main/notebooks/optuna_lab.ipynb#scrollTo=E0yEokTDxhrC 

## Fitness vs noise

In [None]:
# Fitness vs noise
noises = np.linspace(start=0.0, stop=0.5, num=21)

# Small set
wm_flags = np.array([[0,0,0,0,0,0,0], [0,0,0,0,0,0,0], [1,1,1,1,0,0,0],[1,1,1,1,1,1,1]])
colors = ["xkcd:gray", [0.0, 0.0, 0.0, 0.5], [0.039, 0.73, 0.71, 1], list(np.array([24, 156, 196, 255]) / 255)]

# Full set 
# wm_flags = np.array(list(itertools.product([0,1], repeat=7)))
# wm_flags = np.vstack((wm_flags[0], wm_flags))
# colors = cm.get_cmap("plasma", len(wm_flags)).colors.tolist()

results = np.zeros((len(noises), len(wm_flags)))

# Generate mazes
maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=100000, noise_scale=0.0, gaps=2)

maze_test = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze_test.generate(number=1000, noise_scale=0.0, gaps=2)

# Run
for b, wm_flag in enumerate(tqdm(wm_flags)): 

    # Control architecture 
    if b != 1: 
        n_hidden_units = 8
    else:
        n_hidden_units = 34 
    
    # Train
    agnt = multimodal_mazes.AgentDQN(location=[5,5], channels=[1,1], sensor_noise_scale=0.05, n_hidden_units=n_hidden_units, wm_flags=wm_flag)
    agnt.generate_policy(maze, n_steps=6) 

    # Test 
    for a, noise in enumerate(noises):
        results[a,b] = multimodal_mazes.eval_fitness(genome=None, config=None, channels=[1,1], sensor_noise_scale=noise, drop_connect_p=0.0, maze=maze_test, n_steps=6, agnt=agnt)

# Plotting
plt.plot([0.05, 0.05], [0,1], ':', color='k', alpha=0.5, label='Training noise')
for b, wm_flag in enumerate(wm_flags): 
    plt.plot(noises, results[:,b], color=colors[b], label=wm_flag)

plt.ylim([0, 1.05])
plt.ylabel('Fitness')
plt.xlabel('Sensor noise')
plt.legend()

In [None]:
# Fitness vs noise AUC 
auc = np.trapz(y=results.T, x=noises, axis=1)
idxs = np.argsort(auc)

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5*3,5), sharex=False, sharey=True)
for b, idx in enumerate(idxs): 
    ml, sl, _ = plt.stem(b, auc[idx])
    ml.set_color('k')
    sl.set_color('k')
# plt.xticks(range(len(wm_flags)), policies, rotation='vertical')
plt.ylabel('AUC');

## Loading results

In [None]:
# def calculate_w_norms(agnt):
#     """"""
#     tmp_f, tmp_o = [], []
#     w_norms = np.zeros(9)
#     for a, p in enumerate(agnt.parameters()):

#         if a <= 1: 
#             tmp_f.append(torch.norm(p.data, p="fro"))
#         else: 
#             tmp_o.append(torch.norm(p.data, p="fro"))

#     w_norms[:2] = np.array(tmp_f) / np.array(tmp_f)[0]
#     w_norms[2:][agnt.wm_flags == 1] = np.array(tmp_o) / np.array(tmp_f)[0]

#     return w_norms

In [None]:
import os
paths = ['../Results/test' + str(n) + '/' for n in range(901, 951)]

print(paths)

noises = np.linspace(start=0.0, stop=0.5, num=21) # noises,
results = np.zeros((len(noises), 129, len(paths))) * np.nan # noises x architectures x repeats
wm_flags = np.zeros((129, 7, len(paths))) * np.nan # architectures x flags x repeats
n_parameters = np.zeros((129, len(paths))) * np.nan # architectures x repeats
auc = np.zeros((129, len(paths))) * np.nan # architectures x repeats 

training_fs = np.zeros((129, 100, len(paths))) * np.nan  # architectures x data points x repeats
input_sens = [[] for _ in range(129)] # architectures 
input_sens_test_mean = np.zeros((129, len(paths))) * np.nan # architectures x repeats
input_sens_overall_mean = np.zeros((129, len(paths))) * np.nan # architectures x repeats
memories = np.zeros((129, len(noises), 14, len(paths))) * np.nan # architectures x noises x time steps x repeats
# w_norms = np.zeros((129, 9, len(paths))) * np.nan # architectures x weight matricies x repeats

gaps = []
for a, path in enumerate(tqdm(paths)):

    try: 
        # Load data 
        for f in os.listdir(path):
            if f.endswith(".pickle"):
                with open(path + f, 'rb') as file:
                    agnt = pickle.load(file)
                    idx = int(os.path.splitext(f)[0])

                    results[:, idx, a] = agnt.results
                    wm_flags[idx, :, a] = agnt.wm_flags
                    n_parameters[idx, a] = agnt.n_parameters
                    auc[idx, a] = np.trapz(y=agnt.results, x=noises)

                    training_fs[idx, :, a] = agnt.training_fitness
                    input_sens[idx].append(agnt.input_sensitivity)
                    input_sens_test_mean[idx, a] = np.nanmean(agnt.input_sensitivity[2][1])
                    input_sens_overall_mean[idx, a] = np.nanmean([np.nanmean(agnt.input_sensitivity[i][1]) for i in range(len(agnt.input_sensitivity))])
                    memories[idx, :, :, a] = agnt.memory
                    # w_norms[idx, :, a] = calculate_w_norms(agnt)

            elif f.endswith(".ini"):
                exp_config = multimodal_mazes.load_exp_config(path + f)
                gaps.append(exp_config["maze_gaps"])
    except:
        gaps.append(np.nan)

gaps = np.array(gaps)
# assert len(np.unique(gaps)) == 1, "Loaded experiments with different gap lengths"
print(gaps)

In [None]:
# Sample efficiency 

# Function as time to threshold
se_threshold = np.nanmedian(results[2]) * 0.9
sample_efficiency = np.zeros_like(auc) # architectures x repeats

for a in range(auc.shape[0]):
    for b in range(auc.shape[1]):
        if np.isnan(training_fs[a,0,b]) == False:
            tmp = np.where(training_fs[a,:,b] >= se_threshold)[0]

            if len(tmp) > 1: 
                sample_efficiency[a,b] = 100 - tmp[0]
            else: 
                sample_efficiency[a,b] = 0

### Comparison across hyperparameters

In [None]:
# WIP Kendall's tau 
from scipy.stats import kendalltau

n_exps = 11
n_repeats = 25 

exp_keys = np.repeat(np.arange(0, n_exps), repeats=n_repeats)
k_taus = np.zeros((n_exps, n_exps)) * np.nan

xs, ys = [], []
for a in range(n_exps):
    for b in range(n_exps):
        if a != b: 
            x = np.zeros(129)
            y = np.zeros(129)

            x[np.argsort(np.nanmedian(function[:, exp_keys==a], axis=1))] = np.arange(129)
            y[np.argsort(np.nanmedian(function[:, exp_keys==b], axis=1))] = np.arange(129)

            k_taus[a,b], _ = kendalltau(x, y)

            xs.extend(np.copy(x))
            ys.extend(np.copy(y))

xs = np.array(xs)
ys = np.array(ys)

In [None]:
# WIP: Kendall's tau figures 
interest = []
interest.append(np.where(k_taus == np.nanmin(k_taus))[0])
interest.append(np.where(k_taus == np.nanmax(k_taus))[0])

# Image
plt.imshow(k_taus)
plt.xlabel('Configuration')
plt.ylabel('Configuration')
plt.colorbar()

# Scatters
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10,5), sharex=False, sharey=False)
for a, i in enumerate(interest):
    plt.sca(ax[a])
    plt.plot([0, 129], [0, 129], 'k', alpha=0.25)

    x = np.zeros(129)
    y = np.zeros(129)

    x[np.argsort(np.nanmedian(function[:, exp_keys==interest[a][0]], axis=1))] = np.arange(129)
    y[np.argsort(np.nanmedian(function[:, exp_keys==interest[a][1]], axis=1))] = np.arange(129)
    
    plt.scatter(x, y, s=3, c='k')

    if a == 0:
        plt.xlabel('Ranking')
        plt.ylabel('Ranking')
    else: 
        plt.xticks([])
        plt.yticks([])

In [None]:
plt.plot([0, 129], [0, 129], 'k', alpha=1.0, label='y=x')

idx = np.argsort(xs)
curve = np.poly1d(np.polyfit(xs[idx],ys[idx],deg=1))
plt.plot(xs[idx], curve(xs[idx]), color='xkcd:turquoise', alpha=1.0, label='Linear fit')

plt.scatter(xs, ys, s=1, c='k', alpha=0.2)

In [None]:
# Comparing hyperparameters 

exp_keys = np.repeat(np.arange(0,n_exps), repeats=n_repeats)
np.argsort(np.nanmedian(function, axis=1))

In [None]:
interest = [128, 41]

i_cols = ['xkcd:purple', 'xkcd:dark seafoam']
i_labels = ['Fully recurrent', 'Fastest learner']

tmp = []
for a, i in enumerate(interest):
    tmp.append([np.nanmedian(function[i, exp_keys==e]) for e in np.unique(exp_keys)])

tmp = np.array(tmp)
plt.plot([range(n_exps), range(n_exps)], [tmp[0], tmp[1]], c='xkcd:grey', zorder=0);

for a, i in enumerate(interest):
    plt.scatter(range(n_exps), tmp[a], c=i_cols[a], label=i_labels[a])

plt.ylabel('Fitness')
plt.xticks(range(11), ['Baseline', 'n=4', 'n=16', 'sn=0.025', 'sn=0.1', 'lr=0.0001', 'lr=0.01', 'g=0.95', 'g=0.99', 'eps=50000', 'eps=200000'])
plt.xticks(rotation=90)
plt.legend();

In [None]:
# Fitness and robustness to noise (Averaging per experiment)
interest = [1,128, np.argmax(np.nanmedian(function, axis=1)), np.argmax(np.nanmedian(auc, axis=1))]
i_cols = ['xkcd:dusty blue', 'xkcd:purple', 'xkcd:dark seafoam', 'xkcd:yellowish']
i_labels = ['Feedforward (L)', 'Fully recurrent', 'Fastest learner', 'Most robust']

tmp_f = np.array([np.nanmedian(function[:, exp_keys == e], axis=1) for e in range(11)]).T
multimodal_mazes.plot_dqn_rankings(y=tmp_f, y_label='Fitness over training', interest=interest, i_cols=i_cols, sig_test=interest[1])

tmp_r = np.array([np.nanmedian(auc[:, exp_keys == e], axis=1) for e in range(11)]).T
multimodal_mazes.plot_dqn_rankings(y=tmp_r, y_label='Robustness', interest=interest, i_cols=i_cols, sig_test=interest[1])

### Fitness, sample efficiency and robustness to noise

In [None]:
# Networks of interest
interest = [1,128, np.argmax(np.nanmedian(sample_efficiency, axis=1)), np.argmax(np.nanmedian(auc, axis=1))]
i_cols = ['xkcd:soft blue', 'xkcd:purpley', 'xkcd:magenta', 'xkcd:bright orange']
i_labels = ['Feedforward (L)', 'Fully connected', 'Most sample efficient', 'Most robust']

# np.argwhere((np.nanmax(wm_flags, axis=2) == [1,0,0,1,0,0,0]).all(1))

print(interest)

fig, ax = plt.subplots(nrows=1, ncols=len(i_cols), figsize=(5*len(i_cols),5), sharex=True, sharey=True)
for a, i in enumerate(interest):
    plt.sca(ax[a])
    multimodal_mazes.plot_dqn_architecture(np.nanmax(wm_flags, axis=2)[interest[a]], ax=ax[a], color=i_cols[a])

In [None]:
# Select architectures 
multimodal_mazes.plot_dqn_examples(y=results[2], y_label='Fitness', interest=interest, i_cols=i_cols, sig_test=interest[1])
multimodal_mazes.plot_dqn_examples(y=sample_efficiency, y_label='Sample efficiency', interest=interest, i_cols=i_cols, sig_test=interest[1])
multimodal_mazes.plot_dqn_examples(y=auc, y_label='Robustness', interest=interest, i_cols=i_cols, sig_test=interest[1])

In [None]:
# All architectures 
multimodal_mazes.plot_dqn_rankings(y=results[2], y_label='Fitness', interest=interest, i_cols=i_cols, sig_test=interest[1])
multimodal_mazes.plot_dqn_rankings(y=sample_efficiency, y_label='Sample efficiency', interest=interest, i_cols=i_cols, sig_test=interest[1])
multimodal_mazes.plot_dqn_rankings(y=auc, y_label='Robustness', interest=interest, i_cols=i_cols, sig_test=interest[1])

In [None]:
# Parameters vs function (average)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(10*3,5), sharex=True, sharey=False)

for a in range(3):
    plt.sca(ax[a])

    x = np.nanmean(n_parameters, axis=1)
    plt.xlabel('Number of parameters')

    if a == 0: 
        y = np.nanmedian(results[2], axis=1)
        plt.ylabel('Fitness')

    elif a == 1:
        y = np.nanmedian(sample_efficiency, axis=1)
        plt.ylabel('Sample efficiency')

    elif a == 2:
        y = np.nanmedian(auc, axis=1)
        plt.ylabel('Robustness')

    plt.scatter(x, y, s=90, c='k', alpha=0.25)

    for b, i in enumerate(interest):
        plt.scatter(x[i], y[i], color=i_cols[b], s=90, alpha=1.0, label=i_labels[b])

In [None]:
# Fitness vs noise 
plt.plot([0.05, 0.05], [0,1], ':', color='k', alpha=0.5, label='Training noise')

plt.plot([], [], 'k', alpha=0.1, label='All architectures')
for a, i in enumerate(range(129)):
    plt.plot(noises, np.nanmedian(results[:,a], axis=1), color='k', alpha=0.1)

for a, i in enumerate(interest):
    plt.plot(noises, np.nanmedian(results[:,i], axis=1), color=i_cols[a], label=i_labels[a])

plt.ylim([0, 1.05])
plt.ylabel('Fitness')
plt.xlabel('Sensor noise')
plt.legend()

### Sample efficiency 

In [None]:
# print(np.argsort(np.nanmedian(sample_efficiency, axis=1)))

# interest = [87, 128, np.argmax(np.nanmedian(sample_efficiency, axis=1))]

# i_cols = ['k', 'xkcd:purple', 'xkcd:dark seafoam green']
# i_labels = ['Min', 'Fully recurrent', 'Max']

# print(interest)

# fig, ax = plt.subplots(nrows=1, ncols=len(interest), figsize=(5*len(interest),5), sharex=True, sharey=True)
# for a, i in enumerate(interest):
#     plt.sca(ax[a])
#     multimodal_mazes.plot_dqn_architecture(np.nanmax(wm_flags, axis=2)[interest[a]], ax=ax[a], color=i_cols[a])

In [None]:
# Plot median training curve + interquartile range
# Captures the middle 50% of the data

# All networks
x = range(training_fs.shape[1])
l, m, u = np.nanquantile(np.transpose(training_fs, (1,0,2)).reshape(training_fs.shape[1], -1), [0.25, 0.5, 0.75], axis=1)
plt.plot(x, m, color='xkcd:grey', label='All networks')
plt.fill_between(x, l, u, color='xkcd:grey', alpha=0.25, edgecolor=None)

# Interest
for a, i in enumerate(interest):
    x = range(training_fs.shape[1])
    l, m, u = np.nanquantile(training_fs[i], [0.25, 0.5, 0.75], axis=1)
    plt.plot(x, m, color=i_cols[a], label=i_labels[a])
    plt.fill_between(x, l, u, color=i_cols[a], alpha=0.25, edgecolor=None)

plt.xticks([0, 50, 100], labels=["0", "100000", "200000"])
plt.xlabel('Training episodes')
# plt.ylim([0, 1.05])
plt.ylabel('Fitness')
plt.legend()

### Input sensitivity 

In [None]:
# np.argsort(np.nanmedian(input_sens_test_mean, axis=1))
# interest = [np.argmin(np.nanmedian(input_sens_test_mean, axis=1)), 128, 17]

# i_cols = ['k', 'xkcd:purple', 'xkcd:dark seafoam green']
# i_labels = ['Min', 'Fully recurrent', 'Max']
# print(interest)

# fig, ax = plt.subplots(nrows=1, ncols=len(interest), figsize=(5*len(interest),5), sharex=True, sharey=True)
# for a, i in enumerate(interest):
#     plt.sca(ax[a])
#     multimodal_mazes.plot_dqn_architecture(np.nanmax(wm_flags, axis=2)[interest[a]], ax=ax[a], color=i_cols[a])

In [None]:
# input_sens: lists of architectures, networks, noise levels then x and y
from scipy.optimize import curve_fit

def gaussian(x, amp, mean, stddev):
    return amp * np.exp(-((x - mean) ** 2) / (2 * stddev ** 2))

In [None]:
# Fit curves (per network)
# Input sens: lists of architectures, networks, noise levels then x and y

for a in interest:
    for b in range(len(input_sens[a])):

        xs = input_sens[a][b][2][0]
        ys = input_sens[a][b][2][1]
        idx = np.argsort(xs)

        try:
            popt, _ = curve_fit(gaussian, xs[idx], ys[idx], bounds=((-np.inf, -np.inf, 0), (np.inf, np.inf, np.inf)))
            y_fit = gaussian(xs[idx], *popt)
            plt.plot(xs[idx], y_fit, color=i_cols[np.where(np.array(interest) == a)[0][0]], alpha=0.75)
        except:
            pass 

plt.xticks(ticks=[-1, 0, 1], labels=("-1.0", "0.0", "1.0"))
plt.xlabel('Input state (L:R)')
plt.ylabel("Sensitivity")

for a, i in enumerate(interest):
    plt.plot([], [], i_cols[a], label=i_labels[a])
plt.legend()

In [None]:
# Fit curves (per architecture)

for a, i in enumerate(interest):

    xs, ys = [], []

    for b in range(len(input_sens[i])):

        xs.extend(input_sens[i][b][2][0])
        ys.extend(input_sens[i][b][2][1])
    
    xs = np.array(xs)
    ys = np.array(ys)
    idx = np.argsort(xs)

    try:
        popt, _ = curve_fit(gaussian, xs[idx], ys[idx], bounds=((-np.inf, -np.inf, 0), (np.inf, np.inf, np.inf)))
        y_fit = gaussian(xs[idx], *popt)
        plt.plot(xs[idx], y_fit, color=i_cols[a], alpha=1.0)
    except:
        pass 

plt.xticks(ticks=[-1, 0, 1], labels=("-1.0", "0.0", "1.0"))
plt.xlabel('Input state (L:R)')
plt.ylabel("Sensitivity")

for a, i in enumerate(interest):
    plt.plot([], [], i_cols[a], label=i_labels[a])
plt.legend()

In [None]:
# All architectures 
multimodal_mazes.plot_dqn_rankings(y=input_sens_test_mean, y_label="Input sensitivity", interest=interest, i_cols=i_cols, sig_test=interest[1])

### Memory capacity

In [None]:
# memories: architectures x noises x time steps x repeats

memory_capacity_test = np.nansum(memories[:,2], axis=1) # architectures x repeats

# np.argsort(np.nanmedian(memory_capacity_test, axis=1))
# interest = [1, 128, np.argmax(np.nanmedian(memory_capacity_test, axis=1))]

# i_cols = ['k', 'xkcd:purple', 'xkcd:dark seafoam green']
# i_labels = ['Feedforward (L)', 'Fully recurrent', 'Max']
# print(interest)

# fig, ax = plt.subplots(nrows=1, ncols=len(interest), figsize=(5*len(interest),5), sharex=True, sharey=True)
# for a, i in enumerate(interest):
#     plt.sca(ax[a])
#     multimodal_mazes.plot_dqn_architecture(np.nanmax(wm_flags, axis=2)[interest[a]], ax=ax[a], color=i_cols[a])

In [None]:
# Plot 

# All networks
x = range(memories.shape[2])
l, m, u = np.nanquantile(np.transpose(memories[:,2], (1,0,2)).reshape(memories.shape[2],-1), [0.25, 0.5, 0.75], axis=1)
plt.plot(x, m[::-1], color='xkcd:grey', label="All networks")
plt.fill_between(x, l[::-1], u[::-1], color='xkcd:grey', alpha=0.25, edgecolor=None)

# Interest
for a, i in enumerate(interest):
    x = range(memories[i,2].shape[0])
    l, m, u = np.nanquantile(memories[i,2], [0.25, 0.5, 0.75], axis=1)
    plt.plot(x, m[::-1], color=i_cols[a], label=i_labels[a])
    plt.fill_between(x, l[::-1], u[::-1], color=i_cols[a], alpha=0.25, edgecolor=None)

plt.hlines(y=1.0, xmin=0, xmax=5, color='k', linestyles='--')
plt.ylim(-0.95, 5)
plt.xlabel('Time')
plt.ylabel('Input-output influence')
plt.xticks(np.arange(memories.shape[2]), labels=np.arange(memories.shape[2])[::-1]*-1)
plt.legend()

In [None]:
multimodal_mazes.plot_dqn_rankings(y=memory_capacity_test, y_label="Memory capacity", interest=interest, i_cols=i_cols, sig_test=interest[1])

### Capacities vs parameters 

In [None]:
# Parameters vs function (average)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10*2,5), sharex=True, sharey=False)

for a in range(2):
    plt.sca(ax[a])

    x = np.nanmean(n_parameters, axis=1)
    plt.xlabel('Number of parameters')

    if a == 0: 
        y = np.nanmedian(input_sens_test_mean, axis=1)
        plt.ylabel('Input sensitivity')

    elif a == 1:
        y = np.nanmedian(memory_capacity_test, axis=1)
        plt.ylabel('Memory')

    plt.scatter(x, y, s=90, c='k', alpha=0.25)

    for b, i in enumerate(interest):
        plt.scatter(x[i], y[i], color=i_cols[b], s=90, alpha=1.0, label=i_labels[b])

### Comparison across tasks 

#### Fitness

In [None]:
def find_removal_settings(idx_a, idx_b, wm_flags):

    wm_flags = np.nanmax(wm_flags, axis=2)
    A = wm_flags[idx_a].astype(int)
    B = wm_flags[idx_b].astype(int)

    # Ensure A has more or equal 1s compared to B
    if np.sum(A) < np.sum(B):
        raise ValueError("A should have more parameters than B")

    # Generate all possible settings by removing elements from A
    removal_settings = []
    for i in range(1, wm_flags.shape[1] + 1):
        for combination in itertools.combinations(np.where(A == 1)[0], i):
            new_setting = A.copy()
            new_setting[list(combination)] = 0
            if np.array_equal(new_setting & B, B):
                removal_settings.append(new_setting)

    removal_settings = np.vstack((A, np.array(removal_settings)))
    removal_settings = removal_settings[np.argsort(np.sum(removal_settings, axis=1))]

    idxs = []
    for a in removal_settings:
        idxs.extend(np.argwhere((wm_flags == list(a)).all(1)))

    return list(np.array(idxs).reshape(-1))

In [None]:
# Networks of interest
from matplotlib.colors import LinearSegmentedColormap

# np.argwhere((np.nanmax(wm_flags, axis=2) == [1,0,0,1,0,0,0]).all(1)

interest = find_removal_settings(idx_a=128, idx_b=np.argmax(np.nanmedian(results[2,:,:], axis=1)), wm_flags=wm_flags)[::-1]

i_cols = LinearSegmentedColormap.from_list(
     "", ['xkcd:purple', 'xkcd:dark seafoam'], N=len(interest)
)
i_labels = ['Fully recurrent', "Most accurate"]
for _ in range(len(interest) - 2):
     i_labels.insert(1, "")
print(interest)

fig, ax = plt.subplots(nrows=1, ncols=len(interest), figsize=(len(interest)*5,5), sharex=True, sharey=True)

for a, i in enumerate(interest):
     plt.sca(ax[a])
     multimodal_mazes.plot_dqn_architecture(np.nanmax(wm_flags, axis=2)[interest[a]], ax=ax[a], color=i_cols(a))

In [None]:
# Fitness plot
fig, ax = plt.subplots(figsize=(15,5))

offsets = np.linspace(start=-0.35, stop=0.35, num=len(interest))

for a, i in enumerate(interest):
    for b in range(3):
        l, m, u = np.nanquantile(results[2][i][gaps==b], [0.25, 0.5, 0.75])
        plt.scatter(offsets[a] + b, m, color=i_cols(a))
        plt.plot([offsets[a] + b, offsets[a] + b], [l, u], color=i_cols(a), alpha=0.5)

plt.xticks([0, 1, 2])
plt.xlabel('Gap')
plt.ylabel('Fitness')

for a, i in enumerate(interest):
    plt.scatter([], [], color=i_cols(a), label=i_labels[a])
plt.legend(loc='lower left')

#### Why?

In [None]:
memory_capacity_test = np.sum(memories[:,2], axis=1) # architectures x repeats

In [None]:
from scipy import stats

gap = 2

_, p = stats.mannwhitneyu(
    x=training_fs_n_auc[interest[0]][gaps == gap], y=training_fs_n_auc[interest[-1]][gaps == gap], method="asymptotic", nan_policy="omit"
)
print(p)

_, p = stats.mannwhitneyu(
    x=input_sens_test_mean[interest[0]][gaps == gap], y=input_sens_test_mean[interest[-1]][gaps == gap], method="asymptotic", nan_policy="omit"
)
print(p)

_, p = stats.mannwhitneyu(
    x=memory_capacity_test[interest[0]][gaps == gap], y=memory_capacity_test[interest[-1]][gaps == gap], method="asymptotic", nan_policy="omit"
)
print(p)

In [None]:
for a, i in enumerate(interest):
    for b in range(3):
        # l, m, u = np.nanquantile(input_sens_test_mean[i][gaps==b], [0.25, 0.5, 0.75])
        # l, m, u = np.nanquantile(memory_capacity_test[i][gaps==b] - np.nanmedian(memory_capacity_test[0][gaps == b]), [0.25, 0.5, 0.75])
        
        l, m, u = np.nanquantile(memory_capacity_test[i][gaps==b], [0.25, 0.5, 0.75])
        plt.scatter(offsets[a] + b, m, color=i_cols(a))
        plt.plot([offsets[a] + b, offsets[a] + b], [l, u], color=i_cols(a), alpha=0.5)

plt.xlabel('Gap')
plt.ylabel('Memory')

for a, i in enumerate(interest):
    plt.scatter([], [], color=i_cols(a), label=i_labels[a])
plt.legend(loc='upper left')

## Plotting architectures

In [None]:
# Graph features 
import networkx as nx

Gs = []
for a, wm_flag in enumerate(np.nanmax(wm_flags, axis=2)):
    G = nx.DiGraph([(0,1), (1,2)]) # F0 and F1 
    if wm_flag[0]: G.add_edge(0,0) # L0
    if wm_flag[1]: G.add_edge(1,1) # L1
    if wm_flag[2]: G.add_edge(2,2) # L2
    if wm_flag[3]: G.add_edge(0,2) # S0
    if wm_flag[4]: G.add_edge(2,0) # S1
    if wm_flag[5]: G.add_edge(1,0) # B0
    if wm_flag[6]: G.add_edge(2,1) # B1

    Gs.append(G)

In [None]:
# All Nodes and edges 
wm_flags = np.array(list(itertools.product([0, 1], repeat=7)))
fig, ax = plt.subplots(nrows=8, ncols=16, figsize=(30,15), sharex=True, sharey=True)

for a, _ in enumerate(ax.ravel()):
    plt.sca(ax.ravel()[a])

    multimodal_mazes.plot_dqn_architecture(wm_flags[a])

In [None]:
# Single unit-unit matrices
multimodal_mazes.plot_dqn_weight_matrix(n_input_units=8, n_hidden_units=8, n_output_units=5, wm_flag=np.ones(7))
plt.xticks([])
plt.yticks([])

plt.xlabel('Units')
plt.ylabel('Units')

In [None]:
# All unit-unit matrices 
wm_flags = np.array(list(itertools.product([0, 1], repeat=7)))
fig, ax = plt.subplots(nrows=8, ncols=16, figsize=(30,15), sharex=True, sharey=True)

for a, _ in enumerate(ax.ravel()):
    plt.sca(ax.ravel()[a])

    multimodal_mazes.plot_dqn_weight_matrix(n_input_units=8, n_hidden_units=8, n_output_units=5, wm_flag=wm_flags[a])
    plt.axis('off')

## Exact Shapley Values

In [None]:
# scores = np.nanmean(results[2,:,:],1) # test accuracy 
scores = np.nanmean(auc, axis=1) # robustness to noise
wm_flags = np.nanmax(wm_flags, axis=2)

In [None]:
from itertools import combinations
def calculate_shapley_values(scores, wm_flags):
    n_hyperparams = wm_flags.shape[1]
    shapley_values = np.zeros(n_hyperparams)

    # Iterate over each hyperparameter
    for i in range(n_hyperparams):
        marginal_contributions = []
        
        # Iterate over all possible coalitions
        for coalition_size in range(n_hyperparams):
            for coalition in combinations(range(n_hyperparams), coalition_size):
                if i not in coalition:
                    # Convert coalition to list and add the current hyperparameter
                    coalition_with_i = list(coalition) + [i]
                    
                    # Calculate the indices for the coalitions with and without the hyperparameter
                    idx_coalition = np.where(np.all(wm_flags[:, coalition] == 1, axis=1) & (wm_flags[:,i] == 0))[0]
                    idx_coalition_with_i = np.where(np.all(wm_flags[:, coalition_with_i] == 1, axis=1))[0]
                    assert len(np.intersect1d(idx_coalition, idx_coalition_with_i)) == 0, "Indexing error"
                    
                    # Calculate the marginal contribution and add it to the list
                    marginal_contributions.extend(scores[idx_coalition_with_i] - scores[idx_coalition])
        
        # Calculate the Shapley value for the current hyperparameter as the average of its marginal contributions
        shapley_values[i] = np.mean(marginal_contributions)
        
        return shapley_values

# Calculate 
shapley_values = calculate_shapley_values(scores, wm_flags)
plt.scatter(range(len(shapley_values)), shapley_values)

## Random forest models

### Data viz

In [None]:
# Capacities vs function (average)
fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(5*2,5*3), sharex="col", sharey="row")

for a in range(2):
    for b in range(3):
        plt.sca(ax[b,a])

        # Data 
        if a == 0: 
            x = np.nanmedian(input_sens_test_mean, axis=1)
            plt.xlabel('Input sensitivity')

        elif a == 1:
            x = np.nanmedian(memory_capacity_test, axis=1)
            plt.xlabel('Memory')

        if b == 0: 
            y = np.nanmedian(results[2], axis=1)
            plt.ylabel('Fitness')

        elif b == 1:
            y = np.nanmedian(sample_efficiency, axis=1)
            plt.ylabel('Sample efficiency')

        elif b == 2:
            y = np.nanmedian(auc, axis=1)
            plt.ylabel('Robustness')

        # Plot 
        x = x[np.nanmedian(results[2], axis=1) >= 0.75]
        y = y[np.nanmedian(results[2], axis=1) >= 0.75]

        idx = np.argsort(x)
        curve = np.poly1d(np.polyfit(x[idx],y[idx],deg=2))
        ax[b,a].plot(x[idx], curve(x[idx]), color='k', alpha=1.0)

        plt.scatter(x, y, s=90, c='k', alpha=0.25)
        if a != 0:     
            ax[b,a].get_yaxis().set_visible(False)
        if b != 2: 
            ax[b,a].get_xaxis().set_visible(False)

In [None]:
# Capacities vs function (average)
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(5*3,5), sharex="all", sharey="all")

cb_labels = ['Fitness', 'Sample efficiency', 'Robustness']

def min_max_norm(x): 
    return (x - min(x)) / (max(x) - min(x))

for a in range(3):
    plt.sca(ax[a])

    if a == 0: 
        c = np.nanmedian(results[2], axis=1) # fitness
    elif a == 1: 
        c = np.nanmedian(sample_efficiency, axis=1) # sample efficiency
    elif a == 2: 
        c = np.nanmedian(auc, axis=1) # robustness

    x = np.nanmedian(input_sens_test_mean, axis=1)
    y = np.nanmedian(memory_capacity_test, axis=1)

    x = x[np.where(np.nanmedian(results[2], axis=1) >= 0.6)[0]]
    y = y[np.where(np.nanmedian(results[2], axis=1) >= 0.6)[0]]
    c = c[np.where(np.nanmedian(results[2], axis=1) >= 0.6)[0]]

    plt.scatter(min_max_norm(x), min_max_norm(y), c=c, s=90, alpha=0.75)

    plt.xlabel('Input sensitivity')
    plt.ylabel('Memory')
    plt.colorbar(location="top", shrink=0.75, label=cb_labels[a])

    if a != 0:     
        ax[a].get_yaxis().set_visible(False)
        ax[a].get_xaxis().set_visible(False)

### Model fits

In [None]:
# Model fits
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, KFold

X = np.hstack((np.transpose(wm_flags, (0, 2, 1)).reshape(-1,7), # structure
               input_sens_test_mean.reshape(-1)[:,None], # sensitivity 
               memory_capacity_test.reshape(-1)[:, None])) # memory 

y = np.hstack((input_sens_test_mean.reshape(-1)[:,None], # sensitivity
                memory_capacity_test.reshape(-1)[:, None], # memory
                results[2].reshape(-1)[:, None], # fitness
                sample_efficiency.reshape(-1)[:, None], # sample efficiency 
                auc.reshape(-1)[:, None])) # robustness

x_masks = np.zeros((13, X.shape[1])) # no features
for i in range(9):
    x_masks[i+1, i] = 1
x_masks[10, 7:] = 1 # capacity 
x_masks[11, :7] = 1 # structure
x_masks[12] = 1 # all features 

cv_means = np.zeros((x_masks.shape[0], y.shape[1])) * np.nan # masks x predictions
cv_stds = np.zeros_like(cv_means) * np.nan # masks x predictions

for a in tqdm(range(x_masks.shape[0])):

    for b in range(y.shape[1]):
        
        if a == 0: 
            cv_means[a,b] = np.sqrt(np.mean((y[:,b] - np.mean(y[:,b])) ** 2))

        else: 
            RF_model = RandomForestRegressor()
            kf = KFold(n_splits=10, shuffle=True)
            scores = cross_val_score(RF_model, X[:, np.where(x_masks[a])[0]], y[:,b], cv=kf, scoring="neg_root_mean_squared_error") * -1
            
            cv_means[a,b] = np.mean(scores)
            cv_stds[a,b] = np.std(scores)

In [None]:
# Errorbar plot
fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(5*5,5), sharex="none", sharey="all")
x_labels = ['Sensitivity (RMSE)', 'Memory (RMSE)', 'Fitness (RMSE)', 'Sample efficiency (RMSE)', 'Robustness (RMSE)']
y_labels = ["No features", "$W_{ii}$", "$W_{hh}$", "$W_{oo}$", "$W_{io}$", "$W_{oi}$", "$W_{hi}$", "$W_{oh}$", "Input sensitivity", "Memory", "Capacity", "Structure", "All features"]

for a in range(len(x_labels)):
    plt.sca(ax[a])

    plt.errorbar(x=cv_means[:,a], y=range(x_masks.shape[0]), xerr=cv_stds[:,a], fmt='.', c='xkcd:grey')

    plt.xlabel(x_labels[a])
    
    if a == 0: 
        plt.yticks(range(len(y_labels)), y_labels)
    else:  
        ax[a].get_yaxis().set_visible(False)

In [None]:
# Structure-capacity-function plot 
y_pred = 4
cv_means_norm = (cv_means[0] - cv_means) / cv_means[0] # models x predictions

import matplotlib
cmap = plt.cm.get_cmap('PiYG')
norm = matplotlib.colors.Normalize(vmin=-1, vmax=1)
fig, ax = plt.subplots()

# Nodes 
plt.scatter(x=np.zeros(7), y=range(7), c='xkcd:light grey', zorder=1)
plt.scatter(x=np.ones(2), y=[2.5, 3.5],c='xkcd:light grey', zorder=1)
plt.scatter(x=2, y=3, c='xkcd:light grey', zorder=1)

# Edges 
for i in range(7):

    # X -> C
    for a, o in enumerate([2.5, 3.5]):
        try: 
            plt.plot([0,1], [i,o], c=cmap(norm(cv_means_norm[i+1,a])), zorder=0)
        except: 
            pass

    # X -> Y 
    plt.plot([0,1], [i, i], c=cmap(norm(cv_means_norm[i+1,y_pred])), zorder=0)
    plt.plot([1,2], [i, 3], c=cmap(norm(cv_means_norm[i+1,y_pred])), zorder=0)

# C -> Y 
for a, i in enumerate([2.5, 3.5]):
    plt.plot([1, 2], [i,3], c=cmap(norm(cv_means_norm[a+8,y_pred])), zorder=0)

plt.yticks(range(7), labels=["$W_{ii}$", "$W_{hh}$", "$W_{oo}$", "$W_{io}$", "$W_{oi}$", "$W_{hi}$", "$W_{oh}$"])

plt.xticks([0,1,2], labels=["", 'Ic, Mc', x_labels[y_pred]], rotation=90)

ax.spines[['left', 'bottom']].set_visible(False)
ax.tick_params(left=False, bottom=False)

cb = fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap),
             ax=ax, orientation='vertical', label='Meep', shrink=0.5)
cb.outline.set_visible(False)

### Single noise level

In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, KFold
import shap
import seaborn as sns

# Structure
X = np.transpose(wm_flags, (0, 2, 1)).reshape(-1,7)

# Capacity (for test fitness)
# X = np.zeros((len(auc.reshape(-1)), 2)) * np.nan
# X[:,0] = input_sens_test_mean.reshape(-1) # input sensitivity 
# X[:,1] = memory_capacity_test.reshape(-1) # memory capacity 

# Capacity (for robustness to noise)
# X = np.zeros((len(training_fs_n_auc.reshape(-1)), 3)) * np.nan
# X[:,1] = input_sens_overall_mean.reshape(-1) # input sensitivity 
# X[:,2] = np.nanmean(np.nansum(memories[:,:,1:], axis=2), axis=1).reshape(-1) # memory capacity 

# Function 
# y = input_sens_test_mean.reshape(-1) # input sensitivity
y = memory_capacity_test.reshape(-1) # memory capacity

# y = results[2].reshape(-1) # fitness
# y = sample_efficiency.reshape(-1) # sample efficiency
# y = auc.reshape(-1) # robustness to noise 

# Remove nan values 
X = X[np.isnan(y) == False]
y = y[np.isnan(y) == False]
assert len(X) == len(y), "Mismatch"

In [None]:
# Fit random forest model
RF_model = RandomForestRegressor()
# kf = KFold(n_splits=10, shuffle=True)
# scores = cross_val_score(RF_model, X, y, cv=kf, scoring="neg_mean_absolute_error") * -1
# print("Mean guess " + str(np.mean(abs(np.mean(y) - y))) + " +/- " + str(np.std(abs(np.mean(y) - y))))
# print(str(np.round(np.mean(scores),3)) + " +/- " + str(np.round(np.std(scores),3)))
RF_model.fit(X,y)
print(np.mean(abs(RF_model.predict(X) - y)))

In [None]:
# Model fit 
plt.scatter(RF_model.predict(X), y, c='k',marker='.', linewidths=0, alpha=0.1, label='All networks')
plt.plot([0,1], [0,1], c='xkcd:turquoise', label='y=x')

# plt.xlim([0, 0.55])
# plt.ylim([0, 0.55])
plt.xlabel('Predicted robustness')
plt.ylabel('Robustness')
plt.legend()

In [None]:
# Shap analysis
explainer = shap.TreeExplainer(model=RF_model, data=X)
shap_values = explainer.shap_values(X, check_additivity=True)

# # feature_importance = np.mean(np.abs(shap_values), axis=0)
feature_importance = np.array([np.nanmedian(shap_values[np.where(X[:,i] == 1), i]) for i in range(7)])
feature_ranking = np.argsort(feature_importance) # ascending 

In [None]:
# Shap plot 
labels = ("$W_{ii}$", "$W_{hh}$", "$W_{oo}$", "$W_{io}$", "$W_{oi}$", "$W_{hi}$", "$W_{oh}$")
colors = ((0.5, 0.5, 0.5), (1, 1, 1))

fig, ax = plt.subplots(nrows=1, ncols=7, figsize=(10*2,5), sharex=True, sharey=True)
for a, i in enumerate(feature_ranking):
    plt.sca(ax[a])
    plt.hlines(y=0.0, xmin=-0.5, xmax=0.5, color='xkcd:grey', zorder=0)
    sns.violinplot(
        x=np.zeros(shap_values.shape[0]),
        y=shap_values[:,i], 
        hue=X[:,i], 
        palette=colors, split=True, inner=None, cut=True)
    
    ax[a].set_xticks([])
    plt.xlabel(labels[i])

    if a == 0: 
        plt.ylabel('Shap value')
        # plt.ylim([-0.1, 0.1])
        # plt.ylim([-0.02, 0.02])
    else:
        ax[a].get_legend().remove()
        ax[a].spines['left'].set_visible(False)           
        ax[a].tick_params(left=False, bottom=False)

    ax[a].spines['bottom'].set_visible(False)

In [None]:
# a = np.array([4, 0, 6, 5, 2, 1, 3])
# b = np.array([0, 2, 3, 5, 4, 1, 6])

# a = np.array([np.where(a == i)[0][0] for i in range(7)])
# b = np.array([np.where(b == i)[0][0] for i in range(7)])

# cols = ["xkcd:teal blue", "xkcd:teal blue", "xkcd:teal blue", "xkcd:topaz", "xkcd:topaz", "xkcd:orange","xkcd:orange"]
# labels = ("$W_{ii}$", "$W_{hh}$", "$W_{oo}$", "$W_{io}$", "$W_{oi}$", "$W_{hi}$", "$W_{oh}$")
# alphas = [1.0, 0.75, 0.5, 1.0, 0.5, 1.0, 0.5]

# for i in range(7):
#     plt.plot([0, 1], [a[i], b[i]], c=cols[i], alpha=alphas[i], label=labels[i]);

# plt.xticks([0,1], ["Fitness", "Robustness"])
# plt.ylabel('Ranking')
# plt.legend()

### Multiple noise levels

In [None]:
from sklearn.ensemble import RandomForestRegressor
import shap

shap_values = []
for a in tqdm(range(results.shape[0])): # for each level of noise
    
    # Structure 
    # X = np.transpose(wm_flags, (0, 2, 1)).reshape(-1,7)

    # Capacity (for test fitness)
    X = np.zeros((len(auc.reshape(-1)), 3)) * np.nan
    X[:,0] = training_fs_n_auc.reshape(-1) # training capacity 
    X[:,1] = input_sens_test_mean.reshape(-1) # input sensitivity 
    X[:,2] = np.sum(memories[:,2, 1:], axis=1).reshape(-1) # memory capacity 

    y = results[a,:,:].reshape(-1) # accuracy 

    X = X[np.isnan(y) == False]
    y = y[np.isnan(y) == False]

    assert len(X) == len(y), "Mismatch"

    # Fit random forest model
    RF_model = RandomForestRegressor()
    RF_model.fit(X,y)

    # Shap analysis
    explainer = shap.TreeExplainer(model=RF_model, data=X)
    s_v = explainer.shap_values(X, check_additivity=True) # networks x features

    shap_values.append(np.copy(s_v))

In [None]:
shap_summary = np.zeros((len(shap_values), shap_values[0].shape[1])) * np.nan # noises x features
for a in range(shap_summary.shape[0]):
    for b in range(shap_summary.shape[1]):

        if X.shape[1] == 3: 
            shap_summary[a,b] = np.nanmean(shap_values[a][shap_values[a][:,b] > 0, b])
        elif X.shape[1] == 7: 
            shap_summary[a,b] = np.nanmean(shap_values[a][X[:,b] == 1, b])

In [None]:
# Summary plot
if X.shape[1] == 3: 
    labels = ("Tc", "Ic", "Mc")
    cols = ["k", "k", "k"]
    alphas = [1.0, 0.75, 0.5]
    
elif X.shape[1] == 7:
    labels = ("L0", "L1", "L2", "S0", "S1", "B0", "B1")
    cols = ["xkcd:teal blue", "xkcd:teal blue", "xkcd:teal blue", "xkcd:topaz", "xkcd:topaz", "xkcd:orange","xkcd:orange"]
    alphas = [1.0, 0.75, 0.5, 1.0, 0.5, 1.0, 0.5]

ylim = [np.min(shap_summary) * 1.1, np.max(shap_summary) * 1.1]
plt.vlines(x=0.05, ymin=ylim[0], ymax=ylim[1], linestyles=":", color='k', alpha=0.5, label='Training noise')
for a in range(shap_summary.shape[1]):
    plt.plot(noises, shap_summary[:,a], color=cols[a], alpha=alphas[a], label=labels[a])

plt.xlabel('Sensor noise')
plt.ylabel('Shap value')
plt.ylim(ylim)

plt.legend(loc="upper center")

## Causal inference

In [None]:
from dowhy import CausalModel
import pandas as pd 

def min_max_norm(x): 
    return (x - min(x)) / (max(x) - min(x))

# Remove stateless networks
results[2][np.nanmedian(results[2], axis=1) < 0.5] = np.nan

In [None]:
# Structure
X = np.transpose(wm_flags, (0, 2, 1)).reshape(-1,7) 
Z = np.repeat(np.arange(129), repeats=results.shape[2])

# Capacity (for test accuracy)
C = np.zeros((X.shape[0], 3)) * np.nan
C[:,0] = min_max_norm(training_fs_n_auc.reshape(-1)) # training capacity 
C[:,1] = min_max_norm(input_sens_test_mean.reshape(-1)) # input sensitivity 
C[:,2] = min_max_norm(np.sum(memories[:,2, 1:], axis=1).reshape(-1)) # memory capacity 

# # Capacity (for robustness to noise)
# C = np.zeros((X.shape[0], 3)) * np.nan
# C[:,0] = min_max_norm(training_fs_n_auc.reshape(-1)) # training capacity 
# C[:,1] = min_max_norm(input_sens_overall_mean.reshape(-1)) # input sensitivity 
# C[:,2] = min_max_norm(np.nanmean(np.sum(memories[:,:,1:], axis=2), axis=1).reshape(-1)) # memory capacity 

# Function 
y = results[2,:,:].reshape(-1) # test accuracy 
# y = auc.reshape(-1) # robustness to noise 

# Remove nan values 
X = X[np.isnan(y) == False]
C = C[np.isnan(y) == False]
Z = Z[np.isnan(y) == False]
y = y[np.isnan(y) == False]
assert len(X) == len(C) == len(y), "Mismatch"

# Format data 
data = pd.DataFrame(X, columns= ["L0", "L1", "L2", "S0", "S1", "B0", "B1"])
for a, l in enumerate(['Tc', 'Ic', 'Mc']):
    data[l] = C[:,a]
data['y'] = y

In [None]:
# Scatter plots
plt.scatter(C[:,2], y, s=1, c='k', alpha=0.1)

interest = [1,128, np.argmax(np.nanmedian(results[2,:,:], axis=1)), np.argmax(np.nanmedian(auc, axis=1))]
i_cols = ['xkcd:dusty blue', 'xkcd:purple', 'xkcd:dark seafoam', 'xkcd:yellowish']
i_labels = ['Feedforward (L)', 'Fully recurrent', 'Most accurate', 'Most robust']

# for a, i in enumerate(interest):
#     plt.scatter(C[Z==i, 1], y[Z==i], c=i_cols[a], s=2)

In [None]:
import warnings
warnings.filterwarnings('ignore', module='dowhy')
warnings.filterwarnings('ignore', module='statsmodel')

CI = np.zeros((11,11)) * np.nan
for a, t in enumerate(tqdm(data.columns.to_list())):
    for b, o in enumerate(data.columns.to_list()):
        model = CausalModel(
            data=data,
            treatment=t,
            outcome=o,
            graph="digraph { "
                "L0 -> y; L1 -> y; L2 -> y; S0 -> y; S1 -> y; B0 -> y; B1 -> y; Tc -> y; Ic -> y; Mc -> y;"
                "L0 -> Tc; L1 -> Tc; L2 -> Tc; S0 -> Tc; S1 -> Tc; B0 -> Tc; B1 -> Tc;"
                "L0 -> Ic; L1 -> Ic; L2 -> Ic; S0 -> Ic; S1 -> Ic; B0 -> Ic; B1 -> Ic;"
                "L0 -> Mc; L1 -> Mc; L2 -> Mc; S0 -> Mc; S1 -> Mc; B0 -> Mc; B1 -> Mc}"
        )   
        # model.view_model(layout="dot")

        identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)

        estimate = model.estimate_effect(identified_estimand,
                method_name="backdoor.linear_regression")

        try:
            ci = estimate.get_confidence_intervals(confidence_level=0.95)
            r = model.refute_estimate(identified_estimand, estimate, method_name="data_subset_refuter")

            if (np.prod(ci) > 0) & (r.refutation_result['is_statistically_significant'] == False): 
                CI[a,b] = estimate.value
        except: 
            pass

# 12mins to run for a single experiment, 40 mins for 11.

In [None]:
print(np.nanmin(CI), np.nanmax(CI))

In [None]:
# Plot
import matplotlib
cmap = plt.cm.get_cmap('PiYG')
# norm = matplotlib.colors.Normalize(vmin=np.nanmax(abs(CI))*-1, vmax=np.nanmax(abs(CI)))
norm = matplotlib.colors.Normalize(vmin=-0.05, vmax=0.05)
fig, ax = plt.subplots()

# Nodes 
plt.scatter(x=np.zeros(7), y=range(7), c='xkcd:light grey', zorder=1)
plt.scatter(x=np.ones(3), y=[2,3,4],c='xkcd:light grey', zorder=1)
plt.scatter(x=2, y=3, c='xkcd:light grey', zorder=1)

# Edges 
for i in range(7):

    # X -> C
    for o in [2,3,4]:
        try: 
            plt.plot([0,1], [i,o], c=cmap(norm(CI[i,o+5])), zorder=0)
        except: 
            pass

    # X -> Y 
    if (i > 2) & (i < 5): 
        a = 0.25
    elif (i ==2):
        a = -0.25
    else: 
        a = 0.0
    plt.plot([0,1], [i, i+a], c=cmap(norm(CI[i,-1])), zorder=0)
    plt.plot([1,2], [i+a, 3], c=cmap(norm(CI[i,-1])), zorder=0)

# C -> Y 
for i in [2,3,4]:
    plt.plot([1, 2], [i,3], c=cmap(norm(CI[i+5,-1])), zorder=0)

# plt.yticks(range(7), labels=["L0", "L1", "L2", "S0", "S1", "B0", "B1"])
plt.yticks(range(7), labels=["$W_{ii}$", "$W_{hh}$", "$W_{oo}$", "$W_{io}$", "$W_{oi}$", "$W_{hi}$", "$W_{oh}$"])

plt.xticks([0,1,2], labels=["", 'Tc, Ic, Mc', 'Fitness'], rotation=90)

ax.spines[['left', 'bottom']].set_visible(False)
ax.tick_params(left=False, bottom=False)

cb = fig.colorbar(matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap),
             ax=ax, orientation='vertical', label='Causal effect', shrink=0.5)
cb.outline.set_visible(False)

## Exploring function

### Loading networks 

In [None]:
paths = ['../Results/test' + str(n) + '/' for n in range(826,851)]
print(paths)

interest = [1, 128]
i_cols = ['xkcd:dusty blue', 'xkcd:purple']
i_labels = ['Feedforward (L)', 'Fully recurrent']

agents, idxs = [], []
for a, path in enumerate(tqdm(paths)):

    # Load data 
    for f in interest:
        try:
            with open(path + str(f) + ".pickle", 'rb') as file:
                    agnt = pickle.load(file)
                    agents.append(agnt)
                    idxs.append(f)
        except: 
            pass

print(np.unique(idxs, return_counts=True))

### Sensitivity (Jacobian)

In [None]:
# Functions
from scipy.optimize import curve_fit
import copy 

def gaussian(x, amp, mean, stddev):
    return amp * np.exp(-((x - mean) ** 2) / (2 * stddev ** 2))

In [None]:
# Calculating sensitivity
import seaborn as sns

sensor_noise_scale = 0.05 
n_steps = 6 

maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=1000, noise_scale=0, gaps=2)

popts = []
for a, i in enumerate(tqdm(interest)): # for each architecture
    for _, v in enumerate(np.where(np.array(idxs) == i)[0]): # for each network
        agnt = copy.deepcopy(agents[v])
        all_states = []        

        # Collect states
        _, all_states = multimodal_mazes.eval_fitness(
            genome=None, config=None, channels=agnt.channels, 
            sensor_noise_scale=sensor_noise_scale, drop_connect_p=0.0, 
            maze=maze, n_steps=n_steps, agnt=agnt, record_states=True)

        # Calculate jacobian norms    
        xs, ys = multimodal_mazes.calculate_dqn_input_sensitivity(all_states, agnt)
        idx = np.argsort(xs)

        # Fit curve
        try:
            popt, _ = curve_fit(gaussian, xs[idx], ys[idx], bounds=((-np.inf, -np.inf, 0), (np.inf, np.inf, np.inf)))
            y_fit = gaussian(xs[idx], *popt)
            popts.append(popt)

            # Plot curve
            plt.plot(xs[idx], y_fit, color=i_cols[a], alpha=0.75)
        except:
            popts.append(np.array([np.nan, np.nan, np.nan]))

plt.xticks(ticks=[-1, 0, 1], labels=("-1.0", "0.0", "1.0"))
plt.xlabel('(L-R)/(L+R)')
plt.ylabel("Sensitivity")

for a, i in enumerate(interest):
    plt.plot([], [], color=i_cols[a], alpha=0.75, label=i_labels[a])
plt.legend()

In [None]:
labels = np.repeat([0,1,2,3], repeats=10)
for l in [0, 1,2,3]:
    print(np.round(np.nanmean(np.array(popts)[labels==l], axis=0),2))

In [None]:
for a, popt in enumerate(popts):
    if labels[a] != 0:
        plt.scatter(labels[a], popt[2], color=i_cols[labels[a]], alpha=0.75)

plt.xlabel('Architectures')

### Ablations

In [None]:
sensor_noise_scale = 0.05 
n_steps = 6 

maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=1000, noise_scale=0, gaps=2)

fitness = np.zeros((len(agents), 8)) * np.nan

for a in tqdm(range((len(agents)))): # for each network
    wm_idx = [0] + list(np.where(agents[a].wm_flags)[0] + 1)
    for b, wm in enumerate([np.inf] + list(np.arange(2, len(list(agents[a].parameters()))))): # for each weight matrix
        agnt = copy.deepcopy(agents[a])
        for c, param in enumerate(agnt.parameters()):
            if c == wm: 
                param.data = torch.zeros(*param.shape)

        fitness[a, wm_idx.pop(0)] = multimodal_mazes.eval_fitness(genome=None, config=None, channels=[1,1], sensor_noise_scale=sensor_noise_scale, drop_connect_p=0.0, maze=maze, n_steps=n_steps, agnt=agnt)

fitness_norm = fitness - fitness[:,0][:,None]

# np.save("./fitness", fitness)
# np.save("./fitness_norm", fitness_norm)

In [None]:
fitness = np.load("./fitness.npy")
fitness_norm = np.load("./fitness_norm.npy")

interest = [1, 128, 74, 36]
i_cols = ['xkcd:dusty blue', 'xkcd:purple', 'xkcd:dark seafoam', 'xkcd:yellowish']
i_labels = ['Feedforward (L)', 'Fully recurrent', 'Most accurate', 'Most robust']

In [None]:
plt.scatter(x=np.tile(np.arange(0,8), reps=fitness.shape[0]), y=fitness.reshape(-1), color='k', marker='.', alpha=0.5)
plt.ylim([-0.05, 1.05])

for a, i in enumerate(interest):
    plt.scatter(x=np.tile(np.arange(0,8), reps=10), y=fitness[np.array(idxs) == i].reshape(-1), color=i_cols[a], marker='.', alpha=0.5)

plt.xticks(ticks=np.arange(8), labels=("N", "L0", "L1", "L2", "S0", "S1", "B0", "B1"))
plt.xlabel('Ablation')
plt.ylabel("Fitness")

In [None]:
plt.errorbar(x=range(8),y=np.nanmean(fitness_norm, axis=0), yerr=np.nanstd(fitness_norm, axis=0), color='k', ls='None', marker='o')
plt.xticks(ticks=np.arange(8), labels=("N", "L0", "L1", "L2", "S0", "S1", "B0", "B1"));
plt.xlabel('Ablation')
plt.ylabel("Change in fitness")

In [None]:
offsets = np.linspace(start=-0.1, stop=0.1, num=4)
for a, i in enumerate(interest):
    plt.errorbar(
        x=range(8) + np.repeat(offsets[a], repeats=8),
        y=np.nanmean(fitness_norm[np.array(idxs) == i], axis=0), 
        yerr=np.nanstd(fitness_norm[np.array(idxs) == i], axis=0), 
        color=i_cols[a], ls='None', marker='o')

plt.xticks(ticks=np.arange(8), labels=("N", "L0", "L1", "L2", "S0", "S1", "B0", "B1"));
plt.xlabel('Ablation')
plt.ylabel("Change in fitness")

### Memory

In [None]:
# WIP: 

# Collect data 
exp_config = multimodal_mazes.load_exp_config("../exp_config.ini")
noises = np.linspace(start=0.0, stop=0.5, num=11)

maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=1000, noise_scale=0, gaps=1)

results = []
v = -1
agnt = copy.deepcopy(agents[v])

# Test
results, input_sensitivity, memory = multimodal_mazes.test_dqn_agent(
    maze_test=maze,
    agnt=agnt,
    exp_config=exp_config,
    noises=noises,
)

In [None]:
# As partial derivatives 
from torch.autograd.functional import jacobian

def forward(*states):
    """
    Arguments:
        states: a tuple storing tensors of:
            inputs, prev_inputs, hidden, prev_outputs, outputs
            from multiple time points. 
            I.e. inputs occur every 5th tensor. 
    Returns:
        agnt.outputs: output activations at t.
    """

    for t, input in enumerate(states[::5]):

        agnt.channel_inputs = input

        if t == 0:
            agnt.outputs, prev_input, hidden, prev_output = agnt.forward(
                states[1], states[2], states[3], tensor_input=True)
        else: 
             agnt.outputs, prev_input, hidden, prev_output = agnt.forward(
                prev_input, hidden, prev_output, tensor_input=True)
            
    return agnt.outputs

In [None]:
# Collect data 
sensor_noise_scale = 0.05
n_steps = 6 

maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=1000, noise_scale=0, gaps=1)

results = []
for v in tqdm(range(len(agents))):
    agnt = copy.deepcopy(agents[v])

    # Fitness
    _, all_states = multimodal_mazes.eval_fitness(
        genome=None,
        config=None,
        channels=[1,1],
        sensor_noise_scale=sensor_noise_scale,
        drop_connect_p=0.0,
        maze=maze,
        n_steps=n_steps,
        agnt=copy.deepcopy(agnt),
        record_states=True,
    )

    # Calculate 
    memory = [[] for _ in range(n_steps)]

    for a in range(len(all_states)): # for each trial

        trial_states = tuple(itertools.chain.from_iterable(all_states[a])) # a tuple of tensor states

        for b, _ in enumerate(trial_states[::5]): # for each time point 
            
            jm = jacobian(forward, trial_states[:(5 * (b+1))]) # a tuple of Jacobians (one per state)

            for c, j in enumerate(jm[::5][::-1]): # for each (sensor) input state (from t backwards)
                memory[c].append(torch.norm(j, p="fro") / torch.norm(jm[::5][-1], p="fro")) # append the norm divided by the norm at time t

    # Store 
    tmp = []
    for i in range(n_steps):
        memory[i] = np.array(memory[i])
        memory[i][np.isinf(memory[i])] = np.nan
        tmp.append(np.nanmean(memory[i]))
    results.append(tmp)

results = np.array(results)

In [None]:
# Figure 
plt.plot(results[np.array(idxs)==1, ::-1].T, c=i_cols[0], alpha=0.75);
plt.plot(results[np.array(idxs)==128, ::-1].T, c=i_cols[1], alpha=0.75);
plt.hlines(y=1.0, xmin=0, xmax=5, color='k', alpha=0.25)

plt.xticks(np.arange(n_steps), labels=np.arange(n_steps)[::-1]*-1)
plt.xlabel('Time')
plt.ylabel('Input-output influence')

In [None]:
# As estimated MI 

# Calculating sensitivity
sensor_noise_scale = 0.05
n_steps = 6 

maze = multimodal_mazes.TrackMaze(size=11, n_channels=2)
maze.generate(number=1000, noise_scale=0, gaps=1)

mi_norms = [[] for _ in interest]
for a, i in enumerate(tqdm(interest)): # for each architecture
    for _, v in enumerate(np.where(np.array(idxs) == i)[0]): # for each network
        agnt = copy.deepcopy(agents[v])
        all_states = []        

        # Collect states
        _, all_states = multimodal_mazes.eval_fitness(
            genome=None, config=None, channels=agnt.channels, 
            sensor_noise_scale=sensor_noise_scale, drop_connect_p=0.0, 
            maze=maze, n_steps=n_steps, agnt=agnt, record_states=True)

        # Estimate memory
        mis = multimodal_mazes.estimate_dqn_memory(all_states=all_states, agnt=agnt, n_steps=n_steps)

        mi_norms[a].append(mis)

for a, i in enumerate(interest):
    mi_norms[a] = np.array(mi_norms[a])

In [None]:
# Plot
for a, i in enumerate(interest):    
    x = range(mi_norms[a].shape[1])
    y = np.nanmean(mi_norms[a], axis=0)
    sem = np.nanstd(mi_norms[a], axis=0) / np.sqrt(mi_norms[a].shape[0])
    plt.plot(x, y, color=i_cols[a], label=i_labels[a])
    plt.fill_between(x, y-sem, y+sem, color=i_cols[a], alpha=0.25, edgecolor=None)

plt.ylim([-0.05,1.50])
plt.legend()
plt.xlabel('Temporal lag')
plt.ylabel("Estimated MI")

In [None]:
# Plot
for a, i in enumerate(interest):    
    plt.scatter(np.ones(10) * a, np.sum(mi_norms[a], axis=1), color=i_cols[a])

plt.xlabel('Temporal lag')
plt.ylabel("Estimated MI")

In [None]:
# First pass 

L_input = [0.1,0.1]
R_inputs = [[0.1, 0.0], [0.0, 0.1]]

for a, i in enumerate(interest): # for each architecture 
    L_outputs = []
    for i_num, v in enumerate(np.where(np.array(idxs) == i)[0]): # for each network
        agnt = copy.deepcopy(agents[v])

        for b in range(2):
            l_outputs = [] 
            
            # Reset agent
            agnt.prev_input *= 0.0
            agnt.hidden *= 0.0
            agnt.prev_output *= 0.0
            agnt.outputs *= 0.0
            agnt.channel_inputs *= 0.0    

            for t in range(6):

                if t == 0:
                    agnt.channel_inputs[0] = np.copy(L_input)
                    agnt.channel_inputs[1] = np.copy(R_inputs[b])
                else: 
                    agnt.channel_inputs *= 0.0

                agnt.policy()
                l_outputs.append(np.copy(agnt.outputs[0]))
            L_outputs.append(l_outputs)

    x = range(np.array(L_outputs).shape[1])
    y = np.nanmean(np.array(L_outputs),axis=0)
    sem = np.nanstd(np.array(L_outputs),axis=0) / np.sqrt(i_num + 1)
    plt.plot(x, y, color=i_cols[a], label=i_labels[a])
    plt.fill_between(x, y-sem, y+sem, color=i_cols[a], alpha=0.25, edgecolor=None)

plt.ylim([-0.05,1.05])
plt.legend()
plt.xlabel('Time since input')
plt.ylabel("p(Left)")

### Noise suppression

In [None]:
for a, i in enumerate(interest): # for each architecture 
    L_outputs = []
    for i_num, v in enumerate(np.where(np.array(idxs) == i)[0]): # for each network
        agnt = copy.deepcopy(agents[v])

        for _ in range(100):
            l_outputs = [] 
            
            # Reset agent
            agnt.prev_input *= 0.0
            agnt.hidden *= 0.0
            agnt.prev_output *= 0.0
            agnt.outputs *= 0.0
            agnt.channel_inputs *= 0.0    

            for t in range(6):
                agnt.channel_inputs = np.random.normal(
                    loc=0.0, scale=0.05, size=agnt.channel_inputs.shape
                )
                agnt.channel_inputs = np.clip(agnt.channel_inputs, a_min=0.0, a_max=1.0)
                agnt.policy()
                l_outputs.append(np.std(np.array(agnt.outputs)))
            L_outputs.append(l_outputs)

    x = range(np.array(L_outputs).shape[1])
    y = np.nanmean(np.array(L_outputs),axis=0)
    sem = np.nanstd(np.array(L_outputs),axis=0) / np.sqrt(i_num + 1)
    plt.plot(x, y, color=i_cols[a], label=i_labels[a])
    plt.fill_between(x, y-sem, y+sem, color=i_cols[a], alpha=0.25, edgecolor=None)

plt.ylim([-0.05,0.45])
plt.legend()
plt.xlabel('Time with noise input')
plt.ylabel("Std(p(outputs))")