In [None]:
from tqdm import tqdm
%run model_evaluation
import torch
from torch import nn, optim
from sklearn.metrics import accuracy_score, confusion_matrix
from collections import defaultdict

import matplotlib.pyplot as plt
import proplot as pplt
import umap

# model, obs_rms, kwargs = load_model_and_env('nav_auxiliary_tasks/nav_aux_wall_1', 0)
# env = gym.make('NavEnv-v0', **kwargs)

save = 'plots/representation_heatmaps/'

%run representation_analysis
%run model_evaluation


def gaussian_smooth(pos, y, extent=(5, 295), num_grid=30, sigma=10,
                    ret_hasval=False):
    # a = stacked['shared_activations'][0, :, 0].numpy()
    y = np.array(y)
    
    grid = np.linspace(extent[0], extent[1], num_grid)
    xs, ys = np.meshgrid(grid, grid)
    ys = ys[::-1]
    smoothed = np.zeros(xs.shape)
    hasval = np.zeros(xs.shape)
    for i in range(num_grid):
        for j in range(num_grid):
            p = np.array([xs[i, j], ys[i, j]])
            dists = np.sqrt(np.sum((pos - p)**2, axis=1))
            g = np.exp(-dists**2 / (2*sigma**2))
            
            if len(g[g > 0.1]) < 1:
                val = 0
            else:
                val = np.sum(y[g > 0.1] * g[g > 0.1]) / np.sum(g[g > 0.1])
                hasval[i, j] = 1

            smoothed[i, j] = val
    if ret_hasval:
        return smoothed, hasval
    else:
        return smoothed


def clean_eps(eps, prune_first=5, activations_key='shared_activations',
             activations_layer=0, clip=False,
             save_inview=True, save_seen=True):
    '''Clean up an eps data dictionary collected from evalu for heatmapping'''
    dones = eps['dones'].copy()
    pos = np.vstack(eps['data']['pos'])
    stacked = stack_activations(eps['activations'])
    angles = eps['data']['angle']
    acts = eps['actions']
    
    activ = stacked[activations_key][activations_layer, :, :].numpy()
    pinview = np.array(eps['data']['poster_in_view'])
    pseen = np.array(eps['data']['poster_seen'])
    
    ep_activ = split_by_ep(activ, dones)
    ep_pos = split_by_ep(pos, dones)
    ep_pinview = split_by_ep(pinview, dones)
    ep_angle = split_by_ep(angles, dones)
    ep_pseen = split_by_ep(pseen, dones)
    ep_acts = split_by_ep(acts, dones)
    
    if prune_first and prune_first > 0:
        prune_first = 5
        pruned_ep_activ = [a[prune_first:] for a in ep_activ]
        pruned_activ = np.vstack(pruned_ep_activ)
        pruned_ep_pos = [p[prune_first:] for p in ep_pos]
        pruned_pos = np.vstack(pruned_ep_pos)
        pruned_ep_pinview = [p[prune_first:] for p in ep_pinview]
        pruned_pinview = np.concatenate(pruned_ep_pinview)
        pruned_ep_angles = [p[prune_first:] for p in ep_angle]
        pruned_angles = np.concatenate(pruned_ep_angles)
        pruned_ep_pseen = [p[prune_first:] for p in ep_pseen]
        pruned_pseen = np.concatenate(pruned_ep_pseen)
        pruned_ep_acts = [p[prune_first:] for p in ep_acts]
        pruned_acts = np.concatenate(pruned_ep_acts)
        
        pos = pruned_pos
        activ = pruned_activ
        pinview = pruned_pinview
        angles = pruned_angles
        pseen = pruned_pseen
        acts = pruned_acts
    
    if clip:
        activ = np.clip(activ, 0, 1)
    
    result_dict = {
        'pos': pos,
        'activ': activ,
        'pinview': pinview,
        'pseen': pseen,
        'angles': angles,
        'dones': dones,
        'actions': acts
    }
    
    if save_inview:
        result_dict.update({
            'pos_inview': pos[pinview],
            'pos_notinview': pos[~pinview],
            'activ_inview': activ[pinview],
            'activ_notinview': activ[~pinview],
            'angles_inview': angles[pinview],
            'angles_notinview': angles[~pinview],
        })
    if save_seen:
        result_dict.update({'pos_seen': pos[pseen],
        'pos_notseen': pos[~pseen],
        'activ_seen': activ[pseen],
        'activ_notseen': activ[~pseen],
        'angles_seen': angles[pseen],
        'angles_notseen': angles[~pseen],
        })
    
    return result_dict
    
    
def stack_all_ep(all_ep):
    '''
    When making a list of results from multiple evalu calls,
    this function can be called to put the relevant data into a single dict to be
    passed to clean_eps for processing
    '''
    dones = np.concatenate([ep['dones'] for ep in all_ep])
    pos = np.vstack([ep['data']['pos'] for ep in all_ep])
    angles = np.concatenate([ep['data']['angle'] for ep in all_ep])
    pseen = np.concatenate([ep['data']['poster_seen'] for ep in all_ep])
    pinview = np.concatenate([ep['data']['poster_in_view'] for ep in all_ep])
    actions = np.vstack([np.vstack(ep['actions']) for ep in all_ep]).squeeze()
    activations = []
    for ep in all_ep:
        activations += ep['activations']

    eps = {
        'dones': dones,
        'activations': activations,
        'actions': actions,
        'data': {
            'pos': pos,
            'angle': angles,
            'poster_seen': pseen,
            'poster_in_view': pinview
        }
    }
    return eps
    

    
def split_by_angle(target, angles):
    splits = {
        0: [-np.pi/4, np.pi/4],
        1: [np.pi/4, 3*np.pi/4],
        3: [-3*np.pi/4, -np.pi/4],
        2: None #this will use else statement otherwise bounds are annoying
    }
    all_trues = np.zeros(angles.shape) == 1
    result = {}
    
    for s in [0, 1, 3]:
        split = splits[s]
        split_idxs = (split[0] <= angles) & (angles <= split[1])
        all_trues = all_trues | split_idxs
        
        result[s] = target[split_idxs]
    #finally, the ones that didn't fit into any of the other quadrants
    result[2] = target[~all_trues]
    
    return result
    
        
    
def compute_directness(all_ep=None, ep=None, pos=None):
    '''
    Compute the directness of paths taken either from an all_ep (split up
    eps generated from appending evalu() calls) or from a single ep
    '''
    goal_loc = np.array([250, 70])
    if all_ep is None and ep is None and pos is None:
        raise Exception('No proper parameters given')

    if all_ep is not None:
        directnesses = []
        for i in range(len(all_ep)):
            p = np.vstack(all_ep[i]['data']['pos'])
            d = p - goal_loc
            d = np.sqrt(np.sum(d**2, axis=1))
            dist_changes = np.diff(d)
            directness = np.sum(dist_changes[:-1] < 0) / np.sum(dist_changes[:-1] != 0)
            directnesses.append(directness)
        return np.array(directnesses)
    else:
        if ep is not None:
            p = np.vstack(ep['data']['pos'])
        elif pos is not None:
            p = pos
        d = p - goal_loc
        d = np.sqrt(np.sum(d**2, axis=1))
        dist_changes = np.diff(d)
        directness = np.sum(dist_changes[:-1] < 0) / np.sum(dist_changes[:-1] != 0)
        return directness
    
        
            
            
def filter_all_ep_directness(all_ep, bound=0.9):
    d = compute_directness(all_ep)
    idxs = d > 0.9
    d_ep = [ep for i, ep in enumerate(all_ep) if idxs[i]]
    return d_ep



def load_heatmaps(file='data/pdistal_rim_heatmap/rim_heatmaps'):
    all_heatmaps = pickle.load(open(file, 'rb'))

    heatmaps = []
    heatmap_idx_to_model = []
    heatmap_model_to_idxs = {}
    widths = [4, 8, 16, 32, 64]
    trials = 3

    current_idx = 0
    for width in widths:
        heatmap_model_to_idxs[width] = []
        for trial in range(trials):
            heatmaps.append(all_heatmaps[width][trial])

            #create indexers to map back and forth between heatmap idxs and models
            for i in range(width):
                heatmap_idx_to_model.append([width, trial, i])
            heatmap_model_to_idxs[width].append([current_idx, current_idx+width])
            current_idx = current_idx + width

    heatmaps = np.clip(np.vstack(heatmaps).reshape(372, 900), 0, 1)
    return heatmaps, heatmap_idx_to_model, heatmap_model_to_idxs


def count_labels(clabels, ignore_cluster=None, remove_zeros=False):
    #Convert a list of cluster labels into ratios
    cluster_counts = np.zeros(num_clusters)
    for i in range(num_clusters):
        cluster_counts[i] = np.sum(clabels == i)
        
    if ignore_cluster is not None:
        if type(ignore_cluster) == list:
            for c in ignore_cluster:
                cluster_counts[c] = 0
        elif type(ignore_cluster) == int:
            cluster_counts[ignore_cluster] = 0
    
    cluster_ratios = cluster_counts / np.sum(cluster_counts)
    
    if remove_zeros:
        cluster_ratios = cluster_ratios[cluster_ratios != 0]
        cluster_counts = cluster_counts[cluster_counts != 0]
    return cluster_counts, cluster_ratios



