# Validation

In [1]:
import collections
from pathlib import Path

from geo_transformers.data_loading import TransformedDataset
from geo_transformers.models.any_horizon_forecast_transformer import AnyHorizonForecastTransformerCLI, SelectionHead
from geo_transformers import notebook_utils, training_utils
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch
from torch.nn import functional as F
import seaborn as sns
from tqdm.auto import tqdm
import wandb
import yaml

## Load model and data

In [2]:
ENTITY = "lirmm-zenith"
PROJECT = "geo-transformers"

def load_run(run_id):
    run_path = f"{ENTITY}/{PROJECT}/{run_id}"
    run_dir = Path(f"exp/forecast_mbk+ung_12h/{PROJECT}/{run_id}")

    if (run_dir / "config.yaml").exists():
        f = open(run_dir / "config.yaml")
    else:
        run_dir.mkdir(exist_ok=True)
        f = wandb.restore("config.yaml", run_path=run_path, root=run_dir)
    with f:
        config_dict = yaml.load(f, Loader=yaml.SafeLoader)
    config_dict = config_dict["fit"]["value"]
    config_dict["trainer"]["logger"] = False
    del config_dict["seed_everything"], config_dict["ckpt_path"]
    
    config, objects = notebook_utils.load_experiment(AnyHorizonForecastTransformerCLI, config_dict)
    
    return run_dir, config, objects

In [3]:
with notebook_utils.chdir(".."):
    _, config, objects = load_run("uxc2lb3w")
    dmodule = objects["data"]
    del config, objects

    dmodule.setup(stage="predict")
    dataset_raw = dmodule.test_dataloader().dataset
    dataset_raw.preprocess_fn = lambda x: x
    dataset_raw.num_candidates = 100
    dataset_raw = list(tqdm(dataset_raw))

  0%|          | 0/151 [00:00<?, ?it/s]

In [4]:
def cat_padded(tensors):
    max_len = max(t.shape[1] for t in tensors)
    return torch.cat(
        [F.pad(t, (0, 0) * (len(t.shape) - 2) + (0, max_len - t.shape[1])) for t in tensors],
        axis=0
    )

## Get predictions and metrics

In [5]:
def predict(run_id, no_context=False):
    with notebook_utils.chdir(".."):
        run_dir, config, objects = load_run(run_id)

        [ckpt_path] = sorted((run_dir / "checkpoints").glob("*-best.ckpt"))
        print(ckpt_path)

        model = objects["model"]
        dmodule = objects["data"]

        model.load_state_dict(torch.load(ckpt_path)["state_dict"])

    model.cuda()
    model.train(False)
    
    dloader = torch.utils.data.DataLoader(
        TransformedDataset(dataset_raw, dmodule.processor.encode),
        batch_size=8,
        collate_fn=training_utils.collate_sequence_dicts,
        num_workers=0
    )

    masks, preds = [], []
    with torch.inference_mode():
        for batch in tqdm(dloader, leave=False):
            batch = dmodule.transfer_batch_to_device(batch, device=model.device, dataloader_idx=0)
        
            encoder_kwargs = {}
            if no_context:
                encoder_kwargs.update(
                    attn_mask=torch.eye(batch["mask"].shape[1], dtype=bool, device=model.device)
                )

            pred, _, _ = model(batch, encoder_kwargs=encoder_kwargs)
            preds.append(pred.cpu())
            masks.append(batch["mask"].cpu())
    return cat_padded(preds), cat_padded(masks)

In [6]:
def compute_accuracies(preds, mask, k=1, num_candidates=None):
    assert num_candidates is None or preds.shape[-1] >= num_candidates
    hits = (preds[..., :num_candidates].topk(k, dim=-1).indices == 0).any(dim=-1)
    return ((hits * mask).sum(dim=-1) / mask.sum(dim=-1))

In [18]:
RUNS = [
    ("vdv96ee1", False, "full context"),
    ("vdv96ee1", True, "full context, diag"),
    ("37ld98g9", False, "var context"),
    ("37ld98g9", True, "var context, diag"),
    ("2u6ajp2u", False, "no att"),
    ("2rygu3l4", False, "no enc"),
]

