In [1]:
import json
import pathlib

import matplotlib
from matplotlib import pyplot as plt
import ml_collections
import numpy as np
import torch

import hier
import infer
import main
import metrics
import progmet
import tree_util

In [2]:
hierarchy_file = 'resources/hierarchy/inat21.csv'

config_file = '/mnt/ssd1/projects/2022-01-hierarchical/experiments/2022-03-31-inat21mini/share_random_cut-0.1-lr-0.01-b-64-wd-0.0003-ep-20/config.json'
model_file = '/mnt/ssd1/projects/2022-01-hierarchical/experiments/2022-03-31-inat21mini/share_random_cut-0.1-lr-0.01-b-64-wd-0.0003-ep-20/checkpoints/epoch-0020.pth'

with open(config_file, 'r') as f:
    config = ml_collections.ConfigDict(json.load(f))

with open(hierarchy_file, 'r') as f:
    tree, node_keys = hier.make_hierarchy_from_edges(hier.load_edges(f))

In [3]:
num_outputs = main.get_num_outputs(config.predict, tree)
net = main.make_model(config.model, num_outputs)
missing_keys, unexpected_keys = net.load_state_dict(torch.load(model_file), strict=True)

In [4]:
missing_keys

[]

In [5]:
unexpected_keys

[]

In [6]:
device = torch.device('cuda')
net.to(device)
loss_fn, pred_fn = main.make_loss(config, tree, device)

In [7]:
train_dataset, eval_dataset, tree, _, train_label_map, eval_label_map = main.make_datasets(config)

In [8]:
eval_loader = torch.utils.data.DataLoader(
    dataset=eval_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=0)

In [9]:
is_leaf = tree.leaf_mask()
specificity = -tree.num_leaf_descendants()
not_trivial = (tree.num_children() != 1)

In [10]:
infer_fns = {
    'leaf': lambda p: infer.argmax_where(p, is_leaf),
    'majority': lambda p: infer.argmax_with_confidence(specificity, p, 0.5, not_trivial),
}

In [14]:
net.eval()
with torch.inference_mode():
    for minibatch in eval_loader:
        inputs, gt_labels = minibatch
        theta = net(inputs.to(device))
        prob = pred_fn(theta).cpu().numpy()

#         pred = {}
#         for name, infer_fn in infer_fns.items():
#             pred[name] = infer_fn(prob)
#         gt_node = eval_label_map.to_node[gt_labels]
#         pred_seqs = [
#             infer.pareto_optimal_predictions(specificity, p, 0., not_trivial)
#             for p in prob
#         ]
#         prob_seqs = [prob[i, pred_i] for i, pred_i in enumerate(pred_seqs)]
#         # Caution: Predictions are *not* truncated.

        break

In [21]:
import hier_torch

In [26]:
torch.set_printoptions(sci_mode=False)

In [31]:
cumsum_theta = hier_torch.SumAncestors(tree, exclude_root=True).to(device)(theta)
cumsum_theta

tensor([[  0.0000,   8.6953,  -0.6484,  ..., -11.5918, -12.1372, -12.6753],
        [  0.0000,   6.6484,  -1.1562,  ...,  -8.7016,  -8.5163,  -8.3366],
        [  0.0000,  -0.0764,  -7.9904,  ...,  -9.4089, -10.1082, -10.7209],
        ...,
        [  0.0000,   2.5156,  -6.3516,  ...,  -9.4768, -10.1394, -10.7942],
        [  0.0000,   6.3203,  -1.6562,  ..., -11.7842, -12.4482, -13.0137],
        [  0.0000,   4.8398,  -3.8711,  ...,  -9.2546,  -9.4617,  -9.6505]],
       device='cuda:0')

In [33]:
cumsum_theta[:, tree.leaf_subset()].max(axis=-1)

torch.return_types.max(
values=tensor([11.5344,  9.4092,  3.9820, 14.9359,  3.2108, 11.4448,  6.3154, 10.5169,
         6.5687, 11.0928,  8.6688, 23.4878, 15.7400, 11.1978,  5.2783,  7.5959,
        13.3012, 13.4761, 19.1018, 10.8657,  7.7310, 11.2949,  0.6331,  5.9296,
        11.7687,  4.2833, 35.6245, 14.0425,  1.8423, 16.8071, 11.6873, 14.0469,
        18.6725,  6.4614,  9.8580,  3.1846,  6.3537, 13.8838, 11.7275,  9.2106,
        12.6653,  3.7393, 14.5012, 15.0444,  9.9788, 13.6125,  7.0927, 11.2204,
        15.0648, 27.5705, 10.7206,  2.1573,  0.9351, 11.4608,  7.7632,  7.2695,
         5.4991,  8.7581, 20.7783,  9.9878,  1.5166,  9.7607,  4.1229, 12.8275,
        12.7468, 11.0901,  6.9628,  8.8615, 12.2487,  4.0459,  1.2185, 13.0119,
         9.9775,  2.4891, 14.7791, 16.5249, 10.6653,  5.1162,  9.3907,  9.4849,
         4.6335, 15.8599, 11.5303,  5.2402,  2.8973, 14.4236, 18.4309,  6.4275,
        18.1780,  9.7974,  6.2887,  1.2339,  0.4567, 12.3842,  1.6608,  8.0033,
         

In [29]:
torch.sigmoid(5 + hier_torch.SumAncestors(tree, exclude_root=True).to(device)(theta))

tensor([[    0.9933,     1.0000,     0.9873,  ...,     0.0014,     0.0008,
             0.0005],
        [    0.9933,     1.0000,     0.9790,  ...,     0.0241,     0.0289,
             0.0343],
        [    0.9933,     0.9928,     0.0479,  ...,     0.0120,     0.0060,
             0.0033],
        ...,
        [    0.9933,     0.9995,     0.2056,  ...,     0.0112,     0.0058,
             0.0030],
        [    0.9933,     1.0000,     0.9659,  ...,     0.0011,     0.0006,
             0.0003],
        [    0.9933,     0.9999,     0.7556,  ...,     0.0140,     0.0114,
             0.0095]], device='cuda:0')