In [None]:
%matplotlib notebook
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import tqdm
import tqdm.contrib.concurrent
import re
from collections import namedtuple
import ipywidgets
import math
from IPython.display import display
from scipy import special, optimize

from experimental.beacon_sim.rollout_statistics_pb2 import AllStatistics, RolloutStatistics
from experimental.beacon_sim import plot_trials
from experimental.beacon_sim import correlated_beacons as cb

np.set_printoptions(linewidth=200)

import importlib
importlib.reload(plot_trials)
importlib.reload(cb)

In [None]:
data_dir = os.path.expanduser("~/scratch/initial_state_fix/")

In [None]:
Trial = namedtuple('Trial', ['trial_idx', 'cov_size'])
Instance = namedtuple('Instance', ['instance_idx', 'start_position', 'goal_position', 'path_length', 'plan', 'trials'])
instance_finder = re.compile(r'.*?([0-9]+)\.pb')

def unpack_trials(trials: list[RolloutStatistics]):
    return [Trial(trial_idx=i, cov_size=plot_trials.compute_covariance_size(stat)) for i, stat in enumerate(trials)]

def compute_path_length(stats: AllStatistics):
    current_pos = tuple(stats.local_from_start.translation.data)
    dist_m = 0.0
    for path_idx in stats.plan[1:]:
        if path_idx < 0:
            new_pos = (stats.goal.x, stats.goal.y)
        else:
            new_pos = tuple(stats.road_map.points[path_idx].data)
        
        edge_dist_m = ((current_pos[0] - new_pos[0]) ** 2 + (current_pos[1] - new_pos[1]) ** 2) ** 0.5
        dist_m += edge_dist_m
        current_pos = new_pos
    return dist_m

def load_instance(file_name: str) -> Instance:
    with open(file_name, 'rb') as file_in:
        stats = AllStatistics()
        stats.ParseFromString(file_in.read())
    
    m = instance_finder.search(file_name)
    
    instance_idx = int(m.group(1))
    
    
    
    return Instance(instance_idx=instance_idx,
                    start_position=list(stats.local_from_start.translation.data),
                    goal_position=(stats.goal.x, stats.goal.y),
                    path_length = compute_path_length(stats),
                    plan = list(stats.plan),
                    trials = unpack_trials(stats.statistics))


In [None]:
def load_road_map_and_world_map(file_path):
    with open(file_path, 'rb') as file_in:
        stats = AllStatistics()
        stats.ParseFromString(file_in.read())
        
    return stats.road_map, stats.world_map_config

In [None]:
file_paths = [os.path.join(data_dir, file_name) for file_name in os.listdir(data_dir)]

In [None]:
road_map, world_map = load_road_map_and_world_map(file_paths[0])

In [None]:


instances = tqdm.contrib.concurrent.process_map(load_instance, file_paths, chunksize=1, )
instances = sorted(instances, key=lambda x: x.instance_idx)

In [None]:
def missing_from_configuration(configuration, world_map):
    idxs = []
    for i in range(len(world_map.fixed_beacons.beacons)):
        if configuration & (1 << i):
            idxs.append(world_map.fixed_beacons.beacons[i].id)
    return sorted(idxs)

def plot_environment(road_map, world_map, plan, start_pos, goal_pos, configuration=None):
    if configuration is None:
        configuration = 0
    missing_beacons = missing_from_configuration(configuration, world_map)
        
    # plot the road map
    # Plot the edges
    line_segments = []
    for r in range(road_map.adj.num_rows):
        for c in range(r+1, road_map.adj.num_cols):
            idx = r * road_map.adj.num_cols + c
            if (road_map.adj.data[idx] == 1):
                line_segments.append([tuple(road_map.points[r].data), tuple(road_map.points[c].data)])
    edges = mpl.collections.LineCollection(line_segments, colors=(0.8, 0.8, 0.6, 1.0))
    ax = plt.gca()
    ax.add_collection(edges)
    
    # Plot the plan
    plan_points = []
    for node_idx in plan:
        if node_idx == -1:
            plan_points.append(start_pos)
        elif node_idx == -2:
            plan_points.append(goal_pos)
        else:
            plan_points.append(tuple(road_map.points[node_idx].data))
    plan_x = [pt[0] for pt in plan_points]
    plan_y = [pt[1] for pt in plan_points]
    plt.plot(plan_x, plan_y, 'b')
    
    # Plot the start
    plt.plot(*start_pos, 'md', markersize=10)
    
    # Plot the goal
    plt.plot(*goal_pos, 'g*', markersize=15)
    
    # Plot the nodes
    rm_x = [pt.data[0] for pt in road_map.points]
    rm_y = [pt.data[1] for pt in road_map.points]
    plt.plot(rm_x, rm_y, 'rs')
    
    TEXT_X_OFFSET = 0.2
    TEXT_Y_OFFSET = 0.2
