In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import json
import operator
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as clr
import matplotlib.patches as patches
import matplotlib.cm as cm

from tqdm import tqdm
from itertools import groupby
from matplotlib.ticker import MaxNLocator
from pathlib import Path
from loguru import logger
logger.remove()

from pim.simulator import SimulationExperiment
from pim.setup import load_results, enumerate_results

In [None]:
def get_eval_data(result, radius=20):
    noise = result.parameters['cx']['params']['noise']
    param_noise = result.parameters['cx']['params'].get('parameter_noise', None)
    T_outbound = result.parameters['T_outbound']
    name = result.name
    
    min_dist = np.linalg.norm(result.closest_position())
    tort_score = result.tortuosity_score()
#     tort_score = result.compute_tortuosity()
    angle_offset = result.compute_disk_leaving_angle(radius)
    Tort_T,_,actual,_,optimal = result.homing_tortuosity()
    
    mem_error = result.memory_error()
    angular_mem_error = result.angular_memory_error()
    heading_error = result.heading_error()
    
    velocities = result.velocities
    
#     print(velocities)
    
    return {
        'name':  name,
        'T_outbound': T_outbound,
        'noise': noise,
        'parameter_noise': param_noise,
        'min_dist': min_dist,
        'tort_score': tort_score,
        'angle_offset': angle_offset,
        'Tort_T': Tort_T,
        'actual': actual,
        'optimal': optimal,
        'mem_error': mem_error,
        'angular_mem_error': angular_mem_error,
        'heading_error': heading_error,
        'velocities': velocities
    }

In [None]:
# ----------------------------------- Stone functions -----------------------------------
def plot_angular_distances(noise_levels, angular_distances, bins=18, ax=None,
                           label_font_size=11, log_scale=False, title=None):
        
    fig = None
    if ax is None:
        fig, ax = plt.subplots(subplot_kw=dict(projection='polar'),
                               figsize=(10, 10))

    colors = [cm.viridis(x) for x in np.linspace(0, 1, len(noise_levels))]

    for i in reversed(range(len(noise_levels))):
        plot_angular_distance_histogram(angular_distance=angular_distances[i],
                                        ax=ax, bins=bins, color=colors[i], noise=noise_levels[i])

    ax.set_theta_zero_location("N")
    ax.set_theta_direction(-1)
    ax.set_rlabel_position(22)
    ax.set_title(title, y=1.08, fontsize=label_font_size)

    if log_scale:
        ax.set_rscale('log')
        ax.set_rlim(0.0, 10001)  # What determines this?

    plt.tight_layout()
    return fig, ax

def plot_angular_distance_histogram(angular_distance, ax=None, bins=36,
                                    color='b',noise=0.1):
    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=(6*1.2, 6))

    radii = np.histogram(angular_distance,
                         np.linspace(-np.pi - np.pi / bins,
                                     np.pi + np.pi / bins,
                                     bins + 2,
                                     endpoint=True))[0]
    radii[0] += radii[-1]
    radii = radii[:-1]
    radii = np.roll(radii, int(bins/2))
    radii = np.append(radii, radii[0])
    # Set all values to have at least a count of 1
    # Need this hack to get the plot fill to work reliably
    radii[radii == 0] = 0.5
    theta = np.linspace(0, 2 * np.pi, bins+1, endpoint=True)

    ax.plot(theta, radii, color=color, alpha=0.5, label=noise)
    if color:
        ax.fill_between(theta, 0, radii, alpha=0.2, color=color)
    else:
        ax.fill_between(theta, 0, radii, alpha=0.2)

    return fig, ax

def get_xy_from_velocity(V):
    XY = np.cumsum(V, axis=1)
    X = XY[:, :, 0]
    Y = XY[:, :, 1]
    return X, Y

