In [97]:
import os 
from datetime import datetime

import pandas as pd
import numpy as np
import haven.db as db
import boto3
import tensorflow.keras as keras
from tqdm import tqdm

os.environ['AWS_PROFILE'] = 'admin'
os.environ['HAVEN_DATABASE'] = 'haven'

In [70]:
def load_model(space, experiment_name, run_id):
    bucket_name = f"{space}-models"
    model_key = f"{experiment_name}/{run_id}/model.keras"
    s3 = boto3.client("s3")
    s3.download_file(bucket_name, model_key, "model.keras")

    return keras.models.load_model("model.keras")

model = load_model(
    'mimic-log-odds', 'movement-model-experiment-v3-s1', 
    'e864f08e675a8bd39b0764be4827adf827b49064ed473695c4509cf0cabda693'
)

In [None]:
INPUT = pd.DataFrame([
    {'_quanta': 10.0, 'h3_index': '840c9ebffffffff', 'date': datetime(2020, 4, 17)},
    {'_quanta': 10.0, 'h3_index': '840c699ffffffff', 'date': datetime(2020, 4, 17)}
])
CONTEXT = {
    'max_km': 100.0, 
    'mean_log_npp': 1.9670798, 
    'mean_log_mlt': 3.0952279761654187,
    'features': [
        "normed_log_mlt", "normed_log_npp", "normed_distance", 
        "water_heading", "movement_heading"
    ],
    'essential': [
        'date', 'h3_index'
    ],
    'min_quanta': 0.01
}
INPUT

## Components

- Initial Integrity Checks (all dates are the same)
- Choice Expansion
- Check of Deduplication (After Selection)

In [30]:
import h3
import geopy.distance


def find_neighbors(h3_index, max_km):
    """
    Input:
    - h3_index (str): the H3 index

    Finds all the h3 indices whose centroids are 
    within `max_km`. 
    """
    h3_coords = h3.h3_to_geo(h3_index)
    checked = set([h3_index])
    neighbors = set([h3_index])
    distance = 1
    found_neighbors = True

    while found_neighbors:
        found_neighbors = False
        candidates = h3.k_ring(h3_index, distance)
        new_candidates = set(candidates) - checked
        for candidate in new_candidates:
            if geopy.distance.geodesic(h3_coords, h3.h3_to_geo(candidate)).km <= max_km:
                neighbors.add(candidate)
                found_neighbors = True
            checked.add(candidate)
        distance += 1
    return list(neighbors)

def h3_index_expand(input, context):
    max_km = context['max_km']
    neighbors_rows = []
    input = input.reset_index(drop=True).reset_index().rename({'index': '_decision', 'h3_index': 'origin_h3_index'}, axis=1)
    for h3_index in input['origin_h3_index']:
        neighbors = find_neighbors(h3_index, max_km)
        neighbors_rows.extend([
            {'origin_h3_index': h3_index, 'h3_index': neighbor}
            for neighbor in neighbors
        ])
    neighbors_df = pd.DataFrame(neighbors_rows)
    neighbors_df = neighbors_df.reset_index(drop=True).reset_index().rename({'index': '_choice'}, axis=1)
    return input.merge(neighbors_df)

In [50]:
def pull_physics(input, context):
    h3_indices = ','.join([f"'{h3_index}'" for h3_index in input['h3_index'].unique()])
    date = input['date'].dt.strftime('%Y-%m-%d').values[0]
    sql = f'''
    select 
        h3_index, 
        mixed_layer_thickness,
        velocity_east,
        velocity_north
    from 
        copernicus_physics
    where 
        date = '{date}'
        and h3_index in ({h3_indices})
        and depth_bin = 25.0
    '''
    physics = db.read_data(sql)
    return input.merge(physics, how='inner', on='h3_index')

