In [23]:
import tensorflow as tf
import numpy as np

from sklearn.metrics import accuracy_score

In [24]:
# change variables to match names
data_path = 'easy_data.npy'
labels_path = 'easy_labels.npy'

data = np.load(data_path)
labels = np.load(labels_path)
data = data.astype('uint8')
data.shape

(270000, 5933)

In [39]:
# change variables to match names
hard_data_path = 'hard_data.npy'
hard_labels_path = 'hard_labels.npy'

hard_data = np.load(hard_data_path)
hard_labels = np.load(hard_labels_path)
hard_data.shape

(270000, 110)

In [40]:
CHECKPOINT_FILEPATH = 'efficientnetv2-s/checkpoints/'

In [41]:
def test(X, t):
    """Output accuracy given data array and targets

    Args:
        X (np.array): Data of shape (270000, num_samples)
        t (np.array): Targets of shape (num_samples,)
    """
    # Transpose data
    X = X.T
    # Load best model from training
    saved_model = tf.keras.models.load_model(CHECKPOINT_FILEPATH)

    preds = np.argmax(saved_model.predict(X), axis=-1)

    acc = accuracy_score(t, preds)

    return acc, preds

In [42]:
def test_hard(X, t, threshold):
    """Output accuracy given data array and targets

    Args:
        X (np.array): Data of shape (270000, num_samples)
        t (np.array): Targets of shape (num_samples,)
        threshold (float): A threshold for predicting a value vs predicting -1
    """
    # Transpose data
    X = X.T
    # Load best model from training
    saved_model = tf.keras.models.load_model(CHECKPOINT_FILEPATH)

    # Use softmax predictions to predict based on threshold
    preds = saved_model.predict(X)
    ret_preds = []
    for p in preds:
        if np.max(p) >= threshold:
            ret_preds.append(np.argmax(p))
        else:
            ret_preds.append(-1)

    acc = accuracy_score(t, ret_preds)

    return acc, ret_preds

In [43]:
acc, ret_preds = test(data, labels)

print("Reported easy test accuracy: ", acc)
print("Reported easy test predictions: ", ret_preds)

Reported easy test accuracy:  0.9710096072813079
Reported easy test predictions:  [8 9 3 ... 5 0 4]


In [51]:
hard_acc, hard_ret_preds = test_hard(hard_data, hard_labels, 0.26)

print("Reported hard test accuracy: ", hard_acc)
print("Reported hard_test predictions: ", hard_ret_preds)

Reported hard test accuracy:  0.5909090909090909
Reported hard_test predictions:  [7, 1, 3, 3, 3, 4, 5, 6, 7, 8, 6, 7, 1, 7, 2, 2, 4, 5, 6, 7, 8, 9, 1, 1, 7, 2, 3, 7, 5, 6, 7, 8, 6, 1, 7, 8, 2, 3, 4, 5, 6, 7, 8, 2, 1, 3, 7, 2, 3, 4, 5, 8, 7, 3, 9, 0, 8, 8, 2, 2, 9, 5, 6, 7, 7, 1, 7, 8, 7, 2, 2, 7, 2, 8, 7, 8, 9, 7, 2, 4, 2, 3, 4, 5, 6, 7, 8, 9, 0, 7, 7, 2, 3, 4, 2, 6, 7, 8, 2, 1, 1, 8, 2, 3, 4, 5, 7, 7, 9, 3]
