In [None]:
import os
import plotly.express as px
import numpy as np
import h3
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import cartopy.crs as ccrs

os.environ['HAVEN_DATABASE'] = 'haven'
os.environ['AWS_PROFILE'] = 'admin'

from mirrorverse.utils import read_data_w_cache

In [None]:
def catch_region_map(tag_key):
    for i, _id in enumerate(['172', '202', '159', '205', '210', '229', '142']):
        if tag_key.startswith(_id):
            return i
    return -1

PALETTE = {
    0: '#648FFF',
    1: '#785EF0',
    2: '#DC267F',
    3: '#FE6100', 
    4: '#FFB000', 
    5: '#3E589E', 
    6: '#060606', 
    -1: '#060606',
}

REGION_NAMES = {
    0: 'Unalaska',
    1: 'Chignik',
    2: 'Nanwalek',
    3: 'Kodiak', 
    4: 'Yakutat', 
    5: 'Sitka', 
    6: 'Other', 
    -1: 'Other',
}

v1 = read_data_w_cache(
    "select * from movement_model_inference_10_1_1 where run_id = '1a3a6a2a897b810b9d50ec306b1e211979a05def2e7b10b7240c9fdca8c410c9'"
)
v1['log_likelihood'] = np.log(v1['probability'])
v1['catch_region'] = v1['tag_key'].apply(catch_region_map)
v1['color'] = v1['catch_region'].apply(lambda x: PALETTE[x])

v2 = read_data_w_cache(
    "select * from movement_model_inference_10_1_2 where run_id = 'f509c72818a716514763d2bad4ad4b20bedf5ac5bf0e5d4a6e79ac9456fd8af0'"
)
v2['log_likelihood'] = np.log(v2['probability'])
v2['catch_region'] = v2['tag_key'].apply(catch_region_map)
v2['color'] = v2['catch_region'].apply(lambda x: PALETTE[x])

v3 = read_data_w_cache(
    "select * from movement_model_inference_10_1_3 where run_id = 'ef05b957fd20684af87e0435e816b9141ad9a7160eb834bead1778dcbc83693f'"
)
v3['log_likelihood'] = np.log(v3['probability'])
v3['catch_region'] = v3['tag_key'].apply(catch_region_map)
v3['color'] = v3['catch_region'].apply(lambda x: PALETTE[x])


v4 = read_data_w_cache(
    "select * from movement_model_inference_10_1_4 where run_id = '5ef29500bcc61f80f43209feccbd8f7ae0c244f9bc2aea0a1d529ad1d6e0c06c'"
)
v4['log_likelihood'] = np.log(v4['probability'])
v4['catch_region'] = v4['tag_key'].apply(catch_region_map)
v4['color'] = v4['catch_region'].apply(lambda x: PALETTE[x])

v7 = read_data_w_cache(
    "select * from movement_model_inference_10_1_7 where run_id = '87e2447a56e06fc633f2c00711f28003f3f9bd8c6f1ac93415196469340afc74'"
)
v7['log_likelihood'] = np.log(v7['probability'])
v7['catch_region'] = v7['tag_key'].apply(catch_region_map)
v7['color'] = v7['catch_region'].apply(lambda x: PALETTE[x])

In [None]:
cmap = plt.cm.RdBu
extents_map = {
    0: [-167, -153, 52, 58],
    1: [-172, -150, 50, 58],
    2: [-155, -140, 55, 62],
    3: [-155, -144, 57, 62],
    4: [-148, -130, 52, 62],
    5: [-137, -130, 51, 58],
}

def setup(ax, extent):
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines()

def quiver(ax, df, col):
    the_max = max(df[col].max(), -df[col].min())
    norm = Normalize(vmin=-the_max, vmax=the_max)

    return ax.quiver(
        df['x'], df['y'], df['u'], df['v'], df[col],
        transform=ccrs.PlateCarree(), 
        cmap=cmap, norm=norm,  
        width=0.005,
        scale=12,
        linestyle=[':']
    )

def plot_it(df, val, title, stretch=[0,0,0,0], cbar_axes=[0.17, 0.05, 0.7, 0.02], figsize=(5,5), orientation='horizontal'):
    catch_region = df['catch_region'].values[0]
    extent = extents_map[catch_region]


    extent = [e + s for e,s in zip(extent, stretch)]

    fig, axes = plt.subplots(figsize=figsize, ncols=1, nrows=1, subplot_kw={'projection': ccrs.PlateCarree()})
    setup(axes, extent)
    q = quiver(axes, df, val)
    cbar_axes = fig.add_axes(cbar_axes)
    fig.colorbar(q, cax=cbar_axes, orientation=orientation)

    axes.set_title(title)

    fig.subplots_adjust(hspace=0.1, wspace=0.1, bottom=0.15)
    fig.subplots_adjust(top=0.8)
    plt.tight_layout()
    return fig

def prep_models(new, old):
    new = new[new['_selected']]
    old = old[old['_selected']]

    new['origin_lat'] = new['origin_h3_index'].apply(lambda h: h3.h3_to_geo(h)[0])
    new['origin_lon'] = new['origin_h3_index'].apply(lambda h: h3.h3_to_geo(h)[1])
    new['next_lat'] = new['next_h3_index'].apply(lambda h: h3.h3_to_geo(h)[0])
    new['next_lon'] = new['next_h3_index'].apply(lambda h: h3.h3_to_geo(h)[1])

    new['x'] = new['origin_lon']
    new['y'] = new['origin_lat']
    new['u'] = new['next_lon'] - new['origin_lon']
    new['v'] = new['next_lat'] - new['origin_lat']

    new['catch_region'] = new['tag_key'].apply(catch_region_map)

    new = new[[
        'tag_key', '_decision', '_train', 'log_likelihood',
        'x', 'y', 'u', 'v', 'catch_region', 'time'
    ]]
    old = old[[
        'tag_key', '_decision', '_train', 'log_likelihood'
    ]]
    df = new.merge(
        old, on=['_train', '_decision', 'tag_key'], 
        suffixes=('_new', '_old'), how='inner'
    )
    df['new - old'] = df['log_likelihood_new'] - df['log_likelihood_old']
    return df

In [None]:
sel = prep_models(v7, v4)

In [None]:
sel[sel['catch_region'] == 3]['tag_key'].unique()

In [None]:
tag_key = sel[sel['catch_region'] == 1]['tag_key'].sample().values[0]
fig = plot_it(
    sel[sel['tag_key'] == tag_key], 'new - old', f'{tag_key}', 
)
fig.show()