In [1]:
import numpy as np
from learnable_crf import LearnableCrf
from scipy.special import expit as sigmoid

In [2]:
r = 50
iter_Phi_train = sigmoid(np.load('results/iter_Phi_train.{}.npy'.format(r)))
iter_Phi_val = sigmoid(np.load('results/iter_Phi_val.{}.npy'.format(r)))
iter_Phi_test = sigmoid(np.load('results/iter_Phi_test.{}.npy'.format(r)))

In [3]:
import pickle
with open('cache/df_train.{}.pickle'.format(r), mode='rb') as h:
    df = pickle.load(h)
Y_train = df['label']
with open('cache/df_val_test.pickle', mode='rb') as h:
    df_val, df_test = pickle.load(h)
Y_val = df_val['label']
Y_test = df_test['label']

In [4]:
def get_accuracy(Y_predict, Y_truth, lim_states=False):
    if Y_predict.dtype == bool:
        return float(np.count_nonzero(Y_predict[np.arange(len(Y_predict)), Y_truth])) / len(Y_predict)
    if lim_states:
        return float(np.count_nonzero(Y_predict[:, :20].argmax(axis=1) == Y_truth)) / len(Y_predict)
    return float(np.count_nonzero(Y_predict.argmax(axis=1) == Y_truth)) / len(Y_predict)

In [5]:
[get_accuracy(iter_Phi_test[i], Y_test) for i in range(0, 10)]  # CNN baseline

[0.00831353919239905,
 0.0029691211401425177,
 0.004156769596199525,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029,
 0.004750593824228029]

In [6]:
lcrf = [LearnableCrf(iter_Phi_train[i], Y_train) for i in range(0, 10)]

In [7]:
[x.opt_theta for x in lcrf]

[array([ 2.20014887,  1.39575951,  2.17304041,  0.68408807,  0.91085506,
         0.54329818,  2.13855473,  0.        ,  1.17597833,  0.        ,
         2.20261343,  2.16206545,  1.6350744 ,  1.61051975,  0.        ,
         0.        ,  2.07613943,  2.0925244 ,  1.27955287,  1.45120137,
         0.5274607 ,  1.73853521,  0.7255247 ,  3.23292013,  2.96907729,
         0.        ,  3.8042407 ,  0.99332188,  1.01660532,  0.96432302,
         0.51348657,  0.97932131,  0.99889503,  0.96804325,  0.91183639,
         0.96703131,  2.35966872,  0.90250768,  1.17492042,  0.94683125,
         1.00961717,  0.9340245 ,  0.97278454,  0.58909087,  1.6282265 ,
         0.99525614,  0.97042346,  0.9932059 ,  0.91191411,  0.92283271,
         1.9232554 ]),
 array([ 2.19769312,  1.3790705 ,  2.1646862 ,  0.66226643,  0.90846157,
         0.52693994,  2.1615499 ,  0.        ,  1.15680157,  0.        ,
         2.21395782,  2.16477958,  1.62072063,  1.61170322,  0.        ,
         0.        ,  2.0039

In [8]:
[get_accuracy(lcrf[i].predict(iter_Phi_val[i]), Y_val) for i in range(0, 10)]

[0.4002375296912114,
 0.4412114014251782,
 0.47030878859857483,
 0.4679334916864608,
 0.4691211401425178,
 0.4679334916864608,
 0.4679334916864608,
 0.4679334916864608,
 0.4679334916864608,
 0.4679334916864608]

In [9]:
[get_accuracy(lcrf[i].predict(iter_Phi_test[i]), Y_test) for i in range(0, 10)]

[0.3830166270783848,
 0.43883610451306415,
 0.4649643705463183,
 0.4637767220902613,
 0.4655581947743468,
 0.4661520190023753,
 0.4667458432304038,
 0.4661520190023753,
 0.4661520190023753,
 0.4661520190023753]

In [10]:
def confusion_matrix(Y_predict, Y_truth):  # crf only
    cm = np.zeros((20, 27), dtype=int)
    count = np.zeros(20, dtype=int)
    for i, y in enumerate(Y_predict):
        cm[Y_truth[i], :] += y
        count[Y_truth[i]] += 1
    return cm.astype(float) / count[:, None]

In [11]:
np.diagonal(confusion_matrix(lcrf[0].predict(iter_Phi_test[0]), Y_test))

array([ 0.25      ,  0.02222222,  0.10344828,  0.        ,  0.        ,
        0.525     ,  0.53333333,  0.86764706,  0.421875  ,  0.81481481,
        0.48214286,  0.47586207,  0.21917808,  0.        ,  0.71014493,
        0.38235294,  0.36931818,  0.29655172,  0.5       ,  0.23383085])

In [12]:
np.diagonal(confusion_matrix(lcrf[-1].predict(iter_Phi_test[-1]), Y_test))

array([ 0.46428571,  0.11111111,  0.13793103,  0.3       ,  0.10714286,
        0.475     ,  0.56666667,  0.82352941,  0.453125  ,  0.85185185,
        0.48214286,  0.45517241,  0.2739726 ,  0.05769231,  0.69565217,
        0.5       ,  0.50568182,  0.32413793,  0.58490566,  0.53731343])