In [1]:
import os

import numpy as np
import os
from scipy.stats import hmean

import jax.numpy as jnp
import jax

from sklearn import metrics

from scarce_learn.data import load_awa2
from scarce_learn import zero_shot



In [2]:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [3]:
import torch

In [4]:
torch.__version__

'1.10.0+cu113'

$\phi$  - example embedding - $dim$-vector

$T  $ - label embeddings - $n\_classes \times l\_dim$ matrix

$W$ - bilinear similarity weights - $dim \times l\_dim$ matrix

$label$ - integer class label

Model is trained with the modified hinge loss:

$$
    loss(\phi, label) = \sum_{j \neq label} ReLU(margin - \phi W  T_{label} + \phi W T_j) 
$$

($ReLU(x) = max(0, x)$)

In [5]:
import numpy as np

In [6]:
def get_not_in_array(range_size):
    def _f(not_idx):
        rng = jnp.arange(range_size)
        return jnp.array([x for x in range(range_size) if x != not_idx])
    return _f

In [7]:
f = jax.pmap(get_not_in_array(10))

In [8]:
jnp.tile(jnp.array([1,2]), (2,1))

DeviceArray([[1, 2],
             [1, 2]], dtype=int32)

In [9]:
def get_not_idxs(a, idxs):
    output = []
    for i in range(len(a)):
        output.append(jnp.where(a[i] == idxs[i])[0][0])
    return jnp.stack(output)

In [10]:
a = jnp.array([[0,1], [0,1]])
idxs = jnp.array([0,1])

In [11]:
n_examples = 100
n_classes = 5
weights = jnp.array(np.random.rand(10, 20))
X = jnp.array(np.random.rand(n_examples, weights.shape[0]))
label_embeddings = jnp.array(np.random.rand(n_classes, weights.shape[1]))
y = jnp.array(np.arange(n_examples) % n_classes, dtype='int32')

In [12]:
from scarce_learn.zero_shot import torch_util

In [13]:
from itertools import islice

In [14]:
import tqdm

In [15]:
from sklearn import preprocessing

In [16]:
from jax.experimental import optimizers

In [17]:
from scarce_learn.zero_shot import devise_jax

In [18]:
devise = devise_jax.DEVISELearner(10)

In [19]:
devise.fit(X, y, label_embeddings)

100%|██████████| 10/10 [00:00<00:00, 12.56it/s]


In [20]:
devise.predict(X, label_embeddings)

array([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], dtype=int32)

In [21]:
devise.get_loss(X, label_embeddings, y)

39.95183563232422

In [22]:
from sklearn import metrics

In [23]:
from scarce_learn.data import load_awa2

In [24]:
awa2_dataset = load_awa2()
X_train, label_embeddings_train, labels_train = awa2_dataset['train']
X_val, label_embeddings_val, labels_val = awa2_dataset['val']
X_test, label_embeddings_test, labels_test = awa2_dataset['test']

In [25]:
labels_val.shape

(7340,)

In [26]:
X_trainval = np.row_stack([X_train, X_val])
labels_trainval = np.concatenate([labels_train, labels_val])
label_embeddings_trainval = np.row_stack([label_embeddings_train, label_embeddings_val])

In [27]:
label_embeddings_train.shape

(27, 85)

In [28]:
X_train.shape

(16187, 2048)

In [29]:
awa_devise = devise_jax.DEVISELearner(margin=10)

In [30]:
awa_devise.fit(X_trainval, labels_trainval, label_embeddings_trainval, n_epochs=2, batch_size=32)

100%|██████████| 2/2 [00:04<00:00,  2.21s/it]


In [31]:
predictions_trainval = awa_devise.predict(X_trainval, label_embeddings_train)
predictions_test = awa_devise.predict(X_test, label_embeddings_test)

In [32]:
def get_metrics(model, embeddings, labels, label_embeddings):
    predictions = model.predict(embeddings, label_embeddings)
    accuracy = metrics.accuracy_score(predictions, labels)
    loss = model.get_loss(embeddings, label_embeddings, labels)
    return accuracy, loss

In [33]:
test_accuracy, test_loss = get_metrics(awa_devise, X_test, labels_test, label_embeddings_test)
trainval_accuracy, trainval_loss = get_metrics(awa_devise, X_trainval, labels_trainval, label_embeddings_trainval)

## Training metrics

In [34]:
print('accuracy', round(trainval_accuracy, 3))
print('loss', round(trainval_loss, 3))

accuracy 0.826
loss 15.697


## Test metrics

In [35]:
print('accuracy', round(test_accuracy, 3))
print('loss', round(test_loss, 3))

accuracy 0.42
loss 39.183


In [36]:
import scipy

scipy.stats.hmean([trainval_accuracy, test_accuracy])

0.5571186581935672