In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import itertools
import mlflow
import numpy as np
import os
import pandas as pd
import pathlib
import plotly.express as px
import torch
import xarray as xr
import yaml

import omegaconf

In [None]:
from smc01.postprocessing.lightning import SMC01Module
from smc01.postprocessing.cli.util import find_checkpoint_file, make_datasets, make_dataloader

In [None]:
from smc01.postprocessing.util import load_checkpoint_from_run

In [None]:
DATA_DIR = pathlib.Path(os.getenv('DATA_DIR'))

In [None]:
with hydra.initialize_config_module('smc01.postprocessing.conf'):
    cfg = hydra.compose('validate', ['experiment=attention_gdps_metar', 'experiment.model.big_features_embedding=False'])

In [None]:
train_dataset, val_dataset, test_dataset, n_stations, n_features = make_datasets(cfg)

In [None]:
val_dataloader = make_dataloader(cfg, val_dataset, shuffle=False, concat_collate=False)

In [None]:
CHECKPOINT_DIR = DATA_DIR / 'runs/postprocessing/multirun/2022-05-31/13-53-07/4854681'

In [None]:
model = hydra.utils.instantiate(cfg.experiment.model, 1226, 18)
module = SMC01Module.load_from_checkpoint(find_checkpoint_file(CHECKPOINT_DIR), model=model, full_stations_file=DATA_DIR / 'interpolated/2021-12-20-gdps-metar/stations.csv')

In [None]:
one_example = next(iter(val_dataloader))

In [None]:
pred = model.forward(one_example)

In [None]:
pred

In [None]:
model.station_embedding.shape

In [None]:
affinity = torch.matmul(model.station_embedding, model.station_embedding.T).detach().numpy()

In [None]:
affinity.shape

In [None]:
px.imshow(affinity.detach().numpy())

In [None]:
len(val_dataset.stations)

In [None]:
val_dataset.stations.index('CYUL')

In [None]:
affinity[437]

In [None]:
stations = pd.read_csv(DATA_DIR / 'interpolated/2021-12-20-gdps-metar/stations_w_metadata.csv')

In [None]:
stations

In [None]:
to_append = pd.DataFrame([{'station': None, 'latitude': None, 'longitude': None, 'elevation': None}])

In [None]:
stations = pd.concat([stations, to_append])

In [None]:
stations

In [None]:
val_dataset.stations.index('OAK')

In [None]:
stations['cyul_affinity'] = affinity[437]

In [None]:
stations['oak_affinity'] = affinity[901]

In [None]:
stations

In [None]:
px.scatter_geo(data_frame=stations, lat='latitude', lon='longitude', color='oak_affinity', hover_name='station')