def pred_kmeans(heatmaps, kmeans):
    '''
    Given a list of heatmaps, perform necessary reshaping and predict cluster with kmeans
    '''
    hms = np.vstack([hm.reshape(1, -1) for hm in heatmaps])
    labels = kmeans.predict(hms)
    return labels

## Preliminary Exploration

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width64_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width64_filt_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width64_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width32_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width16_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width8_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width4_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width3_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
fig, ax = pplt.subplots(ncols=4)
for i in range(3):
    ep = pickle.load(open(f'data/pdistal_rim_heatmap/width2_t{i}', 'rb'))
    
    ax[i].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    ax[3].scatter(ep['pos'].T[0], ep['pos'].T[1], alpha=0.2)
    

In [None]:
ep = pickle.load(open(f'data/pdistal_rim_heatmap/width32_t2', 'rb'))
fig, ax = pplt.subplots(nrows=4, ncols=8, wspace=0, hspace=0)
for i in range(32):
    heatmap = np.clip(gaussian_smooth(ep['pos'], ep['activ'][:, i]), 0, 1)
    ax[i].imshow(heatmap, extent=(5, 295, 5, 295))

In [None]:
model_name1 = 'nav_poster_netstructure/nav_pdistal_width64batch200'
model1, obs_rms1, kwargs = load_model_and_env(model_name1, 2)
model_name2 = 'nav_poster_netstructure/nav_pdistal_width32batch200'
model2, obs_rms2, kwargs = load_model_and_env(model_name2, 2)

#Starting around rim - First generate start points and angles
WINDOW_SIZE = (300, 300)
step_size = 10.
xs = np.arange(0+step_size, WINDOW_SIZE[0], step_size)
ys = np.arange(0+step_size, WINDOW_SIZE[1], step_size)
# thetas = np.linspace(0, 2*np.pi, 12, endpoint=False)
start_points = []
start_angles = []
for x in xs:
    for y in [5., 295.]:
        point = np.array([x, y])
        angle = np.arctan2(150 - y, 150 - x)
        start_points.append(point)
        start_angles.append(angle)
for y in ys:
    for x in [5, 295]:
        point = np.array([x, y])
        angle = np.arctan2(150 - y, 150 - x)
        start_points.append(point)
        start_angles.append(angle)
        
start_points = np.vstack(start_points)

all_ep1 = []
for i in range(len(start_points)):
    kw = kwargs.copy()
    kw['fixed_reset'] = [start_points[i].copy(), start_angles[i].copy()]
    ep = forced_action_evaluate(model1, obs_rms1, seed=0, num_episodes=1, 
                                env_kwargs=kw, data_callback=poster_data_callback,
                                with_activations=True)
    all_ep1.append(ep)
saved_actions = [ep['actions'] for ep in all_ep1]


#Force CW actions to CCW model
all_ep2 = []
for i in range(len(start_points)):
    copied_actions = lambda step: saved_actions[i][step]
    kw = kwargs.copy()
    kw['fixed_reset'] = [start_points[i].copy(), start_angles[i].copy()]
    ep = forced_action_evaluate(model2, obs_rms2, seed=0, num_episodes=1, 
                                env_kwargs=kw, data_callback=poster_data_callback,
                                with_activations=True, forced_actions=copied_actions)
    all_ep2.append(ep)
# saved_actions = [ep['actions'] for ep in all_ep2]

In [None]:
ep = clean_eps(stack_all_ep(all_ep2), prune_first=0)

fig, ax = pplt.subplots(nrows=4, ncols=8, wspace=0, hspace=0)
for i in range(32):
    heatmap = np.clip(gaussian_smooth(ep['pos'], ep['activ'][:, i]), 0, 1)
    ax[i].imshow(heatmap, extent=(5, 295, 5, 295))

# Data Collection (Saved files in each section in parenthesis)

In [None]:
all_ep = [[],[],[]]
width = 64

#Starting around rim - First generate start points and angles
WINDOW_SIZE = (300, 300)
step_size = 10.
xs = np.arange(0+step_size, WINDOW_SIZE[0], step_size)
ys = np.arange(0+step_size, WINDOW_SIZE[1], step_size)
# thetas = np.linspace(0, 2*np.pi, 12, endpoint=False)
start_points = []
start_angles = []
for x in xs:
    for y in [5., 295.]:
        point = np.array([x, y])
        angle = np.arctan2(150 - y, 150 - x)
        start_points.append(point)
        start_angles.append(angle)
for y in ys:
    for x in [5, 295]:
        point = np.array([x, y])
        angle = np.arctan2(150 - y, 150 - x)
        start_points.append(point)
        start_angles.append(angle)
        
for trial in range(3):
    model_name = f'nav_poster_netstructure/nav_pdistal_width64batch200'
    model, obs_rms, kwargs = load_model_and_env(model_name, trial)

    all_ep = []
    for i in range(len(start_points)):
        kw = kwargs.copy()
        kw['fixed_reset'] = [start_points[i].copy(), start_angles[i].copy()]
        ep = forced_action_evaluate(model, obs_rms, seed=0, num_episodes=1, 
                                    env_kwargs=kw, data_callback=poster_data_callback,
                                    with_activations=True)
        all_ep.append(ep)

    all_ep_f = filter_all_ep_directness(all_ep)
    eps_f = clean_eps(stack_all_ep(all_ep_f), prune_first=0, save_inview=False, save_seen=False)
    eps = clean_eps(stack_all_ep(all_ep), prune_first=0, save_inview=False, save_seen=False)
    
    saved_actions = [ep['actions'] for ep in all_ep]
    saved_actions_f = [ep['actions'] for ep in all_ep_f]
    
    pickle.dump(saved_actions, open(f'data/pdistal_rim_heatmap/width{width}_t{trial}_acts', 'wb'))
    pickle.dump(saved_actions_f, open(f'data/pdistal_rim_heatmap/width{width}_filt_t{trial}_acts', 'wb'))
    
    pickle.dump(eps, open(f'data/pdistal_rim_heatmap/width{width}_t{trial}', 'wb'))
    pickle.dump(eps_f, open(f'data/pdistal_rim_heatmap/width{width}_filt_t{trial}', 'wb'))
    

## Constructing some consistent trajectories to use (width64_comb)

Here, we generate conserved trajectories to test all our agents on. The trajectories consist of the most direct trajectories with the good overall spatial coverage so we can accurately test activations without hallucinations across the entire area

**Final paths are saved in data/pdistal_rim_heatmap/width64_comb** which contains (combined_actions, start_points, start_angles)

In [None]:
eps = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}', 'rb'))
saved_actions = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}_acts', 'rb'))

In [None]:
#Figuring out where the trajectories differ

start = 24

fig, ax = pplt.subplots(ncols=3)
ax.format(xlim=[0, 300], ylim=[0, 300])

for trial in range(3):
    eps = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}', 'rb'))
    saved_actions = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}_acts', 'rb'))

    pos = eps['pos']
    dones = eps['dones']
    all_pos = np.vstack(split_by_ep(pos, dones)[start:end])
    # plt.scatter(eps['pos'][start:end].T[0], eps['pos'][start:end].T[1])
    ax[trial].scatter(all_pos.T[0], all_pos.T[1], alpha=0.2)


In [None]:
exp = 31

fig, ax = pplt.subplots(ncols=3)
ax.format(xlim=[0, 300], ylim=[0, 300])

for trial in range(3):
    eps = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}', 'rb'))
    saved_actions = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}_acts', 'rb'))

    pos = eps['pos']
    dones = eps['dones']
    all_pos = np.vstack(split_by_ep(pos, dones)[exp])
    # plt.scatter(eps['pos'][start:end].T[0], eps['pos'][start:end].T[1])
    ax[trial].scatter(all_pos.T[0], all_pos.T[1], alpha=0.2)


**Combine trial 1 for fuller coverage on north trajectories and trial 3 for general improved directness of trajectories**

In [None]:
#Generate Conserved Trajectories

trials_from_t1 = [29, 31]
eps1 = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t0', 'rb'))
saved_actions1 = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t0_acts', 'rb'))
eps2 = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t2', 'rb'))
saved_actions2 = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t2_acts', 'rb'))

#Check which episodes have enough directness to be considered for these conserved trajectories
all_pos = split_by_ep(eps2['pos'], eps2['dones'])
keep_idxs = []
for i, pos in enumerate(all_pos):
    direct = compute_directness(pos=pos)
    if direct > 0.9:
        keep_idxs.append(i)
keep_idxs = np.array(keep_idxs)

combined_eps = {}

for key in eps1:
    comb_d = []
    d1 = split_by_ep(eps1[key], eps1['dones'])
    d2 = split_by_ep(eps2[key], eps2['dones'])
    
    for t in range(len(d1)):
        if t in trials_from_t1:
            d = d1[t]
        else:
            d = d2[t]
        
        if t in keep_idxs:
            comb_d.append(d)
        
    if len(comb_d[0].shape) == 2:
        comb_d = np.vstack(comb_d)
    else:
        comb_d = np.concatenate(comb_d)
    
    combined_eps[key] = comb_d
    

