In [10]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pickle
import h3
import plotly.io as pio
pio.renderers.default = "notebook_connected"

from choice_network import RunOrDriftNetwork
from build_network import create_pairs, group_headings

In [3]:
data = pd.read_csv("data/tag_tracks.csv").rename({
    "Ptt": "ptt",
    "Latitude": "lat",
    "Longitude": "lon",
    "Dates - Date Key → Date": "date",
    "Dates - Date Key → Year": "year",
    "Dates - Date Key → Month": "month",
    "Dates - Date Key → Day": "day",
}, axis=1)
print(data.shape)
data.head()

(7532, 7)


Unnamed: 0,ptt,lat,lon,date,year,month,day
0,129843,54.13176,-166.922615,2013-12-19,2013,12,19
1,129843,54.258072,-166.884086,2013-12-20,2013,12,20
2,129843,54.312433,-166.910525,2013-12-21,2013,12,21
3,129843,54.35828,-166.817057,2013-12-22,2013,12,22
4,129843,54.389694,-166.676901,2013-12-23,2013,12,23


In [4]:
pairs = create_pairs(data)
print(pairs.shape)
pairs.head()

100%|██████████| 111/111 [00:00<00:00, 242.94it/s]

(7421, 15)





Unnamed: 0,ptt,start_lat,start_lon,end_lat,end_lon,start_date,end_date,heading,start_h3,end_h3,start_month,start_day,end_month,end_day,remained
0,129843,53.98098,-166.800355,54.360925,-166.742418,2013-12-19,2013-12-20,1.419474,8422d0bffffffff,8422d03ffffffff,12,19,12,20,False
1,129843,54.360925,-166.742418,54.360925,-166.742418,2013-12-20,2013-12-21,,8422d03ffffffff,8422d03ffffffff,12,20,12,21,True
2,129843,54.360925,-166.742418,54.360925,-166.742418,2013-12-21,2013-12-22,,8422d03ffffffff,8422d03ffffffff,12,21,12,22,True
3,129843,54.360925,-166.742418,54.360925,-166.742418,2013-12-22,2013-12-23,,8422d03ffffffff,8422d03ffffffff,12,22,12,23,True
4,129843,54.360925,-166.742418,54.360925,-166.742418,2013-12-23,2013-12-24,,8422d03ffffffff,8422d03ffffffff,12,23,12,24,True


In [5]:
grouped_pairs = []
for ptt in tqdm(pairs['ptt'].unique()):
    df = group_headings(
        pairs[pairs['ptt'] == ptt],
        np.pi/4,
        150
    )
    grouped_pairs.append(df)
grouped_pairs = pd.concat(grouped_pairs)
grouped_pairs.head()

100%|██████████| 111/111 [00:04<00:00, 23.75it/s]


Unnamed: 0,ptt,start_lat,start_lon,end_lat,end_lon,start_date,end_date,heading,start_h3,end_h3,...,start_day,end_month,end_day,remained,group,steps_in_group,momentum,mean_heading,steps_since_group,drift_group
0,129843,53.98098,-166.800355,54.360925,-166.742418,2013-12-19,2013-12-20,1.419474,8422d0bffffffff,8422d03ffffffff,...,19,12,20,False,-1.0,,False,,1.0,1.0
1,129843,54.360925,-166.742418,54.360925,-166.742418,2013-12-20,2013-12-21,,8422d03ffffffff,8422d03ffffffff,...,20,12,21,True,-1.0,,False,,2.0,1.0
2,129843,54.360925,-166.742418,54.360925,-166.742418,2013-12-21,2013-12-22,,8422d03ffffffff,8422d03ffffffff,...,21,12,22,True,-1.0,,False,,3.0,1.0
3,129843,54.360925,-166.742418,54.360925,-166.742418,2013-12-22,2013-12-23,,8422d03ffffffff,8422d03ffffffff,...,22,12,23,True,-1.0,,False,,4.0,1.0
4,129843,54.360925,-166.742418,54.360925,-166.742418,2013-12-23,2013-12-24,,8422d03ffffffff,8422d03ffffffff,...,23,12,24,True,-1.0,,False,,5.0,1.0


