# Toy Example of AttentionXML Model

## Requirements
This notebook uses the following non-standard python packages:
* numpy
* pytorch
* transformers
* treelib
* spacy
* matplotlib
* tqdm

In [1]:
import os
import spacy
import treelib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from transformers import Trainer, TrainingArguments
from itertools import chain
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

In [2]:
# add base directory to path
if '../' not in os.sys.path:
    os.sys.path.insert(0, '../')
# import extreme multi label stuff
from xmlc.dataset import LabelTreePureRandomDataset, LabelTreeGroupBasedDataset
from xmlc.modules import (
    MLP, 
    Attention, 
    MultiHeadAttention, 
    LabelAttentionClassifier,
    ProbabilityLabelTree
)
from xmlc.metrics import (
    precision, 
    coverage, 
    hits
)
from xmlc.tree_utils import (
    index_tree, 
    yield_tree_levels
)

## Paths and Hyperparameter

In [3]:
data_path = "/data/share/gsg_consulting/AttentionXML/data/ops"
fasttext_path = "/data/share/gsg_consulting/AttentionXML/models/gsg-fasttext"
output_dir = "../output/ops-fasttext-simple"

In [4]:
# data hyperparameters
max_length = 512
num_candidates = 128
# model hyperparameters
hidden_size=128
num_lstm_layers=1
dropout=0.5
# training hyperparemers
lr = 1e-3
num_epochs = 8

In [5]:
device = "cuda" if torch.cuda.is_available() else 'cpu'
print("Using device %s!" % device)

Using device cuda!


## Load raw data

In [6]:
def load_texts(fpath:str):
    with open(fpath, "r") as f:
        return f.readlines()

def load_labels(fpath:str):
    return [line.split() for line in load_texts(fpath)]

# load training data
train_texts = load_texts(os.path.join(data_path, "train_texts.txt"))
train_labels = load_labels(os.path.join(data_path, "train_labels.txt"))
assert len(train_texts) == len(train_labels)
# load test data
test_texts = load_texts(os.path.join(data_path, "test_texts.txt"))
test_labels = load_labels(os.path.join(data_path, "test_labels.txt"))
assert len(test_texts) == len(test_labels)

In [7]:
# get a list of all unique labels
unique_labels = np.unique(tuple(chain(*train_labels)))
print("# unique labels:", len(unique_labels))

# unique labels: 2703


## Build simple label tree

In [8]:
tree = treelib.Tree()
# add root node
root = tree.create_node("Root", "Root")
for label in unique_labels:
    tree.create_node(label, label, parent=root)
# for now just split by code-type (OPS-3 vs OPS-5)
# code_types = {
#     '3-': tree.create_node("OPS-3", "OPS-3", parent=root),
#     '5-': tree.create_node("OPS-5", "OPS-5", parent=root)
# }
# # add labels to corresponding group
# for label in unique_labels:
#     node = code_types[label[:2]]
#     tree.create_node(label, label, parent=node)

In [9]:
print("Depth:      ", tree.depth())
print("Totel nodes:", len(tree.all_nodes()))
print("Inner nodes:", len(tree.all_nodes()) - len(tree.leaves()))

Depth:       1
Totel nodes: 2704
Inner nodes: 1


In [10]:
# index the tree nodes
tree = index_tree(tree)

## Build Training and Evaluation Dataset

In [11]:
# load vocabulary
vocab = np.load(os.path.join(fasttext_path, "vocab.npy"))
embed = np.load(os.path.join(fasttext_path, "vectors.npy"))
# change special tokens
vocab[vocab == "<SEP>"] = "[SEP]"
vocab[vocab == "<PAD>"] = "[PAD]"
vocab[vocab == "<UNK>"] = "[UNK]"
# build token-id map
vocab_map = {token.lower(): i for i, token in enumerate(vocab)}

In [12]:
# get german tokenizer
from spacy.lang.de import German
# build tokenizer parameters
prefixes = German.Defaults.prefixes
suffixes = German.Defaults.suffixes
infixes = German.Defaults.infixes
prefix_search = spacy.util.compile_prefix_regex(prefixes).search if prefixes else None
suffix_search = spacy.util.compile_suffix_regex(suffixes).search if suffixes else None
infix_finditer = spacy.util.compile_infix_regex(infixes).finditer if infixes else None
# add tokenizer exception for special tokens
exc = German.Defaults.tokenizer_exceptions
exc = spacy.util.update_exc(exc, {
    '[SEP]': [{spacy.symbols.ORTH: "[SEP]"}]
})
# create tokenizer
tokenizer = spacy.tokenizer.Tokenizer(
    vocab=spacy.vocab.Vocab(strings=vocab.tolist()),
    rules=exc,
    prefix_search=prefix_search,
    suffix_search=suffix_search,
    infix_finditer=infix_finditer,
    token_match=German.Defaults.token_match,
    url_match=German.Defaults.url_match
)