def compute_path_straightness(V,T_outbound):
    X, Y = get_xy_from_velocity(V)
    N = X.shape[0]

    # Distances to the nest at each homebound point
    D = np.sqrt(X[:, T_outbound:]**2 + Y[:, T_outbound:]**2)
    turn_dists = D[:, 0]

    # Get shortest distance so far to nest at each time step
    # We make the y axis equal, by measuring in terms of proportion of
    # route distance.
    cum_min_dist = np.minimum.accumulate(D.T / turn_dists)

    # Get cumulative speed
    cum_speed = np.cumsum(np.sqrt((V[:, T_outbound:, 0]**2 + V[:, T_outbound:, 1]**2)), axis=1)

    # Now we also make the x axis equal in terms of proportion of distance
    # Time is stretched to compensate for longer/shorter routes
    cum_min_dist_norm = []
    for i in np.arange(N):
        t = cum_speed[i]
        xs = np.linspace(0, turn_dists[i]*2, 500, endpoint=False)
        cum_min_dist_norm.append(np.interp(xs,
                                           t,
                                           cum_min_dist[:, i]))
    return np.array(cum_min_dist_norm).T

def compute_tortuosity(cum_min_dist):
    """Computed with tau = L / C."""
    mu = np.nanmean(cum_min_dist, axis=1)
    if 1.0 - mu[int(len(mu)/2)] == 0.0:
        return None
    
    tortuosity = 1.0 / (1.0 - mu[int(len(mu)/2)])
    return tortuosity

def plot_route_straightness(cum_min_dist, x_count=500, ax=None,
                            label_font_size=14, unit_font_size=10):
    fig = None
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 10))

    # TESTING remove this if necessary
    mu = np.nanmean(cum_min_dist, axis=1)
    sigma = np.nanstd(cum_min_dist, axis=1)
    t = np.linspace(0, 2, x_count)

    ax.plot(t, mu, label='Mean path',color='b')
    ax.fill_between(t, mu+sigma, mu-sigma, color='b', alpha=0.2)
    ax.set_ylim(0, 1.01)
    ax.plot([0, 1], [1, 0], 'r', label='Best possible path')
    ax.set_xlabel('Distance travelled relative to turning point distance',
                  fontsize=label_font_size)
    ax.set_ylabel('Distance from home', fontsize=label_font_size)
    ax.set_title('(f)')#Tortuosity of homebound route', y=1.05, fontsize=label_font_size)

    vals = ax.get_xticks()
    ax.set_xticklabels(['{:3.0f}%'.format(x*100) for x in vals])

    vals = ax.get_yticks()
    ax.set_yticklabels(['{:3.0f}%'.format(x*100) for x in vals])
    ax.tick_params(labelsize=unit_font_size)

    ax.axvline(x=1, ymin=0, ymax=mu[250], color='black', linestyle='dotted')

    ax.annotate(text='',
                xy=(1, mu[250]),
                xytext=(1, 1),
                arrowprops=dict(facecolor='black',
                                arrowstyle='<->'))

    ax.text(1.05, mu[250]+(1-mu[250])/2, '$C$', fontsize=15, color='k',
            ha='left', va='center')

    l = ax.legend(loc='lower left', prop={'size': 12}, handlelength=0,
                  handletextpad=0)
    colors = ['b', 'r']
    for i, text in enumerate(l.get_texts()):
        text.set_color(colors[i])
        text.set_ha('right')  # ha is alias for horizontalalignment
        text.set_position((103, 0))
    for handle in l.legendHandles:
        handle.set_visible(False)
    l.draw_frame(False)
    plt.tight_layout()
    return fig, ax

In [None]:
# --------------------------------- Evaluation stuff -----------------------------------

def mem_errors(data, noise=0.1, T_outbound=1500, ax=None, color="b"):
    errors = np.array([result['mem_error'] for result in data 
                if result['noise'] == noise 
                and result['T_outbound'] == T_outbound])
    
    mean_errors = errors.mean(axis=0)
    std_errors =  errors.std(axis=0)

    if ax is not None:
        timesteps = np.arange(0,T_outbound+1500,1)
        
        ax.fill_between(timesteps,mean_errors+std_errors, mean_errors, color=color, alpha=0.2)
        ax.plot(timesteps, mean_errors, color=color, label=noise)
        ax.fill_between(timesteps,mean_errors-std_errors, mean_errors, color=color, alpha=0.2)
        
        ax.set_xlabel("Time (steps)")
        ax.set_ylabel("Memory error")
        ax.set_title("(d)")
        
        ax.legend(title="Noise", loc='upper left')

    return errors

