In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

from vmodel.visibility import visibility_set
import vmodel.geometry as vgeom

import vmodel.util.color as vcolor
from addict import Dict

In [None]:
plt.rcParams['figure.figsize'] = [4, 4]
plt.rcParams['figure.autolayout'] = True

In [None]:
path = '/home/fabian/vmodel_datasets/visual_switching/agents_300_runs_1_times_100_dist_1.0_perc_10.0_topo_0_rngstd_0.0.states.nc'

In [None]:
ds = xr.open_dataset(path).sel(run=1)

In [None]:
# Only every 10th timestep
vis = ds.visibility

In [None]:
# Create switches array (shift by one timestep and apply XOR)
switches = np.logical_xor(vis, vis.shift(time=1, fill_value=False)).sum('agent2')

In [None]:
timestep = 12

In [None]:
def plot(ax, ds, timestep=1, focal_agent=1, center_agent=True, plot_text=False, plot_appearance=True):

    dst = ds.isel(time=timestep)

    margin = dst.perception_radius  # margin around focal agent
    perception_radius = dst.perception_radius
    radius = float(ds.radius)

    pos_self = dst.position.sel(agent=focal_agent)
    px, py = pos_self.data
    px, py

    # Position computation
    positions = dst.position.data
    pos_self = positions[focal_agent - 1]
    pos_others = np.delete(positions, focal_agent - 1, axis=0) - pos_self
    distances = np.insert(np.linalg.norm(pos_others, axis=1), focal_agent - 1, float('inf'))
    pos_others_me = np.insert(pos_others, focal_agent - 1, np.array([0, 0]), axis=0)

    too_far = distances > perception_radius

    # Center figure around agent or swarm
    if center_agent:
        xlim = (px - margin, px + margin)
        ylim = (py - margin, py + margin)
    else:
        xlim = (positions[:, 0].min(), positions[:, 0].max())
        ylim = (positions[:, 1].min(), positions[:, 1].max())


    # Add background rectangle
    perc_radius = plt.Circle((px, py), radius=perception_radius, color='white', zorder=1)
    ax.add_patch(perc_radius)

    perc_circle = plt.Circle((px, py), radius=perception_radius, fill=False, ls=':', lw=0.5, ec='grey', zorder=100)
    ax.add_patch(perc_circle)

    # Loop over other agents
    for a in dst.agent.data:
        dsa = dst.sel(agent=a)

        istoofar = too_far[a - 1]
        isfocal = (a == focal_agent)
        isvisible = dst.visibility.sel(agent=focal_agent, agent2=a).data

        if isfocal:
            color = vcolor.focal
        elif istoofar:
            color = vcolor.invisible
        elif not istoofar and isvisible:
            color = vcolor.visual
        elif not istoofar and not isvisible:
            color = vcolor.grey

        x, y = dsa.position
        dx, dy = dsa.velocity
        speed = np.linalg.norm([dx, dy])

        timeslice = slice(timestep - 10, timestep)
        
        if plot_appearance:
            wasvisible = ds.visibility.isel(time=slice(timestep - 10, timestep)).mean('time').sel(agent=focal_agent, agent2=a).data.round().astype(bool)
            # wasvisible = ds.visibility.isel(time=timestep -1).sel(agent=focal_agent, agent2=a).data.astype(bool)
            didappear = not wasvisible and isvisible
            diddisappear = wasvisible and not isvisible
            if didappear:
                color = vcolor.voronoi
            elif diddisappear:
                color = vcolor.topological

        # Plot tail
        xs, ys = ds.position.sel(agent=a).isel(time=timeslice).data
        ax.plot(xs, ys, color=color, alpha=0.5, zorder=99)

        # Plot position
        ax.plot(x, y)
        circle = plt.Circle((x, y), radius=radius, color=color, zorder=99)
        ax.add_patch(circle)

        # Plot velocity
        # head_length = 0.3 * speed
        # head_width = head_length / 2
        # ax.arrow(x, y, dx, dy, color=color, head_length=head_length, head_width=head_width,
        #         length_includes_head=True, zorder=99)

        # Plot agent number
        if plot_text:
            ax.text(x, y, s=f'{a}', ha='center', va='center', zorder=100)
        if not isfocal:
            p1, p2 = vgeom.tangent_points_to_circle(pos_others_me[a - 1], radius)
            p1, p2 = np.array(p1), np.array(p2)
            ps1, ps2 = p1 * perception_radius * 4, p2 * perception_radius * 4
            origin = np.array([px, py])  # need to translate by origin!!
            poly = np.array([p1, ps1, ps2, p2]) + origin
            polygon = plt.Polygon(poly, color=vcolor.lightgrey, zorder=1)
            ax.add_patch(polygon)
        
    ax.set(xlim=xlim, ylim=ylim)
    ax.set(aspect='equal')
    ax.set(xlabel='x [m]', ylabel='y [m]')
    ax.set(facecolor=vcolor.lightgrey)
    # ax.set(title=f'T = {timestep}s')


In [None]:
# fig, axes = plt.subplots(ncols=3, figsize=(20, 20), sharey=True)
focal_agent = 65  # take some agent as focal agent
plot_text = False
start = 10
offset = 10
timesteps = [start, start + offset, start + 2 * offset, start + 3 * offset]
for i in range(len(timesteps)):
    # ax = axes[i]
    timestep = timesteps[i]
    fig, ax = plt.subplots()
    plot(ax, ds, timestep=timestep, focal_agent=focal_agent, plot_text=False)
    ax.locator_params(axis='y', nbins=4)
    ax.grid()
    fig.savefig(f'visual_switching_{i + 1}.pdf')
