In [None]:
import os
import random

import h3
import numpy as np
import pandas as pd
import geopy.distance
import plotly.express as px
from mirrorverse.plotting import plot_h3_slider, plot_h3_animation
from mirrorverse.utils import read_data_w_cache

os.environ['HAVEN_DATABASE'] = 'haven'
os.environ['AWS_PROFILE'] = 'admin'

In [None]:
sql = '''
select  
    i._individual,
    i._decision,
    i._choice,
    f._selected,
    i._train,
    i.log_odds,
    i.odds,
    i.probability,
    f.h3_index,
    f.time,
    f.net_primary_production,
    f.mixed_layer_thickness,
    f.distance,
    f.water_heading,
    f.movement_heading
from 
    movement_model_inference_m2_a1 i 
    inner join movement_model_features_m2_a1 f 
        on i._individual = f._individual
        and i._decision = f._decision
        and i._choice = f._choice
'''
data = read_data_w_cache(sql)
data['date'] = data['time'].dt.strftime("%Y-%m-%d")
print(data.shape)
data.head()

In [None]:
sql = '''
select  
    _individual,
    -avg(ln(probability)) as score
from 
    movement_model_inference_m2_a1
where 
    _selected
group by 
    1
'''
rankings = read_data_w_cache(sql).sort_values('score', ascending=False)
print(rankings.shape)
rankings

In [None]:
def set_line_color(row):
    if row['distance'] == 0:
        return "orange"
    elif row['_selected']:
        return "purple"
    else:
        return "black"
    
data['color'] = data.apply(set_line_color, axis=1)

In [None]:
def add_common_time(data):
    data['mod_date'] = data['date'].apply(lambda d: '-'.join(['2020'] + list(d.split('-')[1:])))
    data['_time'] = pd.to_datetime(data['mod_date'])
    return data

poor_fits = add_common_time(
    data[data['color'] == 'orange'].merge(
        rankings[rankings['score'] >= rankings['score'].quantile(0.75)]
    )
).groupby(['h3_index', '_time'])[['_individual']].mean().reset_index()
poor_fits['color'] = 'orange'
good_fits = add_common_time(
    data[data['color'] == 'orange'].merge(
        rankings[rankings['score'] <= rankings['score'].quantile(0.25)]
    )
).groupby(['h3_index', '_time'])[['_individual']].mean().reset_index()
good_fits['color'] = 'purple'
overall = pd.concat([poor_fits, good_fits])
plot_h3_animation(
    overall, 
    value_col='_individual',
    h3_col='h3_index',
    slider_col='_time',
    line_color_col='color',
    bold_colors=['orange', 'purple'],
    zoom=3, 
    center={"lat": 55, "lon": -165},
    duration=200
)