In [1]:
import numpy as np
import os
from scipy.stats import hmean
from sklearn import metrics

from scarce_learn.data import load_awa2
from scarce_learn import zero_shot



In [2]:
import torch

In [3]:
torch.__version__

'1.10.0+cu113'

$\phi$ - input feature
$$
    loss(\phi, label) = \sum_{j \neq label} ReLU(margin - t_{label} W \phi + t_j W \phi) 
$$

In [4]:
awa2_dataset = load_awa2()
X_train, label_embeddings_train, labels_train = awa2_dataset['train']
X_val, label_embeddings_val, labels_val = awa2_dataset['val']
X_test, label_embeddings_test, labels_test = awa2_dataset['test']
X_train = X_train.astype('float32')
label_embeddings_train = label_embeddings_train.astype('float32') 

In [5]:
X_trainval = np.row_stack([X_train, X_val])
labels_trainval = np.concatenate([labels_train, labels_val])
label_embeddings_trainval = np.row_stack([label_embeddings_train, label_embeddings_val])

In [6]:
X_train.shape

(16187, 2048)

In [7]:
torch.tensor(X_trainval).dtype, torch.tensor(label_embeddings_trainval).dtype

(torch.float64, torch.float64)

In [8]:
devise_learner = zero_shot.devise_torch.DEVISELearner(margin=10)

In [9]:
%%time
devise_learner.fit(X_trainval, labels_trainval, label_embeddings_trainval, n_epochs=2)

[1/736]   0%|           [00:00<?]

[1/736]   0%|           [00:00<?]

CPU times: user 26.3 s, sys: 1.24 s, total: 27.6 s
Wall time: 26.7 s


DEVISELearner(margin=10)

In [10]:
predictions_trainval = devise_learner.predict(X_trainval, label_embeddings_train)
predictions_test = devise_learner.predict(X_test, label_embeddings_test)

In [11]:
def get_metrics(model, embeddings, labels, label_embeddings):
    predictions = model.predict(embeddings, label_embeddings)
    accuracy = metrics.accuracy_score(predictions, labels)
    loss = model.get_loss(embeddings, label_embeddings, labels)
    return accuracy, loss

In [12]:
test_accuracy, test_loss = get_metrics(devise_learner, X_test, labels_test, label_embeddings_test)
trainval_accuracy, trainval_loss = get_metrics(devise_learner, X_trainval, labels_trainval, label_embeddings_trainval)

## Train

In [13]:
print('accuracy', round(trainval_accuracy, 3))
print('loss', round(trainval_loss, 3))

accuracy 0.74
loss 16.882


## Test

In [15]:
print('accuracy', round(test_accuracy, 3))
print('loss', round(test_loss, 3))

accuracy 0.403
loss 59.276


In [16]:
import scipy

scipy.stats.hmean([trainval_accuracy, test_accuracy])

0.5220912426542748