def angular_mem_errors(data, noise=0.1, T_outbound=1500, ax=None, color="b"):
    errors = np.array([result['angular_mem_error'] for result in data 
                    if result['noise'] == noise 
                    and result['T_outbound'] == T_outbound])
    mean_errors = errors.mean(axis=0) * 180/np.pi
    std_errors =  errors.std(axis=0) * 180/np.pi
    
    if ax is not None:
        timesteps = np.arange(0,T_outbound+1500,1)
        
        ax.fill_between(timesteps,mean_errors+std_errors, mean_errors, color=color, alpha=0.2)
        ax.plot(timesteps, mean_errors, color=color, label=noise)
        ax.fill_between(timesteps,mean_errors-std_errors, mean_errors, color=color, alpha=0.2)
        
        ax.set_xlabel("Timesteps")
        ax.set_ylabel("Angular memory error")
        ax.set_yticks([-180,-90,0,90,180])
        
        ax.legend(title="Noise", loc='upper left')

    return errors

def heading_errors(data, noise=0.1, T_outbound=1500, ax=None, color="b"):
    errors = np.array([result['heading_error'] for result in data 
                if result['noise'] == noise 
                and result['T_outbound'] == T_outbound])
    
    home_errors = np.array([error[T_outbound:] for error in errors])
    mean_errors = home_errors.mean(axis=0)
    std_errors =  home_errors.std(axis=0)
   
    if ax is not None:
        timesteps = np.arange(0,1500,1)
        
        ax.fill_between(timesteps,mean_errors+std_errors, mean_errors, color=color, alpha=0.2)
        ax.plot(timesteps, mean_errors, color=color, label=noise)
        ax.fill_between(timesteps,mean_errors-std_errors, mean_errors, color=color, alpha=0.2)
        
        ax.set_xlabel("Inbound time (steps)")
        ax.set_ylabel("Heading error")
        ax.set_title("(e)")
        
        ax.legend(title="Noise", loc='upper right')

    return errors, home_errors

def tortuosity_plot(data, noise=0.1, T_outbound=1500, ax=None):
    actual_dists = []
    optimal_dists = []
    T = []
    
    for result in data:
        if result['noise'] == noise and result['T_outbound'] == T_outbound:
            T = result['Tort_T']
            actual_dists.append(result['actual'])
            optimal_dists.append(result['optimal'])

    actual_dists = np.array(actual_dists)
    optimal_dists = np.array(optimal_dists)
    
    mean_actual = actual_dists.mean(axis=0)
    std_actual = actual_dists.std(axis=0)
#     min_actual = actual_dists.min(axis=0)
#     max_actual = actual_dists.max(axis=0)
    
    mean_optimal = optimal_dists.mean(axis=0)
#     min_optimal = optimal_dists.min(axis=0)
#     max_optimal = optimal_dists.max(axis=0)
    
    if ax is not None:
        ax.fill_between(T,mean_actual, mean_actual+std_actual, color="blue", alpha=0.2)
        ax.plot(T, mean_actual, label="mean distance from home", color="blue")
        ax.fill_between(T,mean_actual, mean_actual-std_actual, color="blue", alpha=0.2)
        ax.plot(T, mean_optimal, label="mean optimal distance", color="orange")
#         ax.fill_between(T,mean_optimal, min_optimal, color="orange", alpha=0.2)
#         ax.fill_between(T,mean_optimal, max_optimal, color="orange", alpha=0.2)
        ax.set_xlabel("Timesteps homing")
        ax.set_ylabel("% of homing distance remaining")
        ax.set_xlim(0, T_outbound)
        ax.legend()

def example_path(result, ax=None, decode=False):
    ax.axis("equal")
    T_inbound = result.parameters['T_inbound']
    T_outbound = result.parameters['T_outbound']
    ax.set_title(f'(a)')
    result.plot_path(ax=ax, decode=decode,search_pattern=False)
    
