In [None]:
import os

import h3
import numpy as np
import pandas as pd
import plotly.express as px

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import cartopy.crs as ccrs

from mirrorverse.utils import read_data_w_cache
from mirrorverse.plotting import plot_h3_slider, plot_h3_animation

os.environ['HAVEN_DATABASE'] = 'haven'
os.environ['AWS_PROFILE'] = 'admin'

In [None]:
def get_data(version):
    train = read_data_w_cache(
        f'select * from movement_model_inference_m3_a4_v{version}'
    )
    train.loc[train['_train'], 'case'] = 'train'
    train.loc[~train['_train'], 'case'] = 'val'
    #test = read_data_w_cache(
    #    f'select * from movement_model_inference_m3_a3_v{version}_test'
    #)
    #test['case'] = 'test'
    data = train#data = pd.concat([train, test])
    data['log_prob'] = np.log(data['probability'])
    data['version'] = str(version)
    return data 

v5 = get_data(5)
v6 = get_data(6)
v7 = get_data(7)

v0 = v5.copy()
v0['probability'] = 1/19
v0['log_prob'] = np.log(v0['probability'])
v0['version'] = '0'
v0['log_odds'] = np.nan 
v0['odds'] = np.nan 

data = pd.concat([v0, v5, v6, v7])
data.head()


In [None]:
baselines_dict = {}
versions = list(data['version'].unique())
for version in versions:
    baselines_dict[version] = data[(data['version'] == version) & data['_selected']][
        ['_train', '_individual', '_decision', 'origin_h3_index', 'next_h3_index', 'log_prob', 'time', 'distance', 'tag_key']
    ]
    baselines_dict[version] = baselines_dict[version].rename({'log_prob': f'log_prob_{version}'}, axis=1)
baselines = baselines_dict[versions[0]]
for version in versions[1:]:
    baselines = baselines.merge(baselines_dict[version])
for v1 in sorted(versions, reverse=True):
    for v2 in versions:
        if v1 > v2:
            baselines[f'diff_{v1}-{v2}'] = baselines[f'log_prob_{v1}'] - baselines[f'log_prob_{v2}']
baselines

In [None]:
baselines['origin_lat'] = baselines['origin_h3_index'].apply(lambda h: h3.h3_to_geo(h)[0])
baselines['origin_lon'] = baselines['origin_h3_index'].apply(lambda h: h3.h3_to_geo(h)[1])
baselines['next_lat'] = baselines['next_h3_index'].apply(lambda h: h3.h3_to_geo(h)[0])
baselines['next_lon'] = baselines['next_h3_index'].apply(lambda h: h3.h3_to_geo(h)[1])

baselines['x'] = baselines['origin_lon']
baselines['y'] = baselines['origin_lat']
baselines['u'] = baselines['next_lon'] - baselines['origin_lon']
baselines['v'] = baselines['next_lat'] - baselines['origin_lat']

In [None]:
norm = Normalize(vmin=-1.0, vmax=1.0)
cmap = plt.cm.RdBu

individuals_map = {
    "unalaska": [56, 58, 104], #[0, 4, 56, 58, 101, 104, 107, 108], 
    "chignik": [21, 48, 95], #[19, 21, 48, 95],
    "nanwalek": [32, 10, 13], #[10, 13, 32, 37, 92, 96, 103, 105, 109], #, #
    "kodiak": [99, 60, 49], #[35, 49, 60, 99, 110],
    "yakutat": [93, 102, 106], #[39, 93, 97, 102, 106],
    "sitka": [75, 83, 86], # [75, 79, 83, 86]
    "ebs": [98]
}
extents_map = {
    "chignik": [-167, -153, 52, 58],
    "unalaska": [-172, -150, 50, 58],
    "nanwalek": [-155, -140, 55, 62],
    "kodiak": [-155, -144, 57, 62],
    "yakutat": [-148, -130, 52, 62],
    "sitka": [-137, -130, 51, 58],
}
heights_map = {
    "chignik": 2.7,
    "unalaska": 2.3,
    "nanwalek": 3,
    "kodiak": 3,
    "yakutat": 3.3, 
    "sitka": 4, 
}