In [6]:
data = grouped_pairs.copy()
data['drifting'] = ~data['momentum']
data.loc[data['drifting'], 'steps_in_state'] = data.loc[data['drifting'], 'steps_since_group']
data.loc[~data['drifting'], 'steps_in_state'] = data.loc[~data['drifting'], 'steps_in_group']
data['h3_index'] = data['start_h3']
data['month'] = data['start_month']
data['mean_heading'] = data['mean_heading'].fillna(0)
data['date'] = data['start_date']
data = data[['ptt', 'h3_index', 'month', 'mean_heading', 'drifting', 'steps_in_state', 'date']]
data.head()

Unnamed: 0,ptt,h3_index,month,mean_heading,drifting,steps_in_state,date
0,129843,8422d0bffffffff,12,0.0,True,1.0,2013-12-19
1,129843,8422d03ffffffff,12,0.0,True,2.0,2013-12-20
2,129843,8422d03ffffffff,12,0.0,True,3.0,2013-12-21
3,129843,8422d03ffffffff,12,0.0,True,4.0,2013-12-22
4,129843,8422d03ffffffff,12,0.0,True,5.0,2013-12-23


In [7]:
with open("models.pkl", "rb") as fh:
    models = pickle.load(fh)

RunOrDriftNetwork.import_models(models)

In [8]:
def simulate(ptt, h3_index, date, steps):
    state = {
        'drifting': True,
        'steps_in_state': 1,
        'h3_index': h3_index,
        'month': date.month,
        'mean_heading': 0,
    }
    row = {k: v for k, v in state.items()}
    row['ptt'] = ptt
    row['date'] = date
    rows = [row]
    for i in range(steps):
        choice_state = {}
        RunOrDriftNetwork.choose(state, choice_state)

        date = date + pd.Timedelta(days=1)
        steps_in_state = state['steps_in_state'] + 1 if state['drifting'] == choice_state['drifting'] else 1
        new_state = {
            'drifting': choice_state['drifting'],
            'steps_in_state': steps_in_state,
            'h3_index': choice_state['h3_index'],
            'month': date.month,
            'mean_heading': 0 if choice_state['drifting'] else choice_state['mean_heading'],
        }

        row = {k: v for k, v in new_state.items()}
        row['ptt'] = ptt
        row['date'] = date
        rows.append(row)

        state = new_state
    
    return pd.DataFrame(rows)

dfs = []
ptts = list(data['ptt'].unique())
for ptt in tqdm(ptts):
    df = data[data['ptt'] == ptt].sort_values('date', ascending=True).iloc[0]
    steps = data[data['ptt'] == ptt].shape[0]
    date = pd.to_datetime(df['date'])
    df = simulate(df['ptt'], df['h3_index'], date, steps)
    dfs.append(df)

100%|██████████| 111/111 [05:58<00:00,  3.23s/it]


In [14]:
def plot_it(data, lat, lon, color):
    fig = px.scatter_geo(
        data, lat=lat, lon=lon, color=color,
    )
    fig.update_layout(autosize=True, height=600, geo=dict(center=dict(lat=58, lon=-150), projection_scale=6))
    return fig

df = pd.concat(dfs)
df['lat'] = df.apply(lambda row: h3.h3_to_geo(row['h3_index'])[0], axis=1)
df['lon'] = df.apply(lambda row: h3.h3_to_geo(row['h3_index'])[1], axis=1)
plot_it(
    df,
    "lat",
    "lon",
    "ptt",
)

In [15]:
data['lat'] = data.apply(lambda row: h3.h3_to_geo(row['h3_index'])[0], axis=1)
data['lon'] = data.apply(lambda row: h3.h3_to_geo(row['h3_index'])[1], axis=1)
plot_it(
    data,
    "lat",
    "lon",
    "ptt",
)