#     for i, pt in enumerate(road_map.points):
#         plt.text(pt.data[0] + TEXT_X_OFFSET, pt.data[1] + TEXT_Y_OFFSET, i)
        
    # Plot the beacons
    beacon_xs = [beacon.pos_x_m for beacon in world_map.fixed_beacons.beacons if beacon.id not in missing_beacons]
    beacon_ys = [beacon.pos_y_m for beacon in world_map.fixed_beacons.beacons if beacon.id not in missing_beacons]
    beacon_ids = [beacon.id for beacon in world_map.fixed_beacons.beacons if beacon.id not in missing_beacons]
    
    for x, y, beacon_id in zip(beacon_xs, beacon_ys, beacon_ids):
        plt.text(x + TEXT_X_OFFSET, y + TEXT_Y_OFFSET, beacon_id)
    
    plt.plot(beacon_xs, beacon_ys, 'b^')
    plt.axis('equal')
    
def plot_covariances(trials, current_trial):
    sorted_trials = sorted(trials, key=lambda trial: trial.cov_size)
    plt.plot([t.cov_size for t in sorted_trials], list(range(len(sorted_trials))))
    sorted_idx = sorted_trials.index(current_trial)
    plt.plot(current_trial.cov_size, sorted_idx, 'g*', markersize=15)
    
def make_plots(fig, trial_idx, instance, trials, road_map, world_map):
    trial = trials[trial_idx]
    plt.figure(fig.number)
    plt.gcf().clear()
    plt.subplot(121)
    plot_environment(
        road_map, world_map, instance.plan, instance.start_position,
        instance.goal_position, configuration=trial.trial_idx)
    plt.title(f'Instance: {instance.instance_idx} Trial: {trial.trial_idx}, Path Length: {instance.path_length: 0.2f}')
    plt.xlabel('X (m)')
    plt.ylabel('Y (m)')
    plt.subplot(122)
    plot_covariances(trials, trial)
    plt.xlabel('$|\Sigma|$')
    plt.ylabel('Count')
    plt.title('CDF of $|\Sigma|$')
    plt.tight_layout()
    
def plot_results(road_map, world_map, instances):
    path_length_order_checkbox = ipywidgets.Checkbox(value=True, description='Sorted Instances?')
    instance_slider = ipywidgets.IntSlider(min = 0, max=len(instances)-1, step=1, description='Instance Idx', value=0)
    sorted_trial_checkbox = ipywidgets.Checkbox(value=True, description = "Sorted Trials?")
    fig = plt.figure(figsize=(12, 6))
    
    hbox = ipywidgets.HBox()
    
    sorted_instances = sorted(instances, key=lambda x: x.path_length)
    
    def on_instance_change(change):
        if on_instance_change.trial_slider:
            on_instance_change.trial_slider.close()
            hbox.children = hbox.children[:-1]
        
        instance_idx = instance_slider.value
        should_use_sorted_path_length = path_length_order_checkbox.value
        instances_to_use = sorted_instances if should_use_sorted_path_length else instances
        instance = instances_to_use[instance_idx]
        
        trials = instance.trials
        sorted_trials = sorted(trials, key=lambda trial: trial.cov_size)
        style = {"description_width": "initial"}
        on_instance_change.trial_slider = ipywidgets.IntSlider(
            min=0, max=len(instance.trials)-1, step=1,
            description=f'Trial Idx for instance {instance.instance_idx}', value=0, style=style)
        
        def on_trial_change(change):
            trial_idx = change["owner"].value
            should_use_sorted_trial = sorted_trial_checkbox.value
            trials_to_use = sorted_trials if should_use_sorted_trial else trials
            make_plots(fig, trial_idx, instance, trials_to_use, road_map, world_map)
            
        on_instance_change.trial_slider.observe(on_trial_change, 'value')
        on_trial_change({'owner': on_instance_change.trial_slider})
        hbox.children = (*hbox.children, on_instance_change.trial_slider)
    on_instance_change.trial_slider = None
    
    def on_sorted_change(change):
        instance_idx = instance_slider.value
        trial_idx = on_instance_change.trial_slider.value
        should_use_sorted_trial = sorted_trial_checkbox.value
        should_use_sorted_path_length = path_length_order_checkbox.value
        
        instances_to_use = sorted_instances if should_use_sorted_path_length else instances
        instance = instances_to_use[instance_idx]
        
        trials = instance.trials
        sorted_trials = sorted(trials, key=lambda trial: trial.cov_size)
        trials_to_use = sorted_trials if should_use_sorted_trial else trials
        
        make_plots(fig, trial_idx, instance, trials_to_use, road_map, world_map)
        
    
    instance_slider.observe(on_instance_change, 'value')
    path_length_order_checkbox.observe(on_sorted_change, 'value')
    sorted_trial_checkbox.observe(on_sorted_change, 'value')
    hbox.children = (path_length_order_checkbox, instance_slider, sorted_trial_checkbox)
    display(hbox)
    on_instance_change({'owner': instance_slider})
    
    

