In [None]:
import glob
import os

from addict import Dict
import xarray as xr
import matplotlib.pyplot as plt

from vmodel import plot as vplot
from vmodel.util import color as vcolor
from vmodel.util import mpl as vmpl

%load_ext autoreload
%autoreload 2


In [None]:
data_dir = '/home/fabian/vmodel_datasets/trajectories/pointmass'

In [None]:
num_agents = 30

In [None]:
paths = sorted(glob.glob(f'{data_dir}/**/agents_{num_agents}_*.nc', recursive=True))
paths = [p for p in paths if '.metrics.nc' not in p]
len(paths)

In [None]:
# Read datasets
expdict = Dict()
for path in paths:
    expname = os.path.split(os.path.dirname(path))[-1]
    ds = xr.open_dataset(path)
    expdict[expname] = ds

In [None]:
expdict.keys()

In [None]:
ds = xr.concat(expdict.values(), 'exp')
ds.coords['exp'] = list(expdict.keys())

In [None]:
timeslice = slice(0, 60)

In [None]:
ds = ds.sel(run=1).isel(time=timeslice)

In [None]:
figsize = (10, 3)

In [None]:
def plot_trajectories(ax, ds, focal_agent=1, exp='visual', focal_color='black'):
    """Plot trajectories of positions over time
    Args:
        ds (dataset): timestep x agent x state
    """
    start, mid, end = 0, len(ds.time.data) // 2, -1

    putlabel = True
    for agent in ds.agent.data:

        # Check if focal agent
        isfocal = agent == focal_agent

        xs, ys = ds.position.sel(agent=agent)

        # alpha = 1.0 if isfocal else 0.5
        # lw = 2.0 if isfocal else 0.5
        alpha = 1.0
        lw = 2.0 if isfocal else 1.0
        color = focal_color if isfocal else vcolor.grey
        zorder = 1 if isfocal else 0

        if isfocal:
            label = f'{exp} (focal)'
        elif not isfocal and putlabel:
            label = f'{exp} (others)'
            putlabel = False
        else:
            label = None

        # Plot trajectory
        line = ax.plot(xs, ys, color=color, lw=lw, alpha=alpha, zorder=zorder, label=label)
        # color = line[0].get_color()

        # Plot trajectory start
        x0, y0 = xs.isel(time=start), ys.isel(time=start)
        ax.plot(x0, y0, color=color, marker='s', alpha=alpha, lw=lw, zorder=zorder)

        # Plot mid point
        xm, ym = xs.isel(time=mid), ys.isel(time=mid)
        ax.plot(xm, ym, color=color, marker='>', alpha=alpha, lw=lw, zorder=zorder)

        # Plot trajectory end
        xt, yt = xs.isel(time=end), ys.isel(time=end)
        circle = plt.Circle((xt, yt), color=color, radius=0.25, alpha=alpha, lw=lw, zorder=zorder)
        ax.add_patch(circle)
        # ax.plot(xt, yt, color=color, marker='o', alpha=alpha, lw=lw, zorder=zorder)

    # ax.grid()
    offset = 7
    ax.set(ylim=(-offset, offset), xlim=(-offset, 29 + offset))
    ax.set(xlabel=r'$x$ position [$m$]', ylabel=r'$y$ position [$m$]')
    ax.set(aspect='equal')

    handles, labels = ax.get_legend_handles_labels()
    order = [1, 0]
    newhandles = [handles[i] for i in order]
    newlabels = [labels[i] for i in order]
    ax.legend(newhandles, newlabels, loc='upper center', ncol=2)
    ax.locator_params(axis='y', nbins=5)

In [None]:
focal_agent = 12

## Visual

In [None]:
fig, ax = plt.subplots(figsize=figsize)
exp = 'visual'
plot_trajectories(ax, ds.sel(exp='visual'), focal_agent=focal_agent, exp=exp, focal_color=vcolor.visual)
fig.savefig(f'trajectories_visual.pdf', bbox_inches='tight')

## Visual + voronoi

In [None]:
fig, ax = plt.subplots(figsize=figsize)
exp = 'visual + voronoi'
plot_trajectories(ax, ds.sel(exp='voronoi'), focal_agent=focal_agent, exp=exp, focal_color=vcolor.voronoi)
fig.savefig(f'trajectories_visual_voronoi.pdf', bbox_inches='tight')

## Visual + myopic

In [None]:
fig, ax = plt.subplots(figsize=figsize)
exp = 'visual + myopic'
plot_trajectories(ax, ds.sel(exp='myopic'), focal_agent=focal_agent, exp=exp, focal_color=vcolor.myopic)
fig.savefig(f'trajectories_visual_myopic.pdf', bbox_inches='tight')

## Visual + topological

In [None]:
fig, ax = plt.subplots(figsize=figsize)
exp = 'visual + topological'
plot_trajectories(ax, ds.sel(exp='topo'), focal_agent=focal_agent, exp=exp, focal_color=vcolor.topological)
fig.savefig(f'trajectories_visual_topo.pdf', bbox_inches='tight')