In [13]:
def build_input_features(texts, max_length=256):
    # tokenize all texts and convert tokens to ids
    input_ids = []
    unk_token_id = vocab_map['[unk]']
    pad_token_id = vocab_map['[pad]']
    for text in tqdm(texts, "Tokenizing"):
        # tokenize and convert to ids
        doc = tokenizer(text)
        ids = [vocab_map.get(str(t).lower(), unk_token_id) for t in doc[:max_length]]
        ids = ids + [pad_token_id] * max(max_length - len(ids), 0)
        # add to list
        input_ids.append(ids)
    # convert to tensor
    input_ids = torch.LongTensor(input_ids)
    input_mask = (input_ids) != pad_token_id
    # return features
    return input_ids, input_mask

In [14]:
# train_input_ids, train_input_mask = build_input_features(train_texts, max_length=max_length)
# test_input_ids, test_input_mask = build_input_features(test_texts, max_length=max_length)
data = torch.load("data.bin")
train_input_ids = data["train-input-ids"]
train_input_mask = data["train-input-mask"]
test_input_ids = data["test-input-ids"]
test_input_mask = data["test-input-mask"]

In [15]:
# build train and test datasets
train_data = LabelTreeGroupBasedDataset(
    input_dataset=TensorDataset(train_input_ids, train_input_mask),
    tree=tree,
    labels=train_labels,
    num_candidates=num_candidates,
)
# observe that the task is much harder when using a Group-based dataset here
# as the negative candidates are more similar to the positives
# this seems the most fair as it mirrors the setup when no label tree is utilized
eval_data = LabelTreePureRandomDataset(
    input_dataset=TensorDataset(test_input_ids, test_input_mask),
    tree=tree,
    labels=test_labels,
    num_candidates=num_candidates,
)

## Model

In [16]:
class LSTMEncoder(nn.Module):
    """ Basic LSTM Encoder """
    
    def __init__(self, 
        embed_size:int,
        hidden_size:int, 
        num_layers:int,
        vocab_size:int,
        padding_idx:int,
        emb_init:torch.FloatTensor =None,
        dropout:float =0.2
    ) -> None:
        super(LSTMEncoder, self).__init__()
        self.dropout = dropout
        # create embedding
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_size,
            padding_idx=padding_idx,
            _weight=emb_init if emb_init is not None else None
        )
        # create lstm encoder
        self.lstm = nn.LSTM(
            input_size=embed_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=True
        )
        # initial hidden and cell states for lstm
        self.h0 = nn.Parameter(torch.zeros(num_layers*2, 1, hidden_size))
        self.c0 = nn.Parameter(torch.zeros(num_layers*2, 1, hidden_size))
                
    def forward(self, 
        input_ids:torch.LongTensor, 
        input_mask:torch.BoolTensor
    ) -> torch.Tensor:
        # flatten parameters
        self.lstm.flatten_parameters()
        # pass through embedding
        b, s = input_ids.size()
        x = self.embedding.forward(input_ids)
        x = F.dropout(x, p=self.dropout, training=self.training)
        # pack padded sequences
        lengths = input_mask.sum(dim=-1).cpu()
        packed_x = nn.utils.rnn.pack_padded_sequence(
            input=x, 
            lengths=lengths, 
            batch_first=True, 
            enforce_sorted=False
        )
        # apply lstm encoder
        h0 = self.h0.repeat_interleave(b, dim=1)
        c0 = self.c0.repeat_interleave(b, dim=1)
        packed_x, _ = self.lstm(packed_x, (h0, c0))
        # unpack packed sequences
        x, _ = nn.utils.rnn.pad_packed_sequence(
            sequence=packed_x, 
            batch_first=True, 
            padding_value=0,
            total_length=s
        )
        return F.dropout(x, p=self.dropout, training=self.training)
    