In [None]:
plot_results(road_map, world_map, instances)

In [None]:

def independent_beacons(marginal_prob):
    ids = [beacon.id for beacon in world_map.fixed_beacons.beacons]
    clique = cb.BeaconClique(p_beacon=marginal_prob, p_no_beacon=(1-marginal_prob) ** len(ids), members=ids)
    return cb.create_correlated_beacons(clique)

def all_correlated(marginal_prob, p_no_beacon):
    ids = [beacon.id for beacon in world_map.fixed_beacons.beacons]
    clique = cb.BeaconClique(p_beacon=marginal_prob, p_no_beacon=p_no_beacon, members=ids)
    return cb.create_correlated_beacons(clique)

def clustered_beacons(marginal_prob, p_no_beacon, clusters):
    out = None
    for members in clusters:
        clique = cb.BeaconClique(p_beacon=marginal_prob, p_no_beacon=p_no_beacon**(1.0/len(clusters)), members=members)
        if out is None:
            out = cb.create_correlated_beacons(clique)
        else:
            out *= cb.create_correlated_beacons(clique)
    
    return out


def corner_beacons(marginal_prob, p_no_beacon):
    clusters = [(51, 5, 53), (1, 3, 54), (50, 4, 52), (0, 2, 55)]
    return clustered_beacons(marginal_prob, p_no_beacon, clusters)

def sides_beacons(marginal_prob, p_no_beacon):
    clusters = [(0, 2, 4), (50, 52, 54), (1, 3, 5), (51, 53, 55)]
    return clustered_beacons(marginal_prob, p_no_beacon, clusters)

def cycle_beacons(marginal_prob, p_no_beacon):
    clusters = [(0, 50, 1, 51), (2, 52, 3, 53), (4, 54, 5, 55)]
    return clustered_beacons(marginal_prob, p_no_beacon, clusters)

def compute_expected_covariance_size(beacon_pot, world_map, trials):
    expected_cov_size = 0.0
    for trial in trials:
        missing_beacons = missing_from_configuration(trial.trial_idx, world_map)
        beacon_presence = {beacon.id: not (beacon.id in missing_beacons) for beacon in world_map.fixed_beacons.beacons}
        expected_cov_size += np.exp(beacon_pot.log_prob(beacon_presence)) * trial.cov_size
    return expected_cov_size

# Correlation types
# all independent
# sides
# corners
# cycling

In [None]:
marginal_prob = 0.4
p_no_beacons = 0.1
distributions = {
    'independent': independent_beacons(marginal_prob),
    'all_correlated': all_correlated(marginal_prob, p_no_beacons),
    'corners': corner_beacons(marginal_prob, p_no_beacons),
    'sides': sides_beacons(marginal_prob, p_no_beacons),
    'cycle': cycle_beacons(marginal_prob, p_no_beacons),    
}


In [None]:
expected_cov_size = {key: [] for key in distributions}

def compute_cov_size_for_instance(instance):
    out = {'instance_idx': instance.instance_idx}
    for key in distributions:
        out[key] = compute_expected_covariance_size(distributions[key], world_map, instance.trials)
    return out
        
results = tqdm.contrib.concurrent.process_map(compute_cov_size_for_instance, instances, chunksize=1)
        

In [None]:

for key in distributions:
    covs = sorted([stats[key] for stats in results])
    plt.figure('cdf')  
    plt.plot(covs, list(range(len(covs))), label=key)
    
    plt.figure('pdf')
    plt.hist(covs, label=key, histtype='step', bins=list(range(0, 250, 5)), linewidth=1)

plt.figure('cdf')
plt.legend()
plt.xlabel('$\mathbb{E}[|\Sigma|]$')
plt.ylabel('Instance Counts')
plt.tight_layout()

plt.figure('pdf')
plt.legend()
plt.xlabel('$\mathbb{E}[|\Sigma|]$')
plt.ylabel('Instance Counts')
plt.tight_layout()

