In [1]:
import os
import plotly.express as px
import numpy as np
import h3
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import cartopy.crs as ccrs

os.environ['HAVEN_DATABASE'] = 'haven'
os.environ['AWS_PROFILE'] = 'admin'

from mirrorverse.utils import read_data_w_cache

In [2]:
cmap = plt.cm.RdBu
extents_map = {
    0: [-167, -153, 52, 58],
    1: [-172, -150, 50, 58],
    2: [-155, -140, 55, 62],
    3: [-155, -144, 57, 62],
    4: [-148, -130, 52, 62],
    5: [-137, -130, 51, 58],
}

def setup(ax, extent):
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines()

def quiver(ax, df, col):
    the_max = max(df[col].max(), -df[col].min())
    norm = Normalize(vmin=-the_max, vmax=the_max)

    return ax.quiver(
        df['x'], df['y'], df['u'], df['v'], df[col],
        transform=ccrs.PlateCarree(), 
        cmap=cmap, norm=norm,  
        width=0.005,
        scale=12,
        linestyle=[':']
    )

def plot_it(df, val, stretch=[0,0,0,0]):
    catch_region = df['catch_region'].values[0]
    extent = extents_map[catch_region]


    extent = [e + s for e,s in zip(extent, stretch)]

    fig, axes = plt.subplots(figsize=(10, 5), ncols=1, nrows=1, subplot_kw={'projection': ccrs.PlateCarree()})
    setup(axes, extent)
    q = quiver(axes, df, val)
    cbar_axes = fig.add_axes([0.12, 0.89, 0.35, 0.02])
    fig.colorbar(q, cax=cbar_axes, orientation='horizontal')

    fig.subplots_adjust(hspace=0.1, wspace=0.1)
    fig.subplots_adjust(top=0.8)
    return fig

def prep_models(new, old):
    new = new[new['_selected']]
    old = old[old['_selected']]

    new['origin_lat'] = new['origin_h3_index'].apply(lambda h: h3.h3_to_geo(h)[0])
    new['origin_lon'] = new['origin_h3_index'].apply(lambda h: h3.h3_to_geo(h)[1])
    new['next_lat'] = new['next_h3_index'].apply(lambda h: h3.h3_to_geo(h)[0])
    new['next_lon'] = new['next_h3_index'].apply(lambda h: h3.h3_to_geo(h)[1])

    new['x'] = new['origin_lon']
    new['y'] = new['origin_lat']
    new['u'] = new['next_lon'] - new['origin_lon']
    new['v'] = new['next_lat'] - new['origin_lat']

    new['catch_region'] = new['tag_key'].apply(catch_region_map)

    new = new[[
        'tag_key', '_decision', '_train', 'log_likelihood',
        'x', 'y', 'u', 'v', 'catch_region', 'time'
    ]]
    old = old[[
        'tag_key', '_decision', '_train', 'log_likelihood'
    ]]
    df = new.merge(
        old, on=['_train', '_decision', 'tag_key'], 
        suffixes=('_new', '_old'), how='inner'
    )
    df['new - old'] = df['log_likelihood_new'] - df['log_likelihood_old']
    return df

In [56]:
def catch_region_map(tag_key):
    for i, _id in enumerate(['172', '202', '159', '205', '210', '229', '142']):
        if tag_key.startswith(_id):
            return i
    return -1

PALETTE = {
    0: '#648FFF',
    1: '#785EF0',
    2: '#DC267F',
    3: '#FE6100', 
    4: '#FFB000', 
    5: '#3E589E', 
    6: '#060606', 
    -1: '#060606',
}

REGION_NAMES = {
    0: 'Unalaska',
    1: 'Chignik',
    2: 'Nanwalek',
    3: 'Kodiak', 
    4: 'Yakutat', 
    5: 'Sitka', 
    6: 'Other', 
    -1: 'Other',
}

v1 = read_data_w_cache(
    "select * from movement_model_inference_m3_a4_v5 where run_id = '4085c6b4443b44924e5a4b49570215d7135424a228cb72cc91689e1f20a561d0'"
)
v1['log_likelihood'] = np.log(v1['probability'])
v1['catch_region'] = v1['tag_key'].apply(catch_region_map)
v1['color'] = v1['catch_region'].apply(lambda x: PALETTE[x])

