In [None]:
# !pip install cython
# !pip install git+https://github.com/Refefer/fastxml # latest version on pypi isn't up to date

In [1]:
import fastxml
import numpy as np
import pickle
from sklearn.metrics import label_ranking_average_precision_score

from utils import load_data

In [2]:
train_features, train_labels, dev_features, dev_labels = load_data("data")

FastXML wants labels to be a list of lists and features to be a list of CSR matrices

In [3]:
y_train = list(train_labels.tolil().rows)
y_dev = list(dev_labels.tolil().rows)

In [4]:
X_train = list(train_features.tocsr())
X_dev = list(dev_features.tocsr())

In [6]:
fxml_trainer = fastxml.Trainer()  # default hyperparams for now

In [7]:
fxml_trainer.fit(X_train, y_train)
fxml_trainer.save("fastxml_v0")
clf = fastxml.Inferencer("fastxml_v0")



In [8]:
y_hat = clf.predict(X_dev, fmt="sparse")

Performance on dev set is not good

In [9]:
label_ranking_average_precision_score(dev_labels.toarray(), y_hat.toarray())

0.18526724818456133

As a sanity check, performance on training data:

In [10]:
label_ranking_average_precision_score(
    train_labels.toarray(), clf.predict(X_train, fmt="sparse").toarray()
)

0.33567487165301385

## Hyperparam Search

There are (many) parameters on the trainer and several on the inferencer that need to be experimented with

In [90]:
params = {
    "max_leaf_size": [10, 50, 100],
    "max_labels_per_leaf": [20, 50],
    "alpha": [1e-5, 1e-3, 1e-1],
    "n_trees": [1, 50, 100],
    "n_jobs": [-1],
}

In [91]:
from sklearn.model_selection import ParameterGrid

In [92]:
param_grid = ParameterGrid(params)

In [93]:
estimators = {}

In [94]:
for param_set in param_grid:
    key = tuple(param_set.values())
    print(param_set)
    estimators[key] = fastxml.Trainer(**param_set)
    estimators[key].fit(X_train, y_train)

{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 10, 'n_jobs': -1, 'n_trees': 1}
{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 10, 'n_jobs': -1, 'n_trees': 50}
{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 10, 'n_jobs': -1, 'n_trees': 100}
{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 50, 'n_jobs': -1, 'n_trees': 1}
{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 50, 'n_jobs': -1, 'n_trees': 50}
{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 50, 'n_jobs': -1, 'n_trees': 100}
{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 100, 'n_jobs': -1, 'n_trees': 1}
{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 100, 'n_jobs': -1, 'n_trees': 50}
{'alpha': 1e-05, 'max_labels_per_leaf': 20, 'max_leaf_size': 100, 'n_jobs': -1, 'n_trees': 100}
{'alpha': 1e-05, 'max_labels_per_leaf': 50, 'max_leaf_size': 10, 'n_jobs': -1, 'n_trees': 1}
{'alpha': 1e-05, 'max_labels_per_leaf': 50, 'max_leaf_size

In [102]:
estimator_fns = {}

In [106]:
estimators[(1e-05, 20, 10, -1, 1)]

10

In [115]:
i = 0
results = {}

In [117]:
import shutil

In [118]:
for k, v in estimators.items():
    print(k)
    v.save(f"./fastxml_models/{i}")
    inf = fastxml.Inferencer(f"./fastxml_models/{i}")
    y_hat = inf.predict(X_dev, fmt="sparse")
    results[k] = label_ranking_average_precision_score(
        dev_labels.toarray(), y_hat.toarray()
    )
    print(results[k])
    shutil.rmtree(f"./fastxml_models/{i}")
    i = i + 1

(1e-05, 20, 10, -1, 1)
0.1849028313223925
(1e-05, 20, 10, -1, 50)
0.44220462618546486
(1e-05, 20, 10, -1, 100)
0.452127399708976
(1e-05, 20, 50, -1, 1)
0.2121308776574277
(1e-05, 20, 50, -1, 50)
0.3944658003494442
(1e-05, 20, 50, -1, 100)
0.4032892929318616
(1e-05, 20, 100, -1, 1)
0.2030716676106386
(1e-05, 20, 100, -1, 50)
0.36749581570362067
(1e-05, 20, 100, -1, 100)
0.37212391413378104
(1e-05, 50, 10, -1, 1)
0.18566932808801223
(1e-05, 50, 10, -1, 50)
0.44366768830386755
(1e-05, 50, 10, -1, 100)
0.45387800608261775
(1e-05, 50, 50, -1, 1)
0.21664849870855518
(1e-05, 50, 50, -1, 50)
0.398973500828358
(1e-05, 50, 50, -1, 100)
0.4077735034122924
(1e-05, 50, 100, -1, 1)
0.20992459751469103
(1e-05, 50, 100, -1, 50)
0.3729999619749015
(1e-05, 50, 100, -1, 100)
0.3774440102786261
(0.001, 20, 10, -1, 1)
0.18865454050519156
(0.001, 20, 10, -1, 50)
0.44345989493114635
(0.001, 20, 10, -1, 100)
0.453787867158334
(0.001, 20, 50, -1, 1)
0.21024357827995613
(0.001, 20, 50, -1, 50)
0.394057204806760

In [119]:
with open("results.pickle", "wb") as f:
    pickle.dump(results, f)

In [121]:
max(results.values())

0.4552619511425552