def pull_biochemistry(input, context):
    h3_indices = ','.join([f"'{h3_index}'" for h3_index in input['h3_index'].unique()])
    date = input['date'].dt.strftime('%Y-%m-%d').values[0]
    sql = f'''
    select 
        h3_index, 
        net_primary_production
    from 
        copernicus_biochemistry
    where 
        date = '{date}'
        and h3_index in ({h3_indices})
        and depth_bin = 25.0
    '''
    biochemistry = db.read_data(sql)
    return input.merge(biochemistry, how='inner', on='h3_index')


In [63]:
def add_lon_lats(input, context):
    data = input.copy()
    data['origin_lat'] = data['origin_h3_index'].apply(lambda i: h3.h3_to_geo(i)[0])
    data['origin_lon'] = data['origin_h3_index'].apply(lambda i: h3.h3_to_geo(i)[1])
    data['lat'] = data['h3_index'].apply(lambda i: h3.h3_to_geo(i)[0])
    data['lon'] = data['h3_index'].apply(lambda i: h3.h3_to_geo(i)[1])
    return data

def add_distance(input, context): 
    data = input.copy()
    data['distance']  = data.apply(lambda r: geopy.distance.geodesic(
        (r['origin_lat'], r['origin_lon']),
        (r['lat'], r['lon'])
    ).km, axis=1)
    return data

def add_headings(input, context):
    data = input.copy()
    data['water_heading'] = data.apply(lambda r: np.arctan2(r['velocity_north'], r['velocity_east']), axis=1)
    data['movement_heading'] = data.apply(
        lambda r: np.arctan2(
            r['lat'] - r['origin_lat'],
            r['lon'] - r['origin_lon'] 
        ) if r['distance'] else 0, axis=1
    )
    return data

In [66]:
def normalize(input, CONTEXT):
    data = input.copy()
    data['normed_distance'] = data['distance'] / CONTEXT['max_km']
    data['normed_log_npp'] = np.log(data['net_primary_production']) - CONTEXT['mean_log_npp']
    data['normed_log_mlt'] = np.log(data['mixed_layer_thickness']) - CONTEXT['mean_log_mlt']
    return data

In [77]:
def run_model(input, model, CONTEXT):
    data = input.copy()
    data['odds'] = np.exp(model(data[CONTEXT['features']]))
    data['sum_odds'] = data.groupby('_decision')['odds'].transform('sum')
    data['probability'] = data['odds'] / data['sum_odds']
    return data

In [94]:
def group(input, CONTEXT):
    data = input.copy()
    data['_quanta'] = data['_quanta'] * data['probability']
    data['total_quanta'] = data.groupby('_decision')['_quanta'].transform('sum')
    data = data[data['_quanta'] >= CONTEXT['min_quanta']]
    data['remaining_quanta'] = data.groupby('_decision')['_quanta'].transform('sum')
    data['_quanta'] = data['_quanta'] * (data['total_quanta'] / data['remaining_quanta'])
    return data.groupby(CONTEXT['essential'])[['_quanta']].sum().reset_index()

def step_forward(input, CONTEXT):
    data = input.copy()
    data['date'] = data['date'] + pd.DateOffset(days=1)
    return data

In [None]:
def step(input, model, context):
    # expand choices
    choices = h3_index_expand(input, context)
    # add environmental features
    environment = pull_physics(choices, context)
    environment = pull_biochemistry(environment, context)
    # add derived features
    derived = add_lon_lats(environment, context)
    derived = add_distance(derived, context)
    derived = add_headings(derived, context)
    # normalize features
    normed = normalize(derived, context)
    # apply model
    results = run_model(normed, model, context)
    # decompose
    grouped = group(results, context)
    return step_forward(grouped, context)

results = [INPUT]
for _ in tqdm(range(50)):
    input = results[-1]
    results.append(step(input, model, CONTEXT))

results = pd.concat(results)
print(results.shape)
results.head()

In [None]:
from mirrorverse.plotting import plot_h3_slider, plot_h3_animation

plot_h3_animation(
    results,
    value_col='_quanta',
    h3_col='h3_index',
    slider_col='date',
    zmax=0.5
)

In [None]:
results[results['date'] == datetime(2020, 6, 6)]['_quanta'].sum()