pickle.dump(combined_eps, open(f'data/pdistal_rim_heatmap/width64_comb', 'wb'))

combined_actions = []
for t in range(len(saved_actions1)):
    if t not in keep_idxs:
        continue
    if t in trials_from_t1:
        combined_actions.append(saved_actions1[t])
    else:
        combined_actions.append(saved_actions2[t])
        
keep_start_points = np.array(start_points)[keep_idxs]
keep_start_angles = np.array(start_angles)[keep_idxs]
#Save the actions, start points and angles
pickle.dump([combined_actions, keep_start_points, keep_start_angles], open(f'data/pdistal_rim_heatmap/width64_comb_acts', 'wb'))


In [None]:

fig, ax = pplt.subplots(ncols=4)
ax.format(xlim=[0, 300], ylim=[0, 300])

for trial in range(3):
    eps = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}', 'rb'))
    saved_actions = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}_acts', 'rb'))

    pos = eps['pos']
    ax[trial].scatter(pos.T[0], pos.T[1], alpha=0.2)
    
p = combined_eps['pos']
ax[3].scatter(p.T[0], p.T[1], alpha=0.2)

## Collect activations along fixed trajectories for each width network and trial of agent (width{width}_copied)

To generate activation heatmaps, we record the activations of each model along the fixed trajectories from 1.1 (taken from effectively optimal 64 width agents)

In [None]:
widths = [2, 3, 4, 8, 16, 32, 64]
num_trials = 10

combined_actions, keep_start_points, keep_start_angles = pickle.load(open(f'data/pdistal_rim_heatmap/width64_comb_acts', 'rb'))


for width in widths:
    all_eps = []

    for trial in tqdm(range(num_trials)):
        model_name = f'nav_poster_netstructure/nav_pdistal_width{width}batch200'
        model, obs_rms, kwargs = load_model_and_env(model_name, trial)

        all_ep = []
        for i in range(len(keep_start_points)):
            copied_actions = lambda step: combined_actions[i][step]
            kw = kwargs.copy()
            kw['fixed_reset'] = [keep_start_points[i].copy(), keep_start_angles[i].copy()]
            ep = forced_action_evaluate(model, obs_rms, seed=0, num_episodes=1, 
                                        env_kwargs=kw, data_callback=poster_data_callback,
                                        with_activations=True, forced_actions=copied_actions)
            all_ep.append(ep)
        
        eps = clean_eps(stack_all_ep(all_ep), prune_first=0, save_inview=False, save_seen=False)
        all_eps.append(eps)
        
        pickle.dump(all_eps, open(f'data/pdistal_rim_heatmap/width{width}_copied', 'wb'))


## Convert collected activations to heatmaps through smoothing (rim_heatmaps)

Use Gaussian smoothing to generate activation heatmaps from the forced trajectory activations for each agent

In [None]:
widths = [2, 3, 4, 8, 16, 32, 64]

all_heatmaps = {}

for width in tqdm(widths):
    all_heatmaps[width] = []
    all_eps = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_copied', 'rb'))
    
    for eps in all_eps:
        heatmaps = []
        
        p = eps['pos']
        a = eps['activ']
        
        for i in range(a.shape[1]):
            heatmap = gaussian_smooth(p, a[:, i])
            heatmaps.append(heatmap)
            
        all_heatmaps[width].append(heatmaps)
        
pickle.dump(all_heatmaps, open('data/pdistal_rim_heatmap/rim_heatmaps', 'wb'))


## Collect trajectories along rim with original policies to compute directness stats (width{width})

Next, with the same rim initial conditions, allow the agents to perform their original policy to collect behavior statistics like directness

In [None]:
widths = [2, 3, 4, 8, 16, 32, 64]
num_trials = 10

#Starting around rim - First generate start points and angles
WINDOW_SIZE = (300, 300)
step_size = 10.
xs = np.arange(0+step_size, WINDOW_SIZE[0], step_size)
ys = np.arange(0+step_size, WINDOW_SIZE[1], step_size)
# thetas = np.linspace(0, 2*np.pi, 12, endpoint=False)
start_points = []
start_angles = []
for x in xs:
    for y in [5., 295.]:
        point = np.array([x, y])
        angle = np.arctan2(150 - y, 150 - x)
        start_points.append(point)
        start_angles.append(angle)
for y in ys:
    for x in [5, 295]:
        point = np.array([x, y])
        angle = np.arctan2(150 - y, 150 - x)
        start_points.append(point)
        start_angles.append(angle)
        
start_points = np.vstack(start_points)

def filter_all_ep_directness(all_ep, bound=0.9):
    d = compute_directness(all_ep)
    idxs = d > 0.9
    d_ep = [ep for i, ep in enumerate(all_ep) if idxs[i]]
    return d_ep


for width in widths:
    all_eps = []
    all_eps_f = []
    
    for trial in tqdm(range(num_trials)):
        model_name = f'nav_poster_netstructure/nav_pdistal_width{width}batch200'
        model, obs_rms, kwargs = load_model_and_env(model_name, trial)

        all_ep = []
        for i in range(len(start_points)):
            kw = kwargs.copy()
            kw['fixed_reset'] = [start_points[i].copy(), start_angles[i].copy()]
            ep = forced_action_evaluate(model, obs_rms, seed=0, num_episodes=1, 
                                        env_kwargs=kw, data_callback=poster_data_callback,
                                        with_activations=True)
            all_ep.append(ep)
        
        all_ep_f = filter_all_ep_directness(all_ep)
        eps_f = clean_eps(stack_all_ep(all_ep_f), prune_first=0, save_inview=False, save_seen=False)
        eps = clean_eps(stack_all_ep(all_ep), prune_first=0, save_inview=False, save_seen=False)
        
        all_eps.append(eps)
        all_eps_f.append(eps_f)
        
        pickle.dump(all_eps, open(f'data/pdistal_rim_heatmap/width{width}', 'wb'))
        pickle.dump(all_eps_f, open(f'data/pdistal_rim_heatmap/width{width}_filt', 'wb'))

## Generate KMeans model (kmeans_heatmap_clusterer) and summarize clustering results and behaviors

Generate KMeans clusterer model and summarize the data from width{width} files as well as clustered. The second block here needs to be run in order to produce figures in section 4

In [None]:
from sklearn.cluster import KMeans
from scipy.stats import entropy

heatmaps, heatmap_idx_to_model, heatmap_model_to_idxs = load_heatmaps()
num_clusters = 9
kmeans = KMeans(n_clusters=num_clusters, n_init=100, random_state=0)
preds = kmeans.fit_predict(heatmaps)

cluster_idxs = []
for i in range(9):
    cluster_idxs.append(preds == i)

cluster_freqs = [cluster.sum() for cluster in cluster_idxs]
print(entropy(cluster_freqs))

fig, ax = pplt.subplots(refwidth=2)
ax.bar(np.arange(9), [cluster.sum() for cluster in cluster_idxs])

pickle.dump(kmeans, open('data/pdistal_rim_heatmap/kmeans_heatmap_clusterer', 'wb'))

### RUN FOR SECTION 4

In [None]:
'''
Older code that puts results into list for plots prior to 4.4. Probably will want to update
those plotting code to use next block dictionary entries
'''

heatmaps, heatmap_idx_to_model, heatmap_model_to_idxs = load_heatmaps()
kmeans = pickle.load(open('data/pdistal_rim_heatmap/kmeans_heatmap_clusterer', 'rb'))

num_clusters = 9
widths = [2, 3, 4, 8, 16, 32, 64]
trials = 10

def convert_labels_to_ratios(clabels):
    #Convert a list of cluster labels into ratios
    cluster_ratios = np.zeros(num_clusters)
    for i in range(num_clusters):
        cluster_ratios[i] = np.sum(clabels == i)
    cluster_ratios = cluster_ratios / len(clabels)
    return cluster_ratios


# Summarize clustering and behavior statistics
all_heatmaps = pickle.load(open('data/pdistal_rim_heatmap/rim_heatmaps', 'rb'))
results = {}
for width in widths:
    results[width] = []
    all_eps = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}', 'rb'))
    for trial in range(trials):
        eps = all_eps[trial]
        directness = compute_directness(pos=eps['pos'])
        
        ep_dones = split_by_ep(eps['dones'], eps['dones'])
        ep_lens = np.array([ep.shape[0] for ep in ep_dones])
        success_rate = 1 - np.sum(ep_lens == 202) / len(ep_lens)
        average_ep_len = np.mean(ep_lens)
        average_succ_ep_len = np.mean(ep_lens[ep_lens < 202])
        
        acts = eps['actions']
        act_ratios = np.array([np.sum(acts == i) for i in range(4)]) / len(acts)
        
        hms = np.vstack([hm.reshape(1, -1) for hm in all_heatmaps[width][trial]])
        labels = kmeans.predict(hms)
        ratios = convert_labels_to_ratios(labels)
        
        
        results[width].append([labels, ratios, directness, success_rate, average_ep_len, average_succ_ep_len, act_ratios])
        