class Model(nn.Module):
    """ Model Combining LSTM-Encoder and a PLT Classifier """
    
    def __init__(self):
        super(Model, self).__init__()
        # create encoder
        self.enc = LSTMEncoder(
            embed_size=500, 
            hidden_size=hidden_size,
            num_layers=num_lstm_layers,
            vocab_size=vocab.shape[0], 
            padding_idx=vocab_map['[pad]'], 
            emb_init=torch.from_numpy(embed).float(),
            dropout=dropout
        )
        # create hierarchy classifier
        self.plt = ProbabilityLabelTree(
            tree=tree,
            cls_factory=self.classifier_factory
        )
        
    def classifier_factory(self, num_labels) -> nn.Module:
        return LabelAttentionClassifier(
            hidden_size=2*hidden_size, 
            num_labels=num_labels,
            attention=Attention(dropout=0.5),
            #   attention=MultiHeadAttention(
            #       embed_dim=2*hidden_size,
            #       num_heads=16,
            #       dropout=dropout
            #   ),
            classifier=MLP(2*hidden_size, 128, 1)
        )
        
    def forward(self, input_ids, input_mask, candidate_paths, labels=None):
        # pass through encoder
        x = self.enc(input_ids, input_mask)
        # apply classifier
        probs = self.plt(x, input_mask, candidate_paths=candidate_paths)
        # compute loss if targets given
        if labels is not None:
            loss = F.binary_cross_entropy(probs, labels)
            return {'loss': loss, 'logits': probs}
        # return logits only
        return {'logits': probs}

In [17]:
# create model and optimizer
model = Model()
optim = torch.optim.Adam(model.parameters(), lr=lr)



In [18]:
n_trainable_params = sum((p.numel() for p in model.parameters() if p.requires_grad))
print("#Trainable Parameters: %i" % n_trainable_params)

#Trainable Parameters: 158576625


## Training
I'm way too lazy to write this from scratch so lets just use the transformers Trainer class.

In [19]:
def compute_metrics(eval_preds):
    # unpack predictions and labels
    preds, labels = eval_preds
    preds = torch.FloatTensor(preds)
    labels = torch.LongTensor(labels)
    _, preds = torch.topk(preds, k=100, dim=-1)
    # compute metrics
    return {
        "P@1": precision(preds, labels, k=1),
        "P@2": precision(preds, labels, k=2),
        "P@3": precision(preds, labels, k=3),
        "P@5": precision(preds, labels, k=5),
        "C@1": coverage(preds, labels, k=1),
        "C@2": coverage(preds, labels, k=2),
        "C@3": coverage(preds, labels, k=3),
        "C@5": coverage(preds, labels, k=5),
        "H@1": hits(preds, labels, k=1),
        "H@2": hits(preds, labels, k=2),
        "H@3": hits(preds, labels, k=3),
        "H@5": hits(preds, labels, k=5),
    }

def collate(batch):
    """ default collate and return as dictionary """
    return dict(zip(
        ('input_ids', 'input_mask', 'candidate_paths', 'labels'),
        torch.utils.data._utils.collate.default_collate(batch)
    ))

In [20]:
# training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=256,
    save_steps=5_000,
    save_total_limit=2,
    report_to="none",
    logging_steps=250,
    eval_steps=250,
    evaluation_strategy='steps'
)
# trainer
trainer = Trainer(
    optimizers=(optim, None),
    model=model,
    args=training_args,
    data_collator=collate,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=compute_metrics
)

In [None]:
# train the model
trainer.train()

***** Running training *****
  Num examples = 284607
  Num Epochs = 8
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 256
  Gradient Accumulation steps = 1
  Total optimization steps = 8896


