# Implementation of *[An embarrassingly simple approach to zero-shot learning](http://proceedings.mlr.press/v37/romera-paredes15.pdf)*

In [1]:
import numpy as np
from scipy.stats import hmean

from sklearn.metrics import accuracy_score

from few_shot_learn.data.awa2 import load_awa2
from few_shot_learn.zero import ESZSLearner

In [2]:
awa2_dataset = load_awa2()

In [3]:
X_train, attributes_features_train, labels_train = awa2_dataset['train']
X_val, attributes_features_val, labels_val = awa2_dataset['val']
X_test, attributes_features_test, labels_test = awa2_dataset['test']

In [4]:
eszs_learner = ESZSLearner(lmbda=0.01, gamma=0.01)

In [5]:
%%time

eszs_learner.fit(X_train, attributes_features_train, labels_train)

CPU times: user 17.8 s, sys: 487 ms, total: 18.3 s
Wall time: 9.23 s


In [6]:
predictions_train = eszs_learner.predict(X_train, attributes_features_train)

In [7]:
predictions_test = eszs_learner.predict(X_test, attributes_features_test)

In [8]:
train_accuracy = accuracy_score(labels_train, predictions_train)

In [9]:
test_accuracy = accuracy_score(labels_test, predictions_test)

### Final metric

Generalized Zero-Shot Learning uses harmonic mean of train and test accuracy as metric.

In [10]:
train_test_harmonic_mean = hmean([train_accuracy, test_accuracy])

In [11]:
print('train accuracy:', round(train_accuracy, 4))
print('test accuracy:', round(test_accuracy, 4))
print('train/test accuracy harmonic mean:', round(train_test_harmonic_mean, 4))

train accuracy: 0.9833
test accuracy: 0.3978
train/test accuracy harmonic mean: 0.5665