def min_dist_histogram(data, noise=0.1, T_outbound=1500, ax=None, binwidth=1, confidence = 0.95):
    min_dists = np.array([result['min_dist'] for result in data 
                            if result['noise'] == noise
                            and result['T_outbound'] == T_outbound])
    
    interval = np.percentile(min_dists,[100*(1-confidence)/2,100*(1-(1-confidence)/2)])
    
    if ax is not None:
        ax.yaxis.set_major_locator(MaxNLocator(integer=True))
        ax.set_xlabel("Closest distance")
        ax.set_ylabel("Frequency")
        ax.set_title("(b)")
        ax.hist(min_dists, bins=np.arange(min(min_dists), max(min_dists) + binwidth, binwidth))
        ax.axvline(interval[1],color="k",linestyle="--")
    
def angle_offset_after_radius(data, noise_levels=[0.1,0.2,0.3,0.4], radius=20, ax=None):
    grouped = group_by_noise(data)
    angular_dists = np.array([[result['angle_offset'] for result in group] for group in grouped])
    if ax is not None:
        plot_angular_distances(noise_levels,angular_dists,ax=ax)
        ax.set_title("(g)", loc="left")
        ax.legend(title="Noise", loc='lower left')
    
def min_dist_v_route_length(model, noise=0.1, ax=None, color="b", param=False):
    
    if not param:
        model = [result for result in model if result['noise'] == noise and result['parameter_noise'] == None]
        legend_title="Noise"
        title = "(c)"
    else:
        model = [result for result in model if result['noise'] == 0.1 and result['parameter_noise'] == noise]
        legend_title="Parameter noise"
        title = ""
    
    grouped = group_by_outbound(model)
    
    distances = np.array([get_min_dists(group,noise,get_outbounds(group)[0]) for group in grouped])
#     min_dists = distances.min(axis=1)
    mean_dists = distances.mean(axis=1)
#     max_dists = distances.max(axis=1)
    std_dists = distances.std(axis=1)
    
    if ax is not None:
        ax.set_xlabel("Outbound time (steps)")
        ax.set_ylabel("Closest distance")
        ax.set_title(title)

        ax.plot(outbounds, mean_dists, color=color, label=noise)
        ax.fill_between(outbounds, mean_dists+std_dists, mean_dists, alpha=0.2, color=color)
        ax.fill_between(outbounds, mean_dists-std_dists, mean_dists, alpha=0.2, color=color)
        ax.legend(title=legend_title, loc='upper left')
        

def min_dist_v_route_length_2(model, noise=0.1, ax=None, color="b", param=False):
    
    model = [result for result in model if result['noise'] == noise and result['parameter_noise'] == None]
    
    grouped = group_by_outbound(model)
    
    distances = np.array([get_min_dists(group,noise,get_outbounds(group)[0]) for group in grouped])
    mean_dists = distances.mean(axis=1)
    std_dists = distances.std(axis=1)
    
    if ax is not None:
        ax.set_xlabel("Outbound time (steps)")
        ax.set_ylabel("Closest distance")
#         ax.set_title(title)

        ax.plot(outbounds, mean_dists, color=color, label=model[0]['name'].split(".")[0])
        ax.fill_between(outbounds, mean_dists+std_dists, mean_dists, alpha=0.2, color=color)
        ax.fill_between(outbounds, mean_dists-std_dists, mean_dists, alpha=0.2, color=color)
        ax.legend(title='Models', loc='upper left')


# ----------------------------------- Help functions -----------------------------------
def get_random_result(model, noise=0.1, T_outbound=1500):
    filtered_model = [result for result in model 
                          if result.parameters['cx']['params']['noise'] == 0.1 
                          and result.parameters['T_outbound'] == T_outbound]
    random_idx = np.random.randint(0,len(filtered_model))
    print(random_idx)
    return filtered_model[random_idx]

def get_noise_levels(model):
    return sorted(set([result["noise"] for result in model]))

def group_by_noise(data):
    return [list(v) for l,v in groupby(sorted(data, key=lambda x:x["noise"]), lambda x: x["noise"])]