results = []
for run_id, no_context, desc in tqdm(RUNS):
    preds, mask = predict(run_id, no_context=no_context)

    tgt_oh = F.one_hot(torch.zeros(preds.shape[:-1], dtype=int), num_classes=preds.shape[-1])
    xent = F.kl_div(preds.log_softmax(dim=-1), tgt_oh, log_target=False, reduction="none").sum(dim=-1) / np.log(preds.shape[-1])
    xent_16 = F.kl_div(preds[:, :, :16].log_softmax(dim=-1), tgt_oh[:, :, :16], log_target=False, reduction="none").sum(dim=-1) / np.log(16)
    logprob_ctrl_16 = preds[:, :, 1:16].log_softmax(dim=-1)
    results.append(pd.DataFrame({
        "desc": desc,
        "seq_id": [traj["seq_id"][0] for traj in dataset_raw],
        "xent@100": ((xent * mask).sum(dim=1) / mask.sum(dim=1)),
        "xent@16": ((xent_16 * mask).sum(dim=1) / mask.sum(dim=1)),
        "ent@16": -((logprob_ctrl_16 * logprob_ctrl_16.exp()).sum(dim=-1) * mask).sum(dim=1) / mask.sum(dim=1) / np.log(16),
        "acc 8/16": compute_accuracies(preds, mask, 8, 16),
        "acc 50/100": compute_accuracies(preds, mask, 50, 100),
        "acc 1/4": compute_accuracies(preds, mask, 1, 4),
        "acc 10/100": compute_accuracies(preds, mask, 10, 100),
        "acc 1/16": compute_accuracies(preds, mask, 1, 16),
    }))

results = pd.concat(results, ignore_index=True)

  0%|          | 0/6 [00:00<?, ?it/s]

exp/forecast_mbk+ung_12h/geo-transformers/vdv96ee1/checkpoints/epoch=49-step=6949-best.ckpt


  0%|          | 0/19 [00:00<?, ?it/s]

exp/forecast_mbk+ung_12h/geo-transformers/vdv96ee1/checkpoints/epoch=49-step=6949-best.ckpt


  0%|          | 0/19 [00:00<?, ?it/s]

exp/forecast_mbk+ung_12h/geo-transformers/37ld98g9/checkpoints/epoch=169-step=23629-best.ckpt


  0%|          | 0/19 [00:00<?, ?it/s]

exp/forecast_mbk+ung_12h/geo-transformers/37ld98g9/checkpoints/epoch=169-step=23629-best.ckpt


  0%|          | 0/19 [00:00<?, ?it/s]

exp/forecast_mbk+ung_12h/geo-transformers/2u6ajp2u/checkpoints/epoch=179-step=24950-best.ckpt


  0%|          | 0/19 [00:00<?, ?it/s]

exp/forecast_mbk+ung_12h/geo-transformers/2rygu3l4/checkpoints/epoch=175-step=24394-best.ckpt


  0%|          | 0/19 [00:00<?, ?it/s]

In [31]:
results.groupby("desc", sort=False).mean()

Unnamed: 0_level_0,xent@100,xent@16,ent@16,acc 8/16,acc 50/100,acc 1/4,acc 10/100,acc 1/16
desc,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
full context,0.9089,0.868632,0.834409,0.730447,0.737179,0.458209,0.293231,0.197842
"full context, diag",0.997791,0.990456,0.905633,0.590871,0.593715,0.322613,0.156565,0.101702
var context,0.894168,0.847382,0.816869,0.755732,0.761867,0.487215,0.323221,0.220644
"var context, diag",0.953701,0.931739,0.897415,0.660647,0.667223,0.381508,0.204419,0.135908
no att,0.945193,0.919238,0.866843,0.683045,0.689779,0.402607,0.231137,0.157229
no enc,0.950192,0.927839,0.905191,0.657577,0.661735,0.38331,0.217173,0.148166


### Save/load metrics

In [11]:
results.to_csv("mbk+ung_12h_eval_results.csv", index=False)

In [30]:
results = pd.read_csv("mbk+ung_12h_eval_results.csv")

## Make result tables and plots

In [32]:
cols = ["xent@16", "xent@100", "acc 1/16", "acc 10/100"]
df = results.groupby("desc", sort=False).mean()[cols].reset_index()
df["desc"] = (df["desc"]
              .str.replace("full context", r"\\textsc{FullCtx}")
              .str.replace("var context", r"\\textsc{VarCtx}")
              .str.replace(", diag", r"+\\textsc{diag}")
              .str.replace("no att", r"\\textsc{NoAtt}")
              .str.replace("no enc", r"\\textsc{NoEnc}"))
styler = df.style
styler.highlight_max(subset=[c for c in cols if c.startswith("acc")], props="font-weight: bold")
styler.highlight_min(subset=[c for c in cols if not c.startswith("acc")], props="font-weight: bold")
styler.applymap(lambda x: "font-weight: bold" if x == "var context" else "", subset="desc")
styler.format(precision=3)
#styler.format(precision=4, subset=["xent@16", "xent@100"])
styler.hide(axis="index")

ipd.display(styler)
print(styler.to_latex(siunitx=True, convert_css=True, hrules=True, column_format="lS[table-format=1.3]S[table-format=1.3]S[table-format=1.3]S[table-format=1.3]"))

desc,xent@16,xent@100,acc 1/16,acc 10/100
\textsc{FullCtx},0.869,0.909,0.198,0.293
\textsc{FullCtx}+\textsc{diag},0.99,0.998,0.102,0.157
\textsc{VarCtx},0.847,0.894,0.221,0.323
\textsc{VarCtx}+\textsc{diag},0.932,0.954,0.136,0.204
\textsc{NoAtt},0.919,0.945,0.157,0.231
\textsc{NoEnc},0.928,0.95,0.148,0.217