region = "sitka"
individuals = individuals_map[region]
extent = extents_map[region]

fig, axes = plt.subplots(figsize=(10, heights_map[region]*len(individuals)), ncols=2, nrows=len(individuals), subplot_kw={'projection': ccrs.PlateCarree()})

def setup(ax, extent):
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    ax.coastlines()

def quiver(ax, df, col):
    return ax.quiver(
        df['x'], df['y'], df['u'], df['v'], df[col],
        transform=ccrs.PlateCarree(), 
        cmap=cmap, norm=norm,  
        width=0.005,
        scale=9,
        linestyle=[':']
    )

for i, individual in enumerate(individuals):
    left = axes[i, 0]
    right = axes[i, 1]
    setup(left, extent)
    setup(right, extent)

    df = baselines[(baselines['_individual'] == individual) & (baselines['distance'] > 0)]
    quiver_left = quiver(left, df, 'diff_6-5')
    quiver_right = quiver(right, df, 'diff_7-6')

    tag_key = df['tag_key'].values[0]
    left.set_title(f'tag id: {tag_key}, {individual}')


# Add colorbars for each column
cbar_ax_left = fig.add_axes([0.12, 0.89, 0.35, 0.02])  # [left, bottom, width, height]
cbar_ax_right = fig.add_axes([0.53, 0.89, 0.35, 0.02])

cbar_left = fig.colorbar(quiver_left, cax=cbar_ax_left, orientation='horizontal')
cbar_right = fig.colorbar(quiver_right, cax=cbar_ax_right, orientation='horizontal')

cbar_left.set_label('Log Likelihood Change 1 - 2')
cbar_right.set_label('Log Likelihood Change 2 - 3')

fig.subplots_adjust(hspace=0.1, wspace=0.1)
fig.subplots_adjust(top=0.8)

fig.suptitle("Changes to Log Likelihoods for Models 1 and 2", fontsize=16, y=0.95)


plt.show()

In [None]:
by_fish = data[data['_selected']].groupby(['_individual', 'tag_key', 'version', '_train'])['log_prob'].sum().reset_index()
by_fish_dict = {}
versions = list(by_fish['version'].unique())
for version in versions:
    by_fish_dict[version] = by_fish[by_fish['version'] == version]
    by_fish_dict[version] = by_fish_dict[version].rename({'log_prob': f'log_prob_{version}'}, axis=1)
    del by_fish_dict[version]['version']
by_fish = by_fish_dict[versions[0]]
for version in versions[1:]:
    by_fish = by_fish.merge(by_fish_dict[version])
for v1 in sorted(versions, reverse=True):
    for v2 in versions:
        if v1 > v2:
            by_fish[f'diff_{v1}-{v2}'] = by_fish[f'log_prob_{v1}'] - by_fish[f'log_prob_{v2}']
by_fish

In [None]:
individuals_map = {
    "unalaska": [0, 4, 56, 58, 101, 104, 107, 108], 
    "chignik": [19, 21, 48, 95],
    "nanwalek": [8, 10, 13, 32, 37, 92, 96, 103, 105, 109, 91, 100], #, #
    "kodiak": [35, 49, 60, 99, 110],
    "yakutat": [39, 93, 97, 102, 106],
    "sitka": [75, 79, 83, 86, 94],
    "ebs": [98]
}

map_individuals = {}
for region, individuals in individuals_map.items():
    for individual in individuals:
        map_individuals[individual] = region

df = by_fish[~by_fish['_train']]
df['region'] = df['_individual'].apply(lambda i: map_individuals[i])
df = df.sort_values('region').reset_index(drop=True).reset_index()
px.bar(
    df, x='index', y='diff_7-5', color='region'
)