def get_outbounds(model):
    return sorted(set([result["T_outbound"] for result in model]))

def group_by_outbound(data):
    return [list(v) for l,v in groupby(sorted(data, key=lambda x:x["T_outbound"]), lambda x: x["T_outbound"])]

def group_by_name(data):
    lookup = {'dye basic':0, 'dye var beta':1, 'dye pontine':2, 'dye var beta + pontine':3, 'weights':4, 'stone':5}
    a = lambda x : lookup[x["name"].split('.')[0]]
    return [list(v) for l,v in groupby(sorted(data.values(), key=a), a)]

def get_min_dists(model,noise=0.1,T_outbound=1500):
    return np.array([result['min_dist'] for result in model])
                            

def tortuosity_scores(model, noise=0.1, T_outbound=1500):
    return np.array([result['tort_score'] for result in model 
                if result['noise'] == noise 
                and result['T_outbound'] == T_outbound])


In [None]:
paths = []
# paths.append("../../results/dye-eval_20221009-145512")
# paths.append("../../results/beta-dye-eval_20221009-145808")
# paths.append("../../results/cheat-dye-eval_20221009-145601")
# paths.append("../../results/beta-cheat-dye_20221009-150619")
# paths.append("../../results/model-v-model_20221010-173329")
# paths.append("../../results/stone-eval_20221010-153927")
# paths.append("../../results/weights-eval_20221011-202459")

paths.append("../../results/model-v-model-outbound_20221012-105545")

results = load_results(enumerate_results(paths))

In [None]:
noise = 0.1
T_outbound = 1500

data = {}
examples = []

for result in tqdm(results):
    if result.parameters['cx']['params']['noise'] == noise and result.parameters['T_outbound'] == T_outbound and len(examples) < 10 and result.parameters['cx']['params'].get('parameter_noise',None) == None:
        examples.append(result)
    
    curr_data = get_eval_data(result)
    data[curr_data['name']] = curr_data

In [None]:
def single_model_eval(data, noise_levels=[0.1,0.2,0.3,0.4]):
    fig = plt.figure(figsize=(15,20))
    
    noise = 0.1
    T_outbound = 1500

    colors = [cm.viridis(x) for x in np.linspace(0, 1, len(noise_levels))]
    
    model = []
    model_2 = []
    
    for result in data.values():
        if result['T_outbound'] == T_outbound and result['parameter_noise'] == None:
            model.append(result)
        model_2.append(result)
        
    # Example path
    example_path(examples[6],ax=plt.subplot(421))
    
        
    # Min dist histogram
    min_dist_histogram(model,noise,T_outbound,ax=plt.subplot(422),binwidth=2)

    # Min dist v outbound path
    for color, noise in reversed(list(zip(colors,noise_levels))):
        min_dist_v_route_length(model_2,noise,ax=plt.subplot(423),color=color)

    # Memory errors
    for color, noise in reversed(list(zip(colors,noise_levels))):
        _ = mem_errors(model,noise,T_outbound,ax=plt.subplot(424),color=color)

    # Heading errors
    for color, noise in reversed(list(zip(colors,noise_levels))):
        _,_ = heading_errors(model,noise,T_outbound,ax=plt.subplot(425),color=color)

    # # Angular memory errors
    # for color, noise in reversed(list(zip(colors,noise_levels))):
    #     _ = angular_mem_errors(model,noise,T_outbound,ax=plt.subplot(424),color=color)
    
    # Tortuosity
#     tortuosity_plot(model,noise,T_outbound,ax=plt.subplot(425))
    V = np.array([result['velocities'] for result in model if result['noise'] == noise])
    cum_min_dist = compute_path_straightness(V,1500)
    plot_route_straightness(cum_min_dist,ax=plt.subplot(426))
    

    
    # Angle after steps
    angle_offset_after_radius(model,radius=20,ax=plt.subplot(427,projection='polar'))
    
    fig.tight_layout()
    
    fig = plt.figure(figsize=(15,10))
    ax = plt.gca()
    param_noise_levels = [0.0,0.01,0.02,0.05,0.1]
    param_colors = [cm.viridis(x) for x in np.linspace(0, 1, len(param_noise_levels))]
    for color, noise in reversed(list(zip(param_colors,param_noise_levels))):
        min_dist_v_route_length(model_2,noise,ax=ax,color=color,param=True)
    ax.set_title("Minimum distance plot with noise on dye parameters")