v2 = read_data_w_cache(
    "select * from movement_model_inference_m3_a4_v6 where run_id = 'f69050939419fc0438d1130f57ccad8cef8ee1dbea2b9785a7817231c310e623'"
)
v2['log_likelihood'] = np.log(v2['probability'])
v2['catch_region'] = v2['tag_key'].apply(catch_region_map)
v2['color'] = v2['catch_region'].apply(lambda x: PALETTE[x])

v3 = read_data_w_cache(
    "select * from movement_model_inference_m3_a4_v7 where run_id = '49ee9ca7158e7620df5cc726ab286e7f5069184aeb9579da8694ebd91d43b3ec'"
)
v3['log_likelihood'] = np.log(v3['probability'])
v3['catch_region'] = v3['tag_key'].apply(catch_region_map)
v3['color'] = v3['catch_region'].apply(lambda x: PALETTE[x])

v0 = v2.copy()
v0['odds'] = 1.0
v0['sum_odds'] = v0.groupby(['_individual', '_decision'])['odds'].transform('sum')
v0['probability'] = v0['odds'] / v0['sum_odds']
v0['log_likelihood'] = np.log(v0['probability'])
v0['catch_region'] = v0['tag_key'].apply(catch_region_map)
v0['color'] = v0['catch_region'].apply(lambda x: PALETTE[x])

In [75]:
def compare_models(models, names, agg, incremental=True, figsize=(10, 8)):
    base_model = models[0]
    base_name = names[0]
    fig, axes = plt.subplots(figsize=figsize, nrows=(len(models) - 1))
    for i, (name, model) in enumerate(zip(names[1:], models[1:])):
        df = base_model[base_model['_selected']][['tag_key', '_decision', 'log_likelihood', 'color']].merge(
            model[model['_selected']][['tag_key', '_decision', 'log_likelihood']],
            on=['tag_key', '_decision'], how='inner', suffixes=(base_name, name)
        )
        df[f'{name} - {base_name}'] = df[f'log_likelihood{name}'] - df[f'log_likelihood{base_name}']
        df = df.groupby(['tag_key', 'color'])[f'{name} - {base_name}'].agg(agg).reset_index()
        df = df.sort_values('tag_key')

        ax = axes[i] if len(models) > 2 else axes
        ax.bar(
            df['tag_key'], df[f'{name} - {base_name}'], color=df['color']
        )
        
        if i < len(models) - 2:
            ax.set_xticklabels([])
        else:
            # Rotate x-axis labels by 90 degrees for the last subplot
            ax.set_xticks(df['tag_key'])
            ax.set_xticklabels(df['tag_key'], rotation=90)

        ax.set_title(f'{name} - {base_name}')

        if incremental:
            base_model = model 
            base_name = name

    fig.text(0.0, 0.5, 'Mean Difference in Log Likelihood', va='center', rotation='vertical')

    # Add a custom legend
    legend_handles = [
        plt.Line2D(
            [0], [0], color=PALETTE[key], lw=4, label=REGION_NAMES[key]
        ) for key in PALETTE.keys() if key in REGION_NAMES and key != -1
    ]
    fig.legend(
        handles=legend_handles, 
        loc='upper right', 
        title='Regions', 
        bbox_to_anchor=(1, 1)  # Adjust position as needed
    )

    plt.tight_layout()
    return fig

In [None]:
fig = compare_models(
    [v0[v0['_train']], v1[v1['_train']], v2[v2['_train']], v3[v3['_train']]], 
    ['Null', 'Distance', 'Heading', 'Food'], 
    np.mean, figsize=(12, 10)
)
fig.savefig("ll_change_train.png")
fig.show()

In [None]:
fig = compare_models(
    [v0[~v0['_train']], v1[~v1['_train']], v2[~v2['_train']], v3[~v3['_train']]], 
    ['Null', 'Distance', 'Heading', 'Food'], 
    np.mean, figsize=(12, 10)
)
fig.savefig("ll_change_val.png")
fig.show()

