In [1]:
import torch

from openmax.weibull_fitting import WeibullFitting
from openmax.openmax import OpenMax

In [2]:
# toy data params
num_class = 10
emb_dim = 100
samples_per_class = [100, 200, 300, 400, 500, 100, 200, 300, 400, 500]
class_names = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
num_samples_test = 20

# parameters of OpenMax process
tailsize = 20
d_type = 'eucl'
alpha = 3
# threshold is selected based on the train data (see paper)
threshold = 0.3

# generate toy data
train_embs = [torch.randn(n_samples, emb_dim) for n_samples in samples_per_class]
test_logits = torch.randn(num_samples_test, num_class)
test_embs = torch.randn(num_samples_test, emb_dim)

In [3]:
# fit Weibull models on the train data
wf = WeibullFitting(tailsize, num_class)
distances, centroids = wf.compute_centroids_and_distances(train_embs, distance_type=d_type)
weibull_models = wf.fit_all_models(distances)

In [4]:
# recalibrate toy test logits
om = OpenMax(centroids, weibull_models, alpha=alpha, distance_type=d_type)
logits_hat = om.recalibrate_logits(test_logits, embs=test_embs)
# compute recalibrated SoftMax scores
preds = om.compute_probs(logits_hat)

In [5]:
# prob of the unknown class is in the 0-th column
preds.shape

torch.Size([20, 11])

In [6]:
# the label is considered "unknown" if the highest score is below the threshold
# or corresponds to the 0-th (unknown) class
test_labels = [row.argmax(-1).item() if row.max() > threshold else 0 for row in preds]
test_labels = ["unk" if label==0 else class_names[label-1] for label in test_labels]
# labels before recalibration
test_labels_original = [class_names[row.argmax().item()] for row in test_logits]

In [7]:
print(test_labels)

['h', 'unk', 'unk', 'g', 'unk', 'unk', 'f', 'i', 'j', 'd', 'j', 'unk', 'unk', 'unk', 'unk', 'unk', 'unk', 'unk', 'unk', 'b']


In [8]:
print(test_labels_original)

['h', 'c', 'g', 'g', 'g', 'g', 'f', 'i', 'j', 'd', 'j', 'j', 'e', 'j', 'h', 'f', 'b', 'h', 'h', 'b']