single_model_eval(data)


In [None]:
def model_to_model(models=[], labels=[]):
    
    labels = [model[0]['name'].split(".")[0] for model in models]
    fig = plt.figure(figsize=(15,15))
    
    showmeans=True
    showmedians=False
    
    ax = plt.subplot(221)
    data = [get_min_dists(model) for model in models]
#     ax.boxplot(data,notch=True,labels=labels)
    ax.violinplot(data,showmeans=showmeans,showmedians=showmedians)
    ax.set_xticks([1,2,3,4,5,6])
    ax.set_xticklabels(labels)
    ax.set_title("(a)")
    ax.set_ylabel("Closest distance")
    plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
    
    ax = plt.subplot(222)
    data = [mem_errors(model).mean(axis=1) for model in models]
#     ax.boxplot(data,notch=True,labels=labels)
    ax.violinplot(data,showmeans=showmeans,showmedians=showmedians)
    ax.set_xticks([1,2,3,4,5,6])
    ax.set_xticklabels(labels)
    ax.set_title("(b)")
    ax.set_ylabel("Memory error")
    plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
    
    ax = plt.subplot(223)
    data = [heading_errors(model)[1].mean(axis=1) for model in models]
#     ax.boxplot(data,notch=True,labels=labels)
    ax.violinplot(data,showmeans=showmeans,showmedians=showmedians)
    ax.set_xticks([1,2,3,4,5,6])
    ax.set_xticklabels(labels)
    ax.set_title("(c)")
    ax.set_ylabel("Heading error during homing")
    plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
    
#     ax = plt.subplot(224)
#     data = [tortuosity_scores(model) for model in models]
# #     ax.boxplot(data,notch=True,labels=labels,showfliers=True)
#     ax.violinplot(data,showmeans=showmeans,showmedians=showmedians)
#     ax.set_xticks([1,2,3,4,5,6])
#     ax.set_xticklabels(labels)
#     ax.set_title("Tortuosity RMSE")
# #     ax.set_ylim(0,4)


#     data = []
#     for model in models:
#         scores = []
#         for result in model:
#             V = np.array([result['velocities']])
#             cum_min_dist = compute_path_straightness(V,1500)
            
#             tort = compute_tortuosity(cum_min_dist)
#             if tort is not None:
#                 scores.append(tort)
#         data.append(scores)

#     ax = plt.subplot(224)
#     ax.violinplot(data,showmeans=showmeans,showmedians=showmedians)
#     ax.set_xticks([1,2,3,4,5,6])
#     ax.set_xticklabels(labels)
#     ax.set_title("Tortuosity")
#     ax.set_ylim(0,40)
#     plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
    
        
    data = []
    for model in models:
        V = []
        for result in model:
            V.append(result['velocities'])
        V = np.array(V)
        cum_min_dist = compute_path_straightness(V,1500)
        tort = compute_tortuosity(cum_min_dist)
        data.append(tort)
    
    ax = plt.subplot(224)
    bars = ax.bar([1,2,3,4,5,6],data,alpha=0.5)
    ax.set_xticks([1,2,3,4,5,6])
    ax.set_xticklabels(labels)
    ax.set_title("(d)")
    ax.set_ylabel("Mean tortuosity")
    plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
    
    ax.bar_label(bars)
        
        #         print(model[0]['name'],' tortuosity is ',tort)
    
    fig.tight_layout()        

models = group_by_name(data)
model_to_model(models)

In [None]:
fig = plt.figure(figsize=(15,15))
models = group_by_name(data)
colors = [cm.rainbow(x) for x in np.linspace(0, 1, len(models))]

for model, color in zip(models,colors):
    min_dist_v_route_length_2(model,color=color,ax=plt.gca())