In [1]:
import numpy as np
import pickle
from learnable_crf import LearnableCrf
from scipy.special import expit as sigmoid

In [3]:
r = 50
with open('cache/df_train.{}.pickle'.format(r), mode='rb') as h:
    df = pickle.load(h)
leaf_indices = np.nonzero([x[0] in x[1] for x in zip(df['label'], df['pseudo_label'])])[0]
Y_train = df['label'][leaf_indices]
with open('cache/df_val_test.pickle', mode='rb') as h:
    df_val, df_test = pickle.load(h)
Y_val = df_val['label']
Y_test = df_test['label']

In [4]:
iter_Phi_train = sigmoid(np.load('results/iter_Phi_train.{}.npy'.format(r))[:, leaf_indices])
iter_Phi_val = sigmoid(np.load('results/iter_Phi_val.{}.npy'.format(r)))
iter_Phi_test = sigmoid(np.load('results/iter_Phi_test.{}.npy'.format(r)))

In [5]:
def get_accuracy(Y_predict, Y_truth, lim_states=False):
    if Y_predict.dtype == bool:
        return float(np.count_nonzero(Y_predict[np.arange(len(Y_predict)), Y_truth])) / len(Y_predict)
    if lim_states:
        return float(np.count_nonzero(Y_predict[:, :20].argmax(axis=1) == Y_truth)) / len(Y_predict)
    return float(np.count_nonzero(Y_predict.argmax(axis=1) == Y_truth)) / len(Y_predict)

In [6]:
[get_accuracy(iter_Phi_test[i], Y_test) for i in range(0, 10)]  # CNN baseline

[0.00831353919239905,
 0.0029691211401425177,
 0.004156769596199525,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029]

In [7]:
lcrf = [LearnableCrf(iter_Phi_train[i], Y_train) for i in range(0, 10)]

In [8]:
[x.opt_theta for x in lcrf]

[array([ 2.1939038 ,  1.33625349,  2.16850978,  0.64685469,  0.93553538,
         0.71872943,  2.11896663,  0.00478724,  1.21611635,  0.        ,
         2.18995653,  2.14358186,  1.73381367,  1.59213518,  0.        ,
         0.        ,  2.05900137,  2.07314393,  1.69131754,  1.45917046,
         0.65351471,  1.67580992,  0.82598483,  3.2120469 ,  2.75124153,
         0.        ,  3.78086442,  0.9934507 ,  1.0165862 ,  0.95963574,
         0.54270242,  0.97668465,  1.00069711,  0.96611547,  0.90016389,
         0.96193266,  2.29557467,  0.89082662,  1.18460583,  0.93284638,
         1.00509559,  0.92619231,  0.96539575,  0.57692204,  1.67022539,
         0.99496498,  0.96530932,  0.99078107,  0.89869586,  0.90798525,
         2.16530878]),
 array([ 2.18634057,  1.31785395,  2.15712766,  0.62115169,  0.93225085,
         0.69878997,  2.12895995,  0.        ,  1.18609924,  0.        ,
         2.19278143,  2.13184232,  1.71834207,  1.58940389,  0.        ,
         0.        ,  1.9643

In [9]:
[get_accuracy(lcrf[i].predict(iter_Phi_val[i]), Y_val) for i in range(0, 10)]

[0.39667458432304037,
 0.43646080760095013,
 0.4667458432304038,
 0.4637767220902613,
 0.4649643705463183,
 0.4655581947743468,
 0.4655581947743468,
 0.4655581947743468,
 0.4655581947743468,
 0.4655581947743468]

In [10]:
[get_accuracy(lcrf[i].predict(iter_Phi_test[i]), Y_test) for i in range(0, 10)]

[0.3812351543942993,
 0.43764845605700714,
 0.46140142517814725,
 0.46140142517814725,
 0.46199524940617576,
 0.4637767220902613,
 0.4643705463182898,
 0.4643705463182898,
 0.4643705463182898,
 0.4643705463182898]

In [11]:
def confusion_matrix(Y_predict, Y_truth):  # crf only
    cm = np.zeros((20, 27), dtype=int)
    count = np.zeros(20, dtype=int)
    for i, y in enumerate(Y_predict):
        cm[Y_truth[i], :] += y
        count[Y_truth[i]] += 1
    return cm.astype(float) / count[:, None]

In [12]:
np.diagonal(confusion_matrix(lcrf[0].predict(iter_Phi_test[0]), Y_test))

array([ 0.25      ,  0.02222222,  0.10344828,  0.        ,  0.        ,
        0.5       ,  0.53333333,  0.86764706,  0.421875  ,  0.83333333,
        0.48214286,  0.47586207,  0.21917808,  0.        ,  0.71014493,
        0.38235294,  0.36931818,  0.29655172,  0.47169811,  0.23383085])

In [13]:
np.diagonal(confusion_matrix(lcrf[-1].predict(iter_Phi_test[-1]), Y_test))

array([ 0.46428571,  0.11111111,  0.13793103,  0.3       ,  0.10714286,
        0.475     ,  0.56666667,  0.82352941,  0.453125  ,  0.85185185,
        0.48214286,  0.45517241,  0.2739726 ,  0.05769231,  0.69565217,
        0.5       ,  0.50568182,  0.32413793,  0.55660377,  0.53731343])