Step,Training Loss,Validation Loss,P@1,P@2,P@3,P@5,C@1,C@2,C@3,C@5,H@1,H@2,H@3,H@5
250,0.1333,0.070224,0.509147,0.421454,0.351802,0.254996,0.236003,0.39071,0.489208,0.590987,0.509147,0.550524,0.587795,0.631981
500,0.066,0.059196,0.545735,0.442077,0.366516,0.268976,0.252962,0.409828,0.509669,0.623387,0.545735,0.577463,0.61238,0.666627
750,0.0595,0.054902,0.587788,0.472725,0.389229,0.282785,0.272455,0.438241,0.541254,0.65539,0.587788,0.617497,0.650329,0.70085
1000,0.0552,0.050987,0.615776,0.491661,0.404878,0.294911,0.285428,0.455795,0.563015,0.683494,0.615776,0.642232,0.676476,0.730904
1250,0.0508,0.046049,0.648325,0.518413,0.424614,0.30718,0.300515,0.480596,0.590459,0.711929,0.648325,0.677177,0.70945,0.761311
1500,0.0454,0.039218,0.678831,0.541174,0.444318,0.319696,0.314656,0.501696,0.617858,0.740937,0.678831,0.706908,0.742372,0.792331
1750,0.0401,0.036695,0.725445,0.570753,0.462944,0.329713,0.336263,0.529118,0.64376,0.764151,0.725445,0.745547,0.773494,0.817156
2000,0.0374,0.03358,0.748729,0.589926,0.476978,0.337049,0.347055,0.546892,0.663275,0.781155,0.748729,0.770592,0.796941,0.835339
2250,0.0347,0.031365,0.768686,0.604633,0.487748,0.342637,0.356306,0.560526,0.678252,0.794106,0.768686,0.789802,0.814936,0.849189
2500,0.0321,0.02947,0.795343,0.624329,0.499168,0.348301,0.368662,0.578785,0.694132,0.807233,0.795343,0.81553,0.834017,0.863226


***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 21045
  Batch size = 512
***** Running Evaluation *****
  Num examples = 210

In [None]:
# save final model state dict to disk
torch.save(model.state_dict(), os.path.join(output_dir, "final-model.bin"))

## Plot Metrics

In [None]:
# select only train and test logs
train_logs = [log for log in trainer.state.log_history if 'loss' in log]
eval_logs = [log for log in trainer.state.log_history if 'eval_loss' in log]
# gather values from logs
train_metrics = {
    key: [log[key] for log in train_logs]
    for key in ['step', 'loss']
}
eval_logs = {
    key: [log[key] for log in eval_logs]
    for key in [
        'step', 
        'eval_loss', 
        'eval_P@1', 
        'eval_P@2', 
        'eval_P@3', 
        'eval_P@5',
        'eval_C@1', 
        'eval_C@2', 
        'eval_C@3', 
        'eval_C@5',
        'eval_H@1', 
        'eval_H@2', 
        'eval_H@3', 
        'eval_H@5',
    ]
}

In [None]:
fig, (ax_loss, ax_p, ax_c, ax_h) = plt.subplots(4, 1, figsize=(12, 20), sharex=True)
# plot losses
ax_loss.plot(train_metrics['step'], train_metrics['loss'], label="train")
ax_loss.plot(eval_logs['step'], eval_logs['eval_loss'], label="eval")
ax_loss.set(
    title="Train and Test Loss",
    ylabel="Loss"
)
ax_loss.legend()
ax_loss.grid()
# plot precisions
ax_p.plot(eval_logs['step'], eval_logs['eval_P@1'], label="$k=1$")
ax_p.plot(eval_logs['step'], eval_logs['eval_P@2'], label="$k=2$")
ax_p.plot(eval_logs['step'], eval_logs['eval_P@3'], label="$k=3$")
ax_p.plot(eval_logs['step'], eval_logs['eval_P@5'], label="$k=5$")
ax_p.set(
    title="Precision@k",
    ylabel="Precision"
)
ax_p.legend()
ax_p.grid()
# plot coverages
ax_c.plot(eval_logs['step'], eval_logs['eval_C@1'], label="$k=1$")
ax_c.plot(eval_logs['step'], eval_logs['eval_C@2'], label="$k=2$")
ax_c.plot(eval_logs['step'], eval_logs['eval_C@3'], label="$k=3$")
ax_c.plot(eval_logs['step'], eval_logs['eval_C@5'], label="$k=5$")
ax_c.set(
    title="Coverage@k",
    ylabel="Coverage"
)
ax_c.legend()
ax_c.grid()
# plot hits
ax_h.plot(eval_logs['step'], eval_logs['eval_H@1'], label="$k=1$")
ax_h.plot(eval_logs['step'], eval_logs['eval_H@2'], label="$k=2$")
ax_h.plot(eval_logs['step'], eval_logs['eval_H@3'], label="$k=3$")
ax_h.plot(eval_logs['step'], eval_logs['eval_H@5'], label="$k=5$")
ax_h.set(
    title="Hits@k",
    ylabel="Hits",
    xlabel="Global Step"
)
ax_h.legend()
ax_h.grid()
# save figure
fig.savefig(os.path.join(output_dir, "metrics.pdf"))
plt.show()