'''
results is a dict indexed by width. Values are lists of model results
Each list item is [labels, ratios, directness, success_rate, av_ep_len, av_succ_ep_len]
'''

In [None]:
'''
Newer code with dictionary entries
'''

heatmaps, heatmap_idx_to_model, heatmap_model_to_idxs = load_heatmaps()
kmeans = pickle.load(open('data/pdistal_rim_heatmap/kmeans_heatmap_clusterer', 'rb'))

num_clusters = 9
widths = [2, 3, 4, 8, 16, 32, 64]
trials = 10

# Summarize clustering and behavior statistics
all_heatmaps = pickle.load(open('data/pdistal_rim_heatmap/rim_heatmaps', 'rb'))
results = {}
for width in widths:
    results[width] = []
    all_eps = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}', 'rb'))
    for trial in range(trials):
        eps = all_eps[trial]
        directness = compute_directness(pos=eps['pos'])
        
        ep_dones = split_by_ep(eps['dones'], eps['dones'])
        ep_lens = np.array([ep.shape[0] for ep in ep_dones])
        success_rate = 1 - np.sum(ep_lens == 202) / len(ep_lens)
        average_ep_len = np.mean(ep_lens)
        average_succ_ep_len = np.mean(ep_lens[ep_lens < 202])
        
        acts = eps['actions']
        act_ratios = np.array([np.sum(acts == i) for i in range(4)]) / len(acts)
        
        hms = np.vstack([hm.reshape(1, -1) for hm in all_heatmaps[width][trial]])
        labels = kmeans.predict(hms)
        _, ratios = count_labels(labels, remove_zeros=False)
        _, nonzero = count_labels(labels, remove_zeros=True)
        hprime = np.sum(-nonzero * np.log(nonzero))        

        
        
        results[width].append({
            'cluster_labels': labels, 
            'cluster_ratios': ratios, 
            'directness': directness, 
            'success_rate': success_rate, 
            'avg_ep_len': average_ep_len, 
            'avg_succ_ep_len': average_succ_ep_len, 
            'act_ratios': act_ratios,
            'shannon': hprime
        })
        

# Heatmap Clustering

### Heatmaps for all forced rim trajectories (rim_heatmaps)

Convert the saved activations and trajectories from 1.1.1 to smoothed heatmaps

In [None]:
from sklearn.decomposition import PCA
    
pca = PCA()
pca.fit(heatmaps)

n_components = 30
plt.plot(pca.explained_variance_[:n_components] / pca.explained_variance_.sum())
ev = pca.explained_variance_
ev[:n_components].sum() / ev.sum()

pca = PCA(n_components=n_components)
reduced = pca.fit_transform(heatmaps)

pca2 = PCA(n_components=2)
reduced2 = pca.fit_transform(heatmaps)

In [None]:
pca.explained_variance_ratio_[:30].sum()

In [None]:
plt.scatter(reduced2.T[0], reduced2.T[1])

KMeans seems to be doing something quite reasonable. 

### KMeans Clustering

In [None]:
from sklearn.cluster import DBSCAN, KMeans

# dbscan = DBSCAN(eps=5, min_samples=10)
# preds = dbscan.fit_predict(reduced)
# preds

kmeans = KMeans()
preds = kmeans.fit_predict(reduced)
# preds

cluster_idxs = []
for i in range(8):
    cluster_idxs.append(preds == i)

In [None]:
cluster = 0
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 1
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))

fig, ax = pplt.subplots(nrows=nrows, ncols=nrows)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 2
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))

fig, ax = pplt.subplots(nrows=nrows, ncols=nrows)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 3
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))

fig, ax = pplt.subplots(nrows=nrows, ncols=nrows)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

### Agglomerative Clustering

Agglomerative clustering allows us to try to see if we can do a dendrogram as a measure of sanity on how many clusters should be made

In [None]:
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram

def plot_dendrogram(model, **kwargs):
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    
    for i, merge in  enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count
    
    linkage_matrix = np.column_stack(
        [model.children_, model.distances_, counts]
    ).astype(float)
    
    dendrogram(linkage_matrix, **kwargs)
    
# model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
# model = model.fit(heatmaps)

# plot_dendrogram(model, p=5, truncate_mode='level')


In [None]:
distance_threshold = 31

model = AgglomerativeClustering(distance_threshold=distance_threshold, n_clusters=None)
preds = model.fit_predict(heatmaps)

fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xticks([])
ax.plot([0, 3700], [distance_threshold, distance_threshold], '--', c='red3', linewidth=1)
plot_dendrogram(model, truncate_mode='level', ax=ax, 
                color_threshold=distance_threshold, no_labels=True)


num_clusters = int(np.max(preds) + 1)
cluster_idxs = []
for i in range(num_clusters):
    cluster_idxs.append(preds == i)
    
plt.savefig(save + '1_1_2_agglomerative_dendrogram.png')

In [None]:
cluster = 0
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

In [None]:
cluster = 1
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

In [None]:
cluster = 2
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

In [None]:
cluster = 3
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

In [None]:
cluster = 4
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

In [None]:
cluster = 5
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

In [None]:
cluster = 6
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

In [None]:
cluster = 7
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

In [None]:
cluster = 8
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))
    
fig.save(save + f'1_1_2_cluster{cluster}.png')

### Checking whether those agents with circular trajectories have differing classes

In [None]:
width = 4
trial = 2

ep = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}', 'rb'))


heatmap_idxs = heatmap_model_to_idxs[width][trial]
specific_heatmaps = heatmaps[heatmap_idxs[0]:heatmap_idxs[1]]
specific_heatmaps = specific_heatmaps.reshape(-1, 30, 30)

fig, ax = pplt.subplots([[0,1,1,0],[2,2,3,3],[4,4,5,5]])

p = ep['pos']
ax[0].scatter(p.T[0], p.T[1], alpha=0.2)
for i in range(4):
    ax[i+1].imshow(specific_heatmaps[i], extent=(5, 295, 5, 295))
    
fig.save(save + '1_1_3_example_4width_circling_trajectories.png')

In [None]:
def convert_labels_to_bar(labels, num_classes=9):
    heights = []
    
    for i in range(num_classes):
        heights.append(np.sum(labels == i))
        
    return np.array(heights)

widths = [4, 8, 16, 32, 64]
# widths = [4]
trials = 3

fig, ax = pplt.subplots(nrows=len(widths), ncols=trials, share=False)
taxs = ax.panel('t', space=0, share=False)