\begin{tabular}{lS[table-format=1.3]S[table-format=1.3]S[table-format=1.3]S[table-format=1.3]}
\toprule
{desc} & {xent@16} & {xent@100} & {acc 1/16} & {acc 10/100} \\
\midrule
\textsc{FullCtx} & 0.869 & 0.909 & 0.198 & 0.293 \\
\textsc{FullCtx}+\textsc{diag} & 0.990 & 0.998 & 0.102 & 0.157 \\
\textsc{VarCtx} & \bfseries 0.847 & \bfseries 0.894 & \bfseries 0.221 & \bfseries 0.323 \\
\textsc{VarCtx}+\textsc{diag} & 0.932 & 0.954 & 0.136 & 0.204 \\
\textsc{NoAtt} & 0.919 & 0.945 & 0.157 & 0.231 \\
\textsc{NoEnc} & 0.928 & 0.950 & 0.148 & 0.217 \\
\bottomrule
\end{tabular}



In [33]:
results[cols].corr(method="pearson").abs().to_numpy().min()

0.8760016421711023

In [41]:
individuals_df = pq.read_table("../data/movebank+ungulates/individuals.parquet").to_pandas().set_index("id")
seq_to_individual = (pq.read_table("../data/movebank+ungulates/locations_12h/", columns=["seq_id", "individual_id"])
 .to_pandas().drop_duplicates()
 .join(individuals_df, on="individual_id")
 .set_index("seq_id"))

In [42]:
TAXON_COLS = ["taxon_class", "taxon_order", "taxon_species"]

In [44]:
with notebook_utils.chdir(".."):
    dmodule.setup(stage="fit")
    dataset_train_raw = dmodule.train_dataloader().dataset
    dataset_train_raw.preprocess_fn = lambda x: x
    dataset_train_raw.num_candidates = None
    indiv_counts = collections.Counter([seq_to_individual.loc[sid]["individual_id"] for traj in tqdm(dataset_train_raw) for sid in traj["seq_id"]])

  0%|          | 0/3333 [00:00<?, ?it/s]

In [45]:
taxon_stats = pd.read_pickle("../data/movebank+ungulates/taxon_stats.pickle")

In [46]:
df = (results
 .query("desc == 'var context'")
 .join(seq_to_individual, on="seq_id").sort_values(by=TAXON_COLS)
 .groupby("taxon_order", sort=False)
 .agg({**{"xent@16": "mean"}, **{col: "first" for col in TAXON_COLS[:-1]}})
 .reset_index(drop=True)
 .join(taxon_stats.groupby("order").sum()[[("#obs", "train")]].set_axis(["#train"], axis="columns"), on="taxon_order")
 .set_index(TAXON_COLS[:-1])
)
df.index.names = [name.replace("taxon_", "") for name in df.index.names]

ipd.display(df[["xent@16", "#train"]].corr(method="pearson"))

ipd.display(df)
print(df.style.format(precision=3).to_latex(
    siunitx=True, convert_css=True, hrules=True, multirow_align="t",
    column_format="llS[table-format=1.3]S[table-format=6]"
).replace(" nan ", " --- "))

Unnamed: 0,xent@16,#train
xent@16,1.0,-0.714386
#train,-0.714386,1.0


Unnamed: 0_level_0,Unnamed: 1_level_0,xent@16,#train
class,order,Unnamed: 2_level_1,Unnamed: 3_level_1
Aves,Accipitriformes,0.814555,58994
Aves,Anseriformes,0.826731,64008
Aves,Cathartiformes,1.057042,8653
Aves,Charadriiformes,0.815012,205602
Aves,Ciconiiformes,0.696822,237304
Mammalia,Artiodactyla,0.927873,201464
Mammalia,Carnivora,0.986052,12282
Mammalia,Proboscidea,0.980351,24870
Reptilia,Testudines,0.997962,34577


\begin{tabular}{llS[table-format=1.3]S[table-format=6]}
\toprule
{} & {} & {xent@16} & {#train} \\
{class} & {order} & {} & {} \\
\midrule
\multirow[t]{5}{*}{Aves} & Accipitriformes & 0.815 & 58994 \\
 & Anseriformes & 0.827 & 64008 \\
 & Cathartiformes & 1.057 & 8653 \\
 & Charadriiformes & 0.815 & 205602 \\
 & Ciconiiformes & 0.697 & 237304 \\
\multirow[t]{3}{*}{Mammalia} & Artiodactyla & 0.928 & 201464 \\
 & Carnivora & 0.986 & 12282 \\
 & Proboscidea & 0.980 & 24870 \\
Reptilia & Testudines & 0.998 & 34577 \\
\bottomrule
\end{tabular}

