# Toy Example of AttentionXML Model

## Requirements
This notebook uses the following non-standard python packages:
* pytorch
* treelib

In [1]:
import os
import treelib
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from itertools import chain
from dataclasses import dataclass
from typing import Tuple, Callable, Iterator, Any

In [2]:
# add base directory to path
if '../' not in os.sys.path:
    os.sys.path.insert(0, '../')
# import extreme multi label stuff
from xmlc.dataset import XMLDataset
from xmlc.modules import MLP, Attention, PLTHierarchy
from xmlc.metrics import precision
from xmlc.tree_utils import index_tree, convert_labels_to_ids

## Build an example tree

In [3]:
tree = treelib.Tree()
# add root node
root = tree.create_node("Root", "Root")
# layer 1
nodeA = tree.create_node("A", "A", parent=root)
nodeB = tree.create_node("B", "B", parent=root)
nodeC = tree.create_node("C", "C", parent=root)
# leaves of A
_ = tree.create_node("A1", "A1", parent=nodeA)
_ = tree.create_node("A2", "A2", parent=nodeA)
_ = tree.create_node("A3", "A3", parent=nodeA)
# leaves of B
_ = tree.create_node("B1", "B1", parent=nodeB)
_ = tree.create_node("B2", "B2", parent=nodeB)
_ = tree.create_node("B3", "B3", parent=nodeB)
_ = tree.create_node("B4", "B4", parent=nodeB)
# leaves of C
_ = tree.create_node("C1", "C1", parent=nodeC)
_ = tree.create_node("C2", "C2", parent=nodeC)

In [4]:
# show the tree
tree.show()

Root
├── A
│   ├── A1
│   ├── A2
│   └── A3
├── B
│   ├── B1
│   ├── B2
│   ├── B3
│   └── B4
└── C
    ├── C1
    └── C2



In [5]:
print("Depth:      ", tree.depth())
print("Totel nodes:", len(tree.all_nodes()))
print("Inner nodes:", len(tree.all_nodes()) - len(tree.leaves()))

Depth:       2
Totel nodes: 13
Inner nodes: 4


In [6]:
tree = index_tree(tree)
tree.show(data_property='level_index', idhidden=False)

0[Root]
├── 0[A]
│   ├── 0[A1]
│   ├── 1[A2]
│   └── 2[A3]
├── 1[B]
│   ├── 3[B1]
│   ├── 4[B2]
│   ├── 5[B3]
│   └── 6[B4]
└── 2[C]
    ├── 7[C1]
    └── 8[C2]



## Build some toy dataset

In [7]:
# sample target labels
target_labels = [
    ["A1", "B1"],
    ["C2"],
    ["A2", "A3", "B1", "C1"],
    ["A3", "B4", "C1"],
    ["A2", "B2", "B3"],
    ["B4"]
]
# sample input
n_examples = len(target_labels)
x = torch.rand((n_examples, 8, 64))
# and a sample input mask
input_mask = torch.ones(x.size()[:2], dtype=bool)

In [8]:
# build train and test datasets
data = XMLDataset(
    input_dataset=TensorDataset(x, input_mask),
    labels=convert_labels_to_ids(tree, target_labels),
    num_candidates=6
)
loader = DataLoader(data, batch_size=2)

## Model

In [9]:
model = PLTHierarchy(
    hidden_size=64,
    num_labels=9,
    attention=Attention(),
    classifier=MLP(64, 32, 1)
)
optim = torch.optim.Adam(model.parameters())

In [10]:
n_trainable_params = sum((p.numel() for p in model.parameters() if p.requires_grad))
print("#Trainable Parameters: %i" % n_trainable_params)

#Trainable Parameters: 2689


In [12]:
# train the first hierarchy
for _ in range(100):
    running_loss = 0
    for x, mask, candidates, labels in loader:
        # predict and compute loss
        logits = model.forward(x, mask, candidates)
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        # optimize
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        running_loss += loss.item()
    
    print(running_loss)

2.088257133960724
2.046960771083832
2.1164071559906006
2.0210236310958862
2.0359984636306763
2.0327056646347046
1.9874212145805359
1.9829522967338562
1.9522855877876282
2.045971095561981
2.058537006378174
2.089433431625366
2.106914222240448
2.0172876119613647
2.005183517932892
2.013204276561737
2.0451966524124146
1.9776143431663513
2.0099229216575623
1.9933056235313416
1.9899924397468567
2.004766583442688
1.9474824666976929
1.9656883478164673
2.012056529521942
1.9365354180335999
1.9883379936218262
1.9481623768806458
1.91921466588974
1.9738876223564148
1.9648289680480957
1.9492599368095398
1.984373152256012
1.9628809094429016
1.9049542546272278
2.010159492492676
1.9502894878387451
1.9784687757492065
1.9289073944091797
1.935418963432312
1.986329197883606
1.9656036496162415
1.9228374361991882
1.9117478132247925
1.9541298151016235
1.9317827820777893
1.9424652457237244
1.889069139957428
1.8679074048995972
1.939570665359497
1.8550596833229065
1.895934522151947
1.954360008239746
1.88725483417