colors = pplt.Cycle('default').by_key()['color']
hex_to_rgb = lambda h: tuple(int(h.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
rgb_colors = np.array([hex_to_rgb(color) for color in colors])/255

titles = []

ax_idx = 0
for width in widths:
    for trial in range(trials):
        ep = pickle.load(open(f'data/pdistal_rim_heatmap/width{width}_t{trial}', 'rb'))
        p = ep['pos']
        
        heatmap_idxs = heatmap_model_to_idxs[width][trial]
        # specific_heatmaps = heatmaps[heatmap_idxs[0]:heatmap_idxs[1]]
        # specific_heatmaps = specific_heatmaps.reshape(-1, 30, 30)
        
        labels = preds[heatmap_idxs[0]:heatmap_idxs[1]]
        
        ax[ax_idx].scatter(p.T[0], p.T[1], alpha=0.2)
        taxs[ax_idx].bar(convert_labels_to_bar(labels), colors=rgb_colors)
        
        titles.append(f'Width {width}, #{trial}')
        ax_idx += 1
                      
ax.format(title=titles, leftlabels=[f'Width {width}' for width in widths],
         suptitle='Original Trajectories Starting from Edge Initial Conditions and Clustering Classes of Nodes Undeer Forced Actions')
taxs.format(xlocator=[], ylocator=[])
fig.save(save + '1_1_3_trajectories_and_classes.png')

# Checking some strange forced node behavior

In [None]:
model_name = 'nav_poster_netstructure/nav_pdistal_width64batch200'
model, obs_rms, kwkargs = load_model_and_env(model_name, 0)

In [None]:
[param.shape for param in model.base.gru.parameters()]

In [None]:
params = list(model.base.gru.parameters())

In [None]:
plt.hist(params[0][:64, 1].detach().numpy())

In [None]:
fig, ax = pplt.subplots(nrows=8, ncols=8, share=True)
for i in range(64):
    # ax[i].hist(params[0][i, :].detach().numpy())
    ax[i].hist(params[1][i+64, :].detach().numpy())

In [None]:
fig, ax = pplt.subplots()
for i in range(64):
    ax.scatter([i]*64, params[1][:64, i].detach().numpy(), c='blue', alpha=0.2)
    ax.scatter([i], params[1][1, i].detach().numpy(), c='black', marker='x')

# KMeans Clustering Tests 

## Generating KMeans Clustering Model

In [None]:
centers = kmeans.cluster_centers_
dists = [np.sqrt(np.sum((centers - reduced[i])**2, axis=0).min()) for i in range(reduced.shape[0])]

heatmap_idxs = np.argsort(dists)[:16]
ph = heatmaps[heatmap_idxs].reshape(-1, 30, 30)

num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

labels = preds[heatmap_idxs]
ax.format(title=[f'Cluster {label}' for label in labels])

In [None]:
model = AgglomerativeClustering(distance_threshold=distance_threshold, n_clusters=None)
preds = model.fit_predict(heatmaps)
num_clusters = int(np.max(preds) + 1)
cluster_idxs = []
for i in range(num_clusters):
    cluster_idxs.append(preds == i)
cluster_freqs = [cluster.sum() for cluster in cluster_idxs]

entropy(cluster_freqs)

In [None]:
cluster = 0
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 1
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 2
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 3
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 4
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 5
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 6
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 7
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

In [None]:
cluster = 8
ph = heatmaps[cluster_idxs[cluster], :].reshape(-1, 30, 30)
num_heatmaps = ph.shape[0]
nrows = int(np.ceil(np.sqrt(num_heatmaps)))
ncols = int(np.ceil(np.sqrt(num_heatmaps)))
if ncols*(nrows-1) >= ph.shape[0]:
    nrows = nrows-1

fig, ax = pplt.subplots(nrows=nrows, ncols=ncols)
for i in range(ph.shape[0]):
    ax[i].imshow(ph[i], extent=(5, 295, 5, 295))

## Directness and richness experimentation

In [None]:
fig, ax = pplt.subplots(nrows=3, ncols=3)

ax.format(title=[f'Cluster {c}' for c in range(num_clusters)])


for i in range(num_clusters):
    x = []
    y = []
    for width in widths:
        
        ress = results[width]
        for res in ress:
            x.append(res[1][i])
            y.append(res[2])
    
    ax[i].scatter(x, y)

In [None]:
fig, ax = pplt.subplots(nrows=3, ncols=3)

ax.format(title=[f'Cluster {c}' for c in range(num_clusters)])


for i in range(num_clusters):
    x = []
    y = []
    for width in widths:
        
        ress = results[width]
        for res in ress:
            x.append(res[1][i])
            y.append(res[4])
    
    ax[i].scatter(x, y)

In [None]:
#Richenss measure:
# * Ignore Cluster 1
# * Ignore largest cluster showing
# * For remaining clusters, take the minimal proportion
#     Richness is num_clusters * min_proportion

fig, ax = pplt.subplots()
x = []
y = []

for width in widths:

    ress = results[width]
    for res in ress:
        ratios = res[1].copy()
        ratios[1] = 0
        # ratios[ratios.argmax()] = 0
        nonzero = ratios[ratios != 0]
        
        if len(nonzero) > 0:
            min_proportion = nonzero.min()
            richness = min_proportion * len(nonzero)**2
        else:
            richness = 0
        
        x.append(richness)
        y.append(res[2])
        # np.arg
        
ax.scatter(x, y)

In [None]:
#Simple Richness

fig, ax = pplt.subplots(ncols=2)
ax.format(title=['Richness', 'Menhinicks Index'])
x = []
x2 = []
y = []



for width in widths:

    ress = results[width]
    for res in ress:
        ratios = res[1].copy()
        ratios[1] = 0
        # ratios[ratios.argmax()] = 0
        nonzero = ratios[ratios != 0]
        
        richness = len(nonzero)
        # rel_richness = richness / np.sqrt(width)
        rel_richness = richness / np.log(width)
        x.append(richness)
        x2.append(rel_richness)
        y.append(res[2])
        # np.arg
        
ax[0].scatter(x, y)
ax[1].scatter(x2, y)

In [None]:
#Diversity measures

fig, ax = pplt.subplots(ncols=2, nrows=2, share=False)
ax.format(title=['Shanon-Wiener Index', 'Simpsons Dominance', 'Shanon Number', 'Simpsons Number'])
x = []
x2 = []
x3 = []
x4 = []
y = []

for width in widths:

    ress = results[width]
    for res in ress:
        counts, ratios = count_labels(res[0], remove_zeros=True)

        hprime = np.sum(-ratios * np.log(ratios))        
        lambd = np.sum(counts * (counts-1)) / (width*(width-1))

        
        x.append(hprime)
        x2.append(1 - lambd)
        x3.append(np.exp(hprime))
        x4.append(1/(1-lambd))
        y.append(res[2])
        
        # np.arg
        
ax[0].scatter(x, y)
ax[1].scatter(x2, y)
ax[2].scatter(x3, y)
ax[3].scatter(x4, y)

In [None]:
pickle.load(open(f'data/pdistal_rim_heatmap/width16_t1', 'rb'))

## Shannon-Wiener

### Success Measures

In [None]:
#Diversity measures

def count_labels(clabels, ignore_cluster=None, remove_zeros=False):
    #Convert a list of cluster labels into ratios
    cluster_counts = np.zeros(num_clusters)
    for i in range(num_clusters):
        cluster_counts[i] = np.sum(clabels == i)
        
    if ignore_cluster is not None:
        if type(ignore_cluster) == list:
            for c in ignore_cluster:
                cluster_counts[c] = 0
        elif type(ignore_cluster) == int:
            cluster_counts[ignore_cluster] = 0
    
    cluster_ratios = cluster_counts / np.sum(cluster_counts)
    
    if remove_zeros:
        cluster_ratios = cluster_ratios[cluster_ratios != 0]
        cluster_counts = cluster_counts[cluster_counts != 0]
    return cluster_counts, cluster_ratios

title = ['Directness', 'Success Rate', 'Average Ep Length', 
                 'Average Successful Ep Length']
fig, ax = pplt.subplots(ncols=2, nrows=2, share=False)
ax.format(title=title)
x = []
ys = {t: [] for t in title}

for width in widths:

    ress = results[width]
    for res in ress:
        counts, ratios = count_labels(res[0], remove_zeros=True)
        hprime = np.sum(-ratios * np.log(ratios))        
        
        
        a
        x.append(hprime)
        ys['Directness'].append(res[2]) 
        ys['Success Rate'].append(res[3]) 
        ys['Average Ep Length'].append(res[4])
        ys['Average Successful Ep Length'].append(res[5])
        
x = np.array(x)
for i, t in enumerate(title):
    ys[t] = np.array(ys[t])
    ax[i].scatter(x, ys[t])


**Penalizing Cluster 1 richness**

To penalize, we simply add p_1*log(p_1) to the richness score (basically removing the positive contribution of p_1 to richness)

This seems like it helps smooth an outlier in the Directness graph, but adds an outlier to the Average Successfull Ep Length graph. So it seems unnecessary

In [None]:
#Diversity measures

def count_labels(clabels, ignore_cluster=None, remove_zeros=False):
    #Convert a list of cluster labels into ratios
    cluster_counts = np.zeros(num_clusters)
    for i in range(num_clusters):
        cluster_counts[i] = np.sum(clabels == i)
        
    if ignore_cluster is not None:
        if type(ignore_cluster) == list:
            for c in ignore_cluster:
                cluster_counts[c] = 0
        elif type(ignore_cluster) == int:
            cluster_counts[ignore_cluster] = 0
    
    cluster_ratios = cluster_counts / np.sum(cluster_counts)
    
    if remove_zeros:
        cluster_ratios = cluster_ratios[cluster_ratios != 0]
        cluster_counts = cluster_counts[cluster_counts != 0]
    return cluster_counts, cluster_ratios

title = ['Directness', 'Success Rate', 'Average Ep Length', 
                 'Average Successful Ep Length']
fig, ax = pplt.subplots(ncols=2, nrows=2, share=False)
ax.format(title=title)
x2 = []
ys2 = {t: [] for t in title}

for width in widths:

    ress = results[width]
    for res in ress:
        counts, ratios = count_labels(res[0], remove_zeros=False)
        nonzero = ratios[ratios != 0]
        hprime = np.sum(-nonzero * np.log(nonzero))
        #further penalize the 1 cluster?
        if ratios[1] != 0:
            hprime += ratios[1] * np.log(ratios[1])
        
        
        x2.append(hprime)
        ys2['Directness'].append(res[2]) 
        ys2['Success Rate'].append(res[3]) 
        ys2['Average Ep Length'].append(res[4])
        ys2['Average Successful Ep Length'].append(res[5])
        
x2 = np.array(x2)
for i, t in enumerate(title): 
    ys2[t] = np.array(ys2[t])
    ax[i].scatter(x2, ys2[t])


In [None]:
#Compare with simple num nodes

def count_labels(clabels, ignore_cluster=None, remove_zeros=False):
    #Convert a list of cluster labels into ratios
    cluster_counts = np.zeros(num_clusters)
    for i in range(num_clusters):
        cluster_counts[i] = np.sum(clabels == i)
        
    if ignore_cluster is not None:
        if type(ignore_cluster) == list:
            for c in ignore_cluster:
                cluster_counts[c] = 0
        elif type(ignore_cluster) == int:
            cluster_counts[ignore_cluster] = 0
    
    cluster_ratios = cluster_counts / np.sum(cluster_counts)
    
    if remove_zeros:
        cluster_ratios = cluster_ratios[cluster_ratios != 0]
        cluster_counts = cluster_counts[cluster_counts != 0]
    return cluster_counts, cluster_ratios

title = ['Richness', 'Directness', 'Success Rate', 'Average Ep Length', 
                 'Average Successful Ep Length']
array = [[0, 1, 1, 0],
         [2, 2, 3, 3],
         [4, 4, 5, 5]]
fig, ax = pplt.subplots(array, share=False)
ax.format(title=title)
x2 = []
ys2 = {t: [] for t in title}

for width in widths:

    ress = results[width]
    for res in ress:
        counts, ratios = count_labels(res['cluster_labels'], remove_zeros=False)
        nonzero = ratios[ratios != 0]
        hprime = np.sum(-nonzero * np.log(nonzero))
        #further penalize the 1 cluster?
        if ratios[1] != 0:
            hprime += ratios[1] * np.log(ratios[1])
        
        
        # x2.append(hprime)
        x2.append(width)
        ys2['Richness'].append(res['shannon'])
        ys2['Directness'].append(res['directness']) 
        ys2['Success Rate'].append(res['success_rate']) 
        ys2['Average Ep Length'].append(res['avg_ep_len'])
        ys2['Average Successful Ep Length'].append(res['avg_succ_ep_len'])
        
x2 = np.array(x2)
ax[0].scatter(x2, ys2['Richness'])
plot_titles = ['Directness', 'Success Rate', 'Average Ep Length', 
                 'Average Successful Ep Length']
for i, t in enumerate(plot_titles): 
    ys2[t] = np.array(ys2[t])
    ax[i+1].scatter(x2, ys2[t])


In [None]:
#Compare with simple num nodes

def count_labels(clabels, ignore_cluster=None, remove_zeros=False):
    #Convert a list of cluster labels into ratios
    cluster_counts = np.zeros(num_clusters)
    for i in range(num_clusters):
        cluster_counts[i] = np.sum(clabels == i)
        
    if ignore_cluster is not None:
        if type(ignore_cluster) == list:
            for c in ignore_cluster:
                cluster_counts[c] = 0
        elif type(ignore_cluster) == int:
            cluster_counts[ignore_cluster] = 0
    
    cluster_ratios = cluster_counts / np.sum(cluster_counts)
    
    if remove_zeros:
        cluster_ratios = cluster_ratios[cluster_ratios != 0]
        cluster_counts = cluster_counts[cluster_counts != 0]
    return cluster_counts, cluster_ratios

title = ['Directness', 'Success Rate', 'Average Ep Length', 
                 'Average Successful Ep Length']
fig, ax = pplt.subplots(nrows=2, ncols=2, share=False)
ax.format(title=title)
x2 = []
ys2 = {t: [] for t in title}

color_plots = []

for n, width in enumerate(widths):
    ress = results[width]
    for res in ress:
        x2.append(res['shannon'])
        ys2['Directness'].append(res['directness']) 
        ys2['Success Rate'].append(res['success_rate']) 
        ys2['Average Ep Length'].append(res['avg_ep_len'])
        ys2['Average Successful Ep Length'].append(res['avg_succ_ep_len'])
        color_plots.append(n)
        
x2 = np.array(x2)
color_plots = np.array(color_plots)

for j, width in enumerate(widths):
    ps = color_plots == j
    for i, t in enumerate(plot_titles): 
        ys2[t] = np.array(ys2[t])
        ax[i].scatter(x2[ps], ys2[t][ps])


### Linear Regressions

Here we fit a couple linear regressions to see whether it is important to penalize a model from having cluster 1. From the below results, it seems again that adding this penalization (described above) does not improve r2 scores. Ultimately it looks like leaving the diversity measure to simply be Shannon-Wiener instead of coming up with an arbitrary penalization is best

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
lr = LinearRegression()
lr.fit(x.reshape(-1, 1), ys['Directness'].reshape(-1, 1))
ypred = lr.predict(x.reshape(-1, 1))
print(lr.coef_, lr.intercept_, r2_score(ys['Directness'].reshape(-1, 1), ypred))

lr = LinearRegression()
lr.fit(x2.reshape(-1, 1), ys2['Directness'].reshape(-1, 1))
ypred = lr.predict(x2.reshape(-1, 1))
print(lr.coef_, lr.intercept_, r2_score(ys2['Directness'].reshape(-1, 1), ypred))


In [None]:
lr = LinearRegression()
lr.fit(x.reshape(-1, 1), ys['Average Successful Ep Length'].reshape(-1, 1))
ypred = lr.predict(x.reshape(-1, 1))
print(lr.coef_, lr.intercept_, r2_score(ys['Average Successful Ep Length'].reshape(-1, 1), ypred))

lr = LinearRegression()
lr.fit(x2.reshape(-1, 1), ys2['Average Successful Ep Length'].reshape(-1, 1))
ypred = lr.predict(x2.reshape(-1, 1))
print(lr.coef_, lr.intercept_, r2_score(ys2['Average Successful Ep Length'].reshape(-1, 1), ypred))


## Action Ratios

It seems like agents may have a directional preference (in particular, a cw/ccw bias). We would like to see what factors influence the bias, and if the bias affects ultimate performance

In [None]:
#0-2 action ratios
fig, ax = pplt.subplots()

x = []
y = []
for width in widths:
    ress = results[width]
    for res in ress:
        act_ratios = res[6]
        x.append(act_ratios[0])
        y.append(act_ratios[2])

ax.scatter(x, y)
ax.format(xlabel='Ratio of Left Turns', ylabel='Ratio of Right Turns')

In [None]:
#0-2 action ratios
fig, ax = pplt.subplots(ncols=3)

title = ['Left', 'Forward', 'Right']
ax.format(title=title)

xs = {t: [] for t in title}
y = []


for width in widths:
    ress = results[width]
    for res in ress:
        act_ratios = res[6]
        xs['Left'].append(act_ratios[0])
        xs['Forward'].append(act_ratios[1])
        xs['Right'].append(act_ratios[2])
        y.append(res[2])

for i, t in enumerate(title):
    ax[i].scatter(xs[t], y)


In [None]:
fig, ax = pplt.subplots(nrows=3, ncols=3)

xs = [[], [], [], [], [], [], [], [], []]
ys = [[], [], []]

for width in widths:
    ress = results[width]
    for res in ress:
        act_ratios = res[6]
        c_ratios = res[1]
        for i in range(3):
            ys[i].append(act_ratios[i])
        for i in range(9):
            xs[i].append(c_ratios[i])

for i in range(9):
    for j in range(3):
        ax[i].scatter(xs[i], ys[j])

In [None]:
fig, ax = pplt.subplots(nrows=3, ncols=3)

xs = [[], [], [], [], [], [], [], [], []]
ys = []

for width in widths:
    ress = results[width]
    for res in ress:
        act_ratios = res[6]
        c_ratios = res[1]
        lr_pref = act_ratios[0] / (act_ratios[0] + act_ratios[1])
        #ratio of left:right
        ys.append(lr_pref)
        
        for i in range(9):
            xs[i].append(c_ratios[i])

for i in range(9):
    # for j in range(3):
    ax[i].scatter(xs[i], ys)

In [None]:
from sklearn.linear_model import LinearRegression

X = np.array(xs).T
y = np.array(ys).reshape(-1, 1)
lr = LinearRegression()
lr.fit(X, y)
pred = lr.predict(X)

print(lr.coef_)

In [None]:
from torch.nn import Linear, ReLU, Sequential
from torch.nn.functional import mse_loss
from torch.optim import Adam
from sklearn.model_selection import train_test_split

xs = []
ys = []

for width in widths:
    ress = results[width]
    for res in ress:
        act_ratios = res['act_ratios']
        lr_ratio = act_ratios[0] / (act_ratios[0] + act_ratios[2])
        xs.append(res['cluster_ratios'])
        ys.append(lr_ratio)
        

lm = Sequential(Linear(9, 4),  ReLU(), Linear(4, 1))
X = np.array(xs)
y = torch.tensor(np.array(ys).reshape(-1, 1), dtype=torch.float32)
X = torch.tensor(X, dtype=torch.float32)
opt = Adam(lm.parameters(), lr=0.01)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y)

losses = []
for i in range(1000):
    pred = lm(X_train)
    loss = mse_loss(pred, y_train)
    opt.zero_grad()
    loss.backward()
    losses.append(loss.item())
    opt.step()

pred_train = lm(X_train)
pred_test = lm(X_test)

fig, ax = pplt.subplots(ncols=2)
ax[0].scatter(pred_train.T[0].detach(), y_train.T[0])
ax[1].scatter(pred_test.T[0].detach(), y_test.T[0])

### Directness/Richness vs LR Ratio

In [None]:
xs = []
xs2 = []
ys = []

fig, ax = pplt.subplots(ncols=2, sharex=False)

for width in widths:
    ress = results[width]
    for res in ress:
        act_ratios = res['act_ratios']
        lr_ratio = act_ratios[0] / (act_ratios[0] + act_ratios[2])
        xs.append(res['directness'])
        xs2.append(res['shannon'])
        ys.append(2*lr_ratio - 1)
        
ax[0].scatter(xs, ys, alpha=0.5)
ax[1].scatter(xs2,  ys, alpha=0.5)
ax[0].format(xlabel='directness', ylabel='LR Ratio')
ax[1].format(xlabel='richness')

In [None]:
from sklearn.linear_model import LinearRegression
model = LinearRegression()

fig, ax = pplt.subplots(nrows=3, ncols=3)
xs = []
ys = []

for width in widths:
    ress = results[width]
    for res in ress:
        act_ratios = res['act_ratios']
        lr_ratio = act_ratios[0] / (act_ratios[0] + act_ratios[2])
        xs.append(res['cluster_ratios'])
        ys.append(lr_ratio)
        
for i in range(9):
    ax[i].scatter(np.array(xs).T[i], (np.array(ys)-0.5)*2)

# Development of Clusters over training

## Data Generation

Ended up moving this into fil run_checkpoint_analysis.py to be able to run this on remote computers

### Heatmap from conserved trajectory activations

In [None]:
width = 64
trial = 0
data_folder = 'data/pdistal_rim_heatmap/'
data_path = data_folder + f'width{width}_checkpoint'
checkpoint_data = pickle.load(open(data_path, 'rb'))

heatmap_path = data_folder + f'width{width}_checkpoint_hms'
if not Path(heatmap_path).exists():
    heatmap_data = {}
else:
    heatmap_data = pickle.load(open(heatmap_path, 'rb'))

if trial not in heatmap_data: 
    heatmap_data[trial] = {}

for chkp_val in tqdm(checkpoint_data['copied'][trial]):
    eps = checkpoint_data['copied'][trial][chkp_val]
    heatmaps = []

    p = eps['pos']
    a = eps['activ']

    for i in range(a.shape[1]):
        heatmap = gaussian_smooth(p, a[:, i])
        heatmaps.append(heatmap)
    
    heatmap_data[trial][chkp_val] = heatmaps
        
pickle.dump(heatmap_data, open(heatmap_path, 'wb'))



### Summary of policy pathways

In [None]:
'''
Older code that puts results into list for plots prior to 4.4. Probably will want to update
those plotting code to use next block dictionary entries
'''

kmeans = pickle.load(open('data/pdistal_rim_heatmap/kmeans_heatmap_clusterer', 'rb'))

width = 64
trial = 0


data_folder = 'data/pdistal_rim_heatmap/'
data_path = data_folder + f'width{width}_checkpoint'
checkpoint_data = pickle.load(open(data_path, 'rb'))

heatmap_path = data_folder + f'width{width}_checkpoint_hms'
summary_path = data_folder + f'width{width}_checkpoint_summ'

heatmap_data = pickle.load(open(heatmap_path, 'rb'))
if not Path(summary_path).exists():
    summary_data = {}
else:
    summary_data = pickle.load(open(summary_path, 'rb'))

if trial not in summary_data: 
    summary_data[trial] = {}

for chkp_val in tqdm(checkpoint_data['policy'][trial]):    
    eps = checkpoint_data['policy'][trial][chkp_val]
    directness = compute_directness(pos=eps['pos'])

    ep_dones = split_by_ep(eps['dones'], eps['dones'])
    ep_lens = np.array([ep.shape[0] for ep in ep_dones])
    success_rate = 1 - np.sum(ep_lens == 202) / len(ep_lens)
    average_ep_len = np.mean(ep_lens)
    average_succ_ep_len = np.mean(ep_lens[ep_lens < 202])

    acts = eps['actions']
    act_ratios = np.array([np.sum(acts == i) for i in range(4)]) / len(acts)

    hms = np.vstack([hm.reshape(1, -1) for hm in heatmap_data[trial][chkp_val]])
    labels = kmeans.predict(hms)
    _, ratios = count_labels(labels, remove_zeros=False)
    _, nonzero = count_labels(labels, remove_zeros=True)
    hprime = np.sum(-nonzero * np.log(nonzero))        



    summary_data[trial][chkp_val] = {
        'cluster_labels': labels, 
        'cluster_ratios': ratios, 
        'directness': directness, 
        'success_rate': success_rate, 
        'avg_ep_len': average_ep_len, 
        'avg_succ_ep_len': average_succ_ep_len, 
        'act_ratios': act_ratios,
        'shannon': hprime
    }
    
pickle.dump(summary_data, open(summary_path, 'wb'))
        


## Analysis of development

### Compression Test

Looks like we can effectively compress heatmaps to 16-bit (half) precision, saving 75% of space for transfer. Visually the heatmaps look identical, and clustering produces identical results

In [None]:


data_folder = 'data/pdistal_rim_heatmap/'
data_path = data_folder + f'width{width}_checkpoint'
heatmap_path = data_folder + f'width{width}_checkpoint_hms'
summary_path = data_folder + f'width{width}_checkpoint_summ'

heatmaps = pickle.load(open(heatmap_path, 'rb'))

In [None]:
heatmap_half = {}

for t in heatmaps:
    heatmap_half[t] = {}
    for chk in heatmaps[t]:
        heatmap_half[t][chk] = []
        for hm in heatmaps[t][chk]:
            heatmap_half[t][chk].append(hm.astype(np.float16))

In [None]:
pickle.dump(heatmap_half, open(data_folder+f'width{width}_checkpoint_hms_half', 'wb'))

In [None]:
heatmap_half[t][chk].append(hm.astype(np.float16))

In [None]:
chks = [0, 100, 200, 300, 400, 500]
fig, ax = pplt.subplots(nrows=2, ncols=len(chks))
for i, chk in enumerate(chks):
    ax[0, i].imshow(heatmaps[0][chk][0])
    ax[1, i].imshow(heatmap_half[0][chk][0])
    

In [None]:
for chk in heatmaps[0]:
    l1 = pred_kmeans(heatmaps[0][chk], kmeans)
    l2 = pred_kmeans(heatmap_half[0][chk], kmeans)
    if not (l1 == l2).all():
        print(chk)

In [None]:
xs = []
ys = []
for chkp in summary_data[0]:
    xs.append(chkp)
    ys.append(summary_data[0][chkp]['shannon'])

In [None]:
fig, ax = pplt.subplots(nrows=2, ncols=5)
for i in range(10):
    chkp = list(heatmap_data[0].keys())[i]
    
    hm = heatmap_data[0][chkp][0]
    ax[i].imshow(hm)

In [None]:
width = 64
trial = 0
data_folder = 'data/pdistal_rim_heatmap/'
data_path = data_folder + f'width{width}_checkpoint'
checkpoint_data = pickle.load(open(data_path, 'rb'))

heatmap_path = data_folder + f'width{width}_checkpoint_hms'
if not Path(heatmap_path).exists():
    heatmap_data = {}
else:
    heatmap_data = pickle.load(open(heatmap_path, 'rb'))

if trial not in heatmap_data: 
    heatmap_data[trial] = {}

    
hms = []
for i, chkp_val in tqdm(enumerate(checkpoint_data['copied'][trial])):
    if i == 10:
        break
    eps = checkpoint_data['copied'][trial][chkp_val]
    heatmaps = []

    p = eps['pos']
    a = eps['activ']

    for i in range(1):
        heatmap = gaussian_smooth(p, a[:, i])
        heatmaps.append(heatmap)
        hms.append(heatmap)
    
    heatmap_data[trial][chkp_val] = heatmaps
        
# pickle.dump(heatmap_data, open(heatmap_path, 'wb'))



In [None]:
fig, ax = pplt.subplots(nrows=2, ncols=5)
for i in range(10):
    # chkp = list(heatmap_data[0].keys())[i]
    
    # hm = heatmap_data[0][chkp][0]
    hm = hms[i]
    ax[i].imshow(hm)

In [None]:
'''
Older code that puts results into list for plots prior to 4.4. Probably will want to update
those plotting code to use next block dictionary entries
'''

kmeans = pickle.load(open('data/pdistal_rim_heatmap/kmeans_heatmap_clusterer', 'rb'))

width = 64
trial = 0


data_folder = 'data/pdistal_rim_heatmap/'
data_path = data_folder + f'width{width}_checkpoint'
checkpoint_data = pickle.load(open(data_path, 'rb'))

heatmap_path = data_folder + f'width{width}_checkpoint_hms'
summary_path = data_folder + f'width{width}_checkpoint_summ'

heatmap_data = pickle.load(open(heatmap_path, 'rb'))
if not Path(summary_path).exists():
    summary_data = {}
else:
    summary_data = pickle.load(open(summary_path, 'rb'))

if trial not in summary_data: 
    summary_data[trial] = {}

for chkp_val in tqdm(checkpoint_data['policy'][trial]):    
    eps = checkpoint_data['policy'][trial][chkp_val]
    directness = compute_directness(pos=eps['pos'])

    ep_dones = split_by_ep(eps['dones'], eps['dones'])
    ep_lens = np.array([ep.shape[0] for ep in ep_dones])
    success_rate = 1 - np.sum(ep_lens == 202) / len(ep_lens)
    average_ep_len = np.mean(ep_lens)
    average_succ_ep_len = np.mean(ep_lens[ep_lens < 202])

    acts = eps['actions']
    act_ratios = np.array([np.sum(acts == i) for i in range(4)]) / len(acts)

    hms = np.vstack([hm.reshape(1, -1) for hm in all_heatmaps[width][trial]])
    labels = kmeans.predict(hms)
    _, ratios = count_labels(labels, remove_zeros=False)
    _, nonzero = count_labels(labels, remove_zeros=True)
    hprime = np.sum(-nonzero * np.log(nonzero))        



    summary_data[trial][chkp_val] = {
        'cluster_labels': labels, 
        'cluster_ratios': ratios, 
        'directness': directness, 
        'success_rate': success_rate, 
        'avg_ep_len': average_ep_len, 
        'avg_succ_ep_len': average_succ_ep_len, 
        'act_ratios': act_ratios,
        'shannon': hprime
    }
    
pickle.dump(summary_data, open(summary_path, 'wb'))
        


In [None]:
for chk in summary_data[0]:
    print(summary_data[0][chk]['directness'])

## Evolution of Cluster Ratios over Training

**IMPORTANT NOTE: Because of the way files were saved and named, checkpoints were not loaded in order and simply iterating over keys in dictionaries will lead to anachronistic results. To iterate, sort checkpoint names and iterate over sorted.**

### Preliminary Examination

On preliminary viewing, it looks like directness and average ep lengths improve regardless of changes to representative richness (for 16 width networks). However on closer inspection

In [None]:
width = 16
trial = 0

data_folder = 'data/pdistal_rim_heatmap/'
data_path = data_folder + f'width{width}_checkpoint'
heatmap_path = data_folder + f'width{width}_checkpoint_hms_half'
summary_path = data_folder + f'width{width}_checkpoint_summ'

summ = pickle.load(open(summary_path, 'rb'))
hms = pickle.load(open(heatmap_path, 'rb'))

chks = list(summ[0].keys())
chks = np.sort(chks)

In [None]:
fig, ax = pplt.subplots(nrows=3, ncols=3)
for j in range(9):
    all_ratios = np.vstack([summ[j][chk]['cluster_ratios'] for chk in chks])
    for i in range(num_clusters):
        ax[j].plot(all_ratios[:, i])

In [None]:
import pandas as pd

In [None]:

for i in range(shannons.shape[0]):
    plt.plot(pd.Series(shannons[i]).ewm(alpha=0.1).mean())

In [None]:
fig, ax = pplt.subplots(nrows=3, ncols=3)

shannons = []
directnesses = []
avg_ep_lens = []
for j in summ:
    shannon = np.vstack([summ[j][chk]['shannon'] for chk in chks]).squeeze()
    directness = np.vstack([summ[j][chk]['directness'] for chk in chks]).squeeze()
    avg_ep_len = np.vstack([summ[j][chk]['avg_ep_len'] for chk in chks]).squeeze()
    shannons.append(shannon)
    directnesses.append(directness)
    avg_ep_lens.append(avg_ep_len)


for i in range(9):
    richness = pd.Series(shannons[i]).ewm(alpha=0.3).mean()
    directness = pd.Series(directnesses[i]).ewm(alpha=0.3).mean()
    avg_ep_len = pd.Series(avg_ep_lens[i]).ewm(alpha=0.3).mean() / 200
    # richness = shannons[i]
    # directness = directnesses[i]
    ax[i].plot(richness, label='richness')
    ax[i].plot(directness, label='directness')
    ax[i].plot(avg_ep_len, label='avg_ep_len')
    
ax[0].legend(loc='ur')

In [None]:
fig, ax = pplt.subplots(nrows=3, ncols=3)

width = 8

data_folder = 'data/pdistal_rim_heatmap/'
data_path = data_folder + f'width{width}_checkpoint'
heatmap_path = data_folder + f'width{width}_checkpoint_hms_half'
summary_path = data_folder + f'width{width}_checkpoint_summ'

summ = pickle.load(open(summary_path, 'rb'))
hms = pickle.load(open(heatmap_path, 'rb'))

chks = list(summ[0].keys())
chks = np.sort(chks)

shannons = []
directnesses = []
avg_ep_lens = []
for j in summ:
    shannon = np.vstack([summ[j][chk]['shannon'] for chk in chks]).squeeze()
    directness = np.vstack([summ[j][chk]['directness'] for chk in chks]).squeeze()
    avg_ep_len = np.vstack([summ[j][chk]['avg_ep_len'] for chk in chks]).squeeze()
    shannons.append(shannon)
    directnesses.append(directness)
    avg_ep_lens.append(avg_ep_len)


for i in range(9):
    richness = pd.Series(shannons[i]).ewm(alpha=0.3).mean()
    directness = pd.Series(directnesses[i]).ewm(alpha=0.3).mean()
    avg_ep_len = pd.Series(avg_ep_lens[i]).ewm(alpha=0.3).mean() / 200
    # richness = shannons[i]
    # directness = directnesses[i]
    ax[i].plot(richness, label='richness')
    ax[i].plot(directness, label='directness')
    ax[i].plot(avg_ep_len, label='avg_ep_len')
    
ax[0].legend(loc='ur')

### Differentials

Two interesting notes
* It seems like a lot of richness is almost decided at random initiation?! Specifically it seems like a model that starts with a lot of cluster 1 will continue to have it towards the end
    * We definitely need to spend some time exploring how individual nodes evolve over training besides ensemble metrics. Is it that individual nodes stay the same or that it just happens that as an ensemble the clustering ratios are preserved?
* Despite this dependence on initial clustering and that all models qualitatively show the same behavior of learning, there do seem to be quantitative differences in how learning progresses
    * There are noticeable areas of rapid improvement vs. steadier improvement, which seem correlated with increases and decreases in richness
    * **In this section, we are trying to see whether the checkpoint to checkpoint differentials in richness can be statistically correlated with differentials in directness or performance as measured by ep length**

In [None]:
width = 16
trial = 0

data_folder = 'data/pdistal_rim_heatmap/'
data_path = data_folder + f'width{width}_checkpoint'
heatmap_path = data_folder + f'width{width}_checkpoint_hms_half'
summary_path = data_folder + f'width{width}_checkpoint_summ'

summ = pickle.load(open(summary_path, 'rb'))
hms = pickle.load(open(heatmap_path, 'rb'))

chks = list(summ[0].keys())
chks = np.sort(chks)

shannons = []
directnesses = []
avg_ep_lens = []
for j in summ:
    shannon = np.vstack([summ[j][chk]['shannon'] for chk in chks]).squeeze()
    directness = np.vstack([summ[j][chk]['directness'] for chk in chks]).squeeze()
    avg_ep_len = np.vstack([summ[j][chk]['avg_ep_len'] for chk in chks]).squeeze()
    shannons.append(shannon)
    directnesses.append(directness)
    avg_ep_lens.append(avg_ep_len)

In [None]:
shannon_diffs = []
directness_diffs = []
avg_ep_len_diffs = []
checkpoints = []

for t in range(len(shannons)):
    shannon_diffs.append(np.diff(shannons[t]))
    directness_diffs.append(np.diff(directnesses[t]))
    avg_ep_len_diffs.append(np.diff(avg_ep_lens[t]))
    checkpoints.append(chks[:-1])
    
shannon_diffs = np.concatenate(shannon_diffs)
directness_diffs = np.concatenate(directness_diffs)
avg_ep_len_diffs = np.concatenate(avg_ep_len_diffs)
checkpoints = np.concatenate(checkpoints)

In [None]:
plt.scatter(shannon_diffs, directness_diffs, alpha=0.1)

In [None]:
plt.scatter(checkpoints, directness_diffs, alpha=0.1) 

In [None]:
X = np.vstack([checkpoints, shannon_diffs]).T
y = directness_diffs.reshape(-1, 1)
lm = LinearRegression()
lm.fit(X, y)


In [None]:
y_pred = lm.predict(X)

In [None]:
r2_score(y, y_pred)

In [None]:
np.corrcoef(y.T, y_pred.T)[0, 1]**2

In [None]:
np.corrcoef(X.T, y.T)

In [None]:
shannon_diffs = []
directness_diffs = []
avg_ep_len_diffs = []
checkpoints = []

for t in range(len(shannons)):
    shannon_diffs.append(np.diff(shannons[t]))
    directness_diffs.append(np.diff(directnesses[t]))
    avg_ep_len_diffs.append(np.diff(avg_ep_lens[t]))
    checkpoints.append(chks[:-1])
    
shannon_diffs = np.vstack(shannon_diffs)
directness_diffs = np.vstack(directness_diffs)
avg_ep_len_diffs = np.vstack(avg_ep_len_diffs)
checkpoints = np.vstack(checkpoints)

In [None]:
directness_diffs[0][2:-2].shape

In [None]:
directness_diffs.shape

In [None]:
X = np.array([shannon_diffs[:, i:i-4].reshape(-1) for i in range(4)])
y = directness_diffs[:, 2:-2].reshape(-1)

In [None]:
X.shape

In [None]:
np.corrcoef(X, y)

In [None]:
np.corrcoef(X)

In [None]:
fig, ax = pplt.subplots(nrows=2, ncols=2)
for i in range(4):
    ax[i].scatter(X[i], y)