In [165]:
cmap = plt.cm.RdBu
extents_map = {
    0: [-167, -153, 52, 58],
    1: [-172, -150, 50, 58],
    2: [-155, -140, 55, 62],
    3: [-155, -144, 57, 62],
    4: [-148, -130, 52, 62],
    5: [-137, -130, 51, 58],
}

def setup(ax, extent):
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines()

def quiver(ax, df, col):
    the_max = max(df[col].max(), -df[col].min())
    norm = Normalize(vmin=-the_max, vmax=the_max)

    return ax.quiver(
        df['x'], df['y'], df['u'], df['v'], df[col],
        transform=ccrs.PlateCarree(), 
        cmap=cmap, norm=norm,  
        width=0.005,
        scale=12,
        linestyle=[':']
    )

def plot_it(df, val, title, stretch=[0,0,0,0], cbar_axes=[0.17, 0.05, 0.7, 0.02], figsize=(5,5), orientation='horizontal'):
    catch_region = df['catch_region'].values[0]
    extent = extents_map[catch_region]


    extent = [e + s for e,s in zip(extent, stretch)]

    fig, axes = plt.subplots(figsize=figsize, ncols=1, nrows=1, subplot_kw={'projection': ccrs.PlateCarree()})
    setup(axes, extent)
    q = quiver(axes, df, val)
    cbar_axes = fig.add_axes(cbar_axes)
    fig.colorbar(q, cax=cbar_axes, orientation=orientation)

    axes.set_title(title)

    fig.subplots_adjust(hspace=0.1, wspace=0.1, bottom=0.15)
    fig.subplots_adjust(top=0.8)
    plt.tight_layout()
    return fig



In [None]:
sel = prep_models(v2, v1)

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '229202'], 'new - old', '229202: Heading - Distance (LL)', 
    [1,0,0,-1], cbar_axes=[0.07, 0.12, 0.02, 0.7], orientation='vertical'
)
fig.savefig('229202_hd.png')
fig.show()

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '159017b'], 'new - old', '159017b: Heading - Distance (LL)', 
    [-1,3,0,1], cbar_axes=[0.07, 0.26, 0.7, 0.02], orientation='horizontal', figsize=(5, 3)
)
fig.savefig('159017b_hd.png')
fig.show()

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '202600'], 'new - old', '202600: Heading - Distance (LL)', 
    [9,10,10,-5], cbar_axes=[0.07, 0.35, 0.7, 0.02], orientation='horizontal', figsize=(5, 3)
)
fig.savefig('202600_hd.png')
fig.show()

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '229209'], 'new - old', '229209: Heading - Distance (LL)', 
    [1,0,0,-1], cbar_axes=[0.07, 0.12, 0.02, 0.7], orientation='vertical'
)
fig.savefig('229209_hd.png')
fig.show()

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '159005b'], 'new - old', '159005b: Heading - Distance (LL)', 
    [-1,3,0,1], cbar_axes=[0.07, 0.26, 0.7, 0.02], orientation='horizontal', figsize=(5, 3)
)
fig.savefig('159005b_hd.png')
fig.show()

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '202588'], 'new - old', '202588: Heading - Distance (LL)', 
    [1,1,10,-7], cbar_axes=[0.07, 0.27, 0.7, 0.02], orientation='horizontal', figsize=(5, 3)
)
fig.savefig('202588_hd.png')
fig.show()

In [None]:
sel = prep_models(v3, v2)

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '159020'], 'new - old', '159020: Food - Heading (LL)', 
    [-1,0,0,1], cbar_axes=[0.07, 0.26, 0.7, 0.02], orientation='horizontal', figsize=(5, 3)
)
fig.savefig('159020_hd.png')
fig.show()

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '172913'], 'new - old', '172913: Food - Heading (LL)', 
    [-4,-2,0,1], cbar_axes=[0.24, 0.26, 0.7, 0.02], orientation='horizontal', figsize=(5, 3)
)
fig.savefig('172913_hd.png')
fig.show()

In [None]:
fig = plot_it(
    sel[sel['tag_key'] == '172915'], 'new - old', '172915: Food - Heading (LL)', 
    [-4,-2,0,1], cbar_axes=[0.24, 0.26, 0.7, 0.02], orientation='horizontal', figsize=(5, 3)
)
fig.savefig('172915_hd.png')
fig.show()