In [1]:
from keras.layers import Input, Dropout, Dense, Embedding
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.models import Model
from keras.optimizers import Adam
from keras.regularizers import l2
import pickle as pkl 
from sklearn.metrics import f1_score, classification_report
from layers.graph import SpectralGraphConvolution
from utils import *

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
def fix_labels(labels):
    for i in range(len(labels)):
        if labels[i][0] == "I":
            if i == 0 or labels[i-1][2:] != labels[i][2:]:
                labels[i] = "B-{}".format(labels[i][2:])
    return labels


def decode_labels(labels, idx2label):
    labels = np.array(labels)
    prediction_indices = labels.argmax(axis=1)
    prediction_labels = [idx2label[i] for i in prediction_indices]
    return prediction_labels


def predict_labels(predictions, actuals, idx2label):
    predictions_labels = []
    actuals_labels = []
    for i in range(len(predictions)):
#     for i in range(predictions.shape[0]):
        prediction = predictions[i]
        actual = actuals[i]
        prediction_labels = decode_labels(prediction, idx2label)
        prediction_labels = fix_labels(prediction_labels)
        actual_labels = decode_labels(actual, idx2label)
        predictions_labels.append(prediction_labels)
        actuals_labels.append(actual_labels)
    return predictions_labels, actuals_labels

In [3]:
def evaluate_metrics(y_true, y_pred):
        ## calc metric
    num_proposed = sum(1 for n in y_pred if n != 'O')
    num_correct = 0
    for i,j in zip(y_true,y_pred):
        if i != 'O' and i == j:
            num_correct +=1
    num_gold = sum(1 for n in y_true if n != 'O')
    print("num_proposed: ", num_proposed)
    print("num_correct: ", num_correct)
    print("num_gold: ", num_gold)
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall==0:
            f1=1.0
        else:
            f1=0
    final = ".P%.2f_R%.2f_F%.2f" %(precision, recall, f1)
    print("precision=%.4f"%precision)
    print("recall=%.4f"%recall)
    print("f1=%.4f"%f1)
    print("final ",final)
    return f1

In [4]:
def f1_metric(y_true, y_pred):
        ## calc metric
    y_pred, y_true = predict_labels(
        y_pred, y_true, meta['idx2label'])
    for i in range(len(y_pred)):
        y_pred[i] = [x.split('-')[1] if '-' in x else x for x in y_pred[i]]
    for i in range(len(y_true)):
        y_true[i] = [x.split('-')[1] if '-' in x else x for x in y_true[i]]
    
    gt = []
    pr = []
    for i in range(len(y_pred)):
        gt.extend(y_pred[i])
    for i in range(len(y_true)):
        pr.extend(y_true[i])
        
    num_proposed = sum(1 for n in pr if n != 'O')
    num_correct = 0
    for i,j in zip(gt,pr):
        if i != 'O' and i == j:
            num_correct +=1
    num_gold = sum(1 for n in gt if n != 'O')
    try:
        precision = num_correct / num_proposed
    except ZeroDivisionError:
        precision = 1.0

    try:
        recall = num_correct / num_gold
    except ZeroDivisionError:
        recall = 1.0

    try:
        f1 = 2*precision*recall / (precision + recall)
    except ZeroDivisionError:
        if precision*recall==0:
            f1=1.0
        else:
            f1=0
    return f1

In [5]:
DATASET = 'conll2003'
EPOCHS = 4
LR = 0.001
L2 = 0
DO = 0.5
BATCH_SIZE = 16

In [6]:
print("Loading dataset...")

A, X, Y, meta = pkl.load(open('pkl/' + DATASET + '.pkl', 'rb'))

print("Loading embedding matrix...")

embedding_matrix = pkl.load(
    open('pkl/' + DATASET + '.embedding_matrix.pkl', 'rb'))

print("Processing dataset...")

val_y = load_output(A, X, Y, 'val')
test_y = load_output(A, X, Y, 'test')

num_nodes = A['train'][0][0].shape[0]
num_relations = len(A['train'][0]) - 1
num_labels = len(meta['label2idx'])

print("Number of nodes: {}".format(num_nodes))
print("Number of relations: {}".format(num_relations))
print("Number of classes: {}".format(num_labels))

Loading dataset...
Loading embedding matrix...
Processing dataset...
Number of nodes: 124
Number of relations: 44
Number of classes: 8


In [7]:
# Define model inputs
X_in = Input(shape=(num_nodes, ))
A_in = [Input(shape=(num_nodes, num_nodes)) for _ in range(num_relations)]

In [8]:
print("Define model")
# Define model architecture
X_embedding = Embedding(embedding_matrix.shape[0], embedding_matrix.shape[1], weights=[
                        embedding_matrix], trainable=False)(X_in)
H = SpectralGraphConvolution(256, activation='relu')([X_embedding] + A_in)
H = Dropout(DO)(H)
H = SpectralGraphConvolution(256, activation='relu')([H] + A_in)
H = Dropout(DO)(H)
output = Dense(num_labels, activation='softmax')(H)

# Compile model
model = Model(inputs=[X_in] + A_in, outputs=output)
model.compile(metrics=['acc'],loss='categorical_crossentropy', optimizer=Adam(lr=LR))
model.summary()

Define model
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 124)          0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 124, 300)     442806300   input_1[0][0]                    
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 124, 124)     0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, 124, 124)     0                                            
________________________________________________________________________________________________

In [13]:
# callbacks = [EarlyStopping(monitor='f1_metric', patience=2, verbose=0),
#              ModelCheckpoint(filepath='model.{loss:.2f}.h5', monitor='f1_metric', save_best_only=True, verbose=0)
#             ]

In [9]:
EPOCHS = 50
for epoch in range(EPOCHS):

    print("=== EPOCH {} ===".format(epoch + 1))

    model.fit_generator(batch_generator(A, X, Y, 'train', batch_size=BATCH_SIZE),
                        steps_per_epoch=len(A['train'])//BATCH_SIZE, verbose=1)


    val_predictions = model.predict_generator(batch_generator(
        A, X, Y, 'val', batch_size=BATCH_SIZE), steps=len(A['val'])//BATCH_SIZE, verbose=1)
    val_predicted_labels, val_actual_labels = predict_labels(
        val_predictions, val_y, meta['idx2label'])

    for i in range(len(val_predicted_labels)):
        val_predicted_labels[i] = [x.split('-')[1] if '-' in x else x for x in val_predicted_labels[i]]
    for i in range(len(val_actual_labels)):
        val_actual_labels[i] = [x.split('-')[1] if '-' in x else x for x in val_actual_labels[i]]
    
    gt = []
    pr = []
    for i in range(len(val_predicted_labels)):
        gt.extend(val_predicted_labels[i])
    for i in range(len(val_actual_labels)):
        pr.extend(val_actual_labels[i])
        
    print("=== Validation Results ===")
    print("Weighted F1-score: ",f1_score(gt,pr, average = 'weighted'))
    print("Classification report:\n", classification_report(gt,pr))
    evaluate_metrics(gt, pr)

    test_predictions = model.predict_generator(batch_generator(
        A, X, Y, 'test', batch_size=BATCH_SIZE), steps=len(A['test']) // BATCH_SIZE, verbose=1)

    test_predicted_labels, test_actual_labels = predict_labels(
        test_predictions, test_y, meta['idx2label'])
    for i in range(len(test_predicted_labels)):
        test_predicted_labels[i] = [x.split('-')[1] if '-' in x else x for x in test_predicted_labels[i]]
    for i in range(len(test_actual_labels)):
        test_actual_labels[i] = [x.split('-')[1] if '-' in x else x for x in test_actual_labels[i]]

    print("=== Test Results ===")

    gt = []
    pr = []
    for i in range(len(test_predicted_labels)):
        gt.extend(test_predicted_labels[i])
    for i in range(len(test_actual_labels)):
        pr.extend(test_actual_labels[i])
    print("Weighted F1-score: ",f1_score(gt,pr, average = 'weighted'))
    print("Classification report:\n", classification_report(gt,pr))
    evaluate_metrics(gt, pr)

=== EPOCH 1 ===
Epoch 1/1
=== Validation Results ===
Weighted F1-score:  0.9968688929547209
Classification report:
              precision    recall  f1-score   support

        LOC       0.92      0.90      0.91      2144
       MISC       0.70      0.92      0.80       965
          O       1.00      1.00      1.00    420276
        ORG       0.82      0.80      0.81      2121
        PER       0.94      0.97      0.95      3038

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  7472
num_gold:  8268
precision=0.8706
recall=0.9037
f1=0.8868
final  .P0.87_R0.90_F0.89
=== Test Results ===
Weighted F1-score:  0.9955219840668679
Classification report:
              precision    recall  f1-score   support

        LOC       0.88      0.83      0.85      2036
       MISC       0.63      0.77      0.69       743
          O       1.00      1.00      1.00    448245
        ORG       0.77      0.72      0.75      2660
        PER       0.90      0.95      

=== Validation Results ===
Weighted F1-score:  0.9992808670014541
Classification report:
              precision    recall  f1-score   support

        LOC       0.99      0.97      0.98      2120
       MISC       0.91      0.98      0.95      1173
          O       1.00      1.00      1.00    420065
        ORG       0.96      0.97      0.96      2050
        PER       0.99      0.99      0.99      3136

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  8323
num_gold:  8479
precision=0.9697
recall=0.9816
f1=0.9756
final  .P0.97_R0.98_F0.98
=== Test Results ===
Weighted F1-score:  0.9968234511233303
Classification report:
              precision    recall  f1-score   support

        LOC       0.90      0.86      0.88      2011
       MISC       0.69      0.84      0.76       755
          O       1.00      1.00      1.00    448434
        ORG       0.80      0.85      0.83      2354
        PER       0.95      0.95      0.95      2766

avg / tota

=== Validation Results ===
Weighted F1-score:  0.9995891170210366
Classification report:
              precision    recall  f1-score   support

        LOC       0.99      0.99      0.99      2092
       MISC       0.95      0.99      0.97      1215
          O       1.00      1.00      1.00    420031
        ORG       0.98      0.98      0.98      2067
        PER       0.99      1.00      1.00      3139

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  8434
num_gold:  8513
precision=0.9826
recall=0.9907
f1=0.9867
final  .P0.98_R0.99_F0.99
=== Test Results ===
Weighted F1-score:  0.9968880146646313
Classification report:
              precision    recall  f1-score   support

        LOC       0.90      0.89      0.89      1955
       MISC       0.71      0.82      0.76       794
          O       1.00      1.00      1.00    448482
        ORG       0.81      0.87      0.84      2326
        PER       0.94      0.95      0.94      2763

avg / tota

=== Validation Results ===
Weighted F1-score:  0.9997328669208086
Classification report:
              precision    recall  f1-score   support

        LOC       1.00      0.99      1.00      2099
       MISC       0.97      1.00      0.98      1230
          O       1.00      1.00      1.00    420021
        ORG       0.98      0.99      0.99      2054
        PER       1.00      1.00      1.00      3140

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  8480
num_gold:  8523
precision=0.9880
recall=0.9950
f1=0.9915
final  .P0.99_R0.99_F0.99
=== Test Results ===
Weighted F1-score:  0.9968565834281566
Classification report:
              precision    recall  f1-score   support

        LOC       0.89      0.90      0.90      1910
       MISC       0.71      0.84      0.77       769
          O       1.00      1.00      1.00    448555
        ORG       0.79      0.86      0.82      2281
        PER       0.95      0.94      0.94      2805

avg / tota

=== Validation Results ===
Weighted F1-score:  0.9997971684488277
Classification report:
              precision    recall  f1-score   support

        LOC       0.99      0.99      0.99      2093
       MISC       0.98      0.99      0.99      1257
          O       1.00      1.00      1.00    419972
        ORG       0.99      0.99      0.99      2075
        PER       1.00      1.00      1.00      3147

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  8520
num_gold:  8572
precision=0.9927
recall=0.9939
f1=0.9933
final  .P0.99_R0.99_F0.99
=== Test Results ===
Weighted F1-score:  0.9969147327735749
Classification report:
              precision    recall  f1-score   support

        LOC       0.88      0.91      0.90      1855
       MISC       0.73      0.81      0.77       828
          O       1.00      1.00      1.00    448467
        ORG       0.82      0.85      0.84      2407
        PER       0.94      0.94      0.94      2763

avg / tota

=== Validation Results ===
Weighted F1-score:  0.9998251974375654
Classification report:
              precision    recall  f1-score   support

        LOC       1.00      0.99      1.00      2108
       MISC       0.98      0.99      0.99      1253
          O       1.00      1.00      1.00    419976
        ORG       0.99      0.99      0.99      2069
        PER       1.00      1.00      1.00      3138

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  8525
num_gold:  8568
precision=0.9932
recall=0.9950
f1=0.9941
final  .P0.99_R0.99_F0.99
=== Test Results ===
Weighted F1-score:  0.9968711750760331
Classification report:
              precision    recall  f1-score   support

        LOC       0.91      0.88      0.89      1971
       MISC       0.73      0.79      0.76       836
          O       1.00      1.00      1.00    448421
        ORG       0.81      0.86      0.83      2347
        PER       0.94      0.95      0.95      2745

avg / tota

=== Validation Results ===
Weighted F1-score:  0.9998071328591823
Classification report:
              precision    recall  f1-score   support

        LOC       1.00      0.99      1.00      2107
       MISC       0.99      1.00      0.99      1261
          O       1.00      1.00      1.00    420009
        ORG       0.97      1.00      0.99      2028
        PER       1.00      1.00      1.00      3139

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  8508
num_gold:  8535
precision=0.9913
recall=0.9968
f1=0.9940
final  .P0.99_R1.00_F0.99
=== Test Results ===
Weighted F1-score:  0.9969110652744383
Classification report:
              precision    recall  f1-score   support

        LOC       0.91      0.88      0.89      1972
       MISC       0.72      0.80      0.76       825
          O       1.00      1.00      1.00    448568
        ORG       0.78      0.89      0.83      2185
        PER       0.95      0.95      0.95      2770

avg / tota

=== Validation Results ===
Weighted F1-score:  0.9998903391425382
Classification report:
              precision    recall  f1-score   support

        LOC       1.00      1.00      1.00      2092
       MISC       0.99      1.00      1.00      1257
          O       1.00      1.00      1.00    419967
        ORG       0.99      0.99      0.99      2084
        PER       1.00      1.00      1.00      3144

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  8550
num_gold:  8577
precision=0.9962
recall=0.9969
f1=0.9965
final  .P1.00_R1.00_F1.00
=== Test Results ===
Weighted F1-score:  0.9969194657366502
Classification report:
              precision    recall  f1-score   support

        LOC       0.89      0.90      0.90      1886
       MISC       0.73      0.81      0.77       818
          O       1.00      1.00      1.00    448436
        ORG       0.82      0.86      0.84      2375
        PER       0.95      0.94      0.94      2805

avg / tota

=== Validation Results ===
Weighted F1-score:  0.9998695883363811
Classification report:
              precision    recall  f1-score   support

        LOC       1.00      1.00      1.00      2101
       MISC       0.99      1.00      0.99      1250
          O       1.00      1.00      1.00    419995
        ORG       0.99      1.00      0.99      2061
        PER       1.00      1.00      1.00      3137

avg / total       1.00      1.00      1.00    428544

num_proposed:  8583
num_correct:  8533
num_gold:  8549
precision=0.9942
recall=0.9981
f1=0.9961
final  .P0.99_R1.00_F1.00
=== Test Results ===
Weighted F1-score:  0.9968659735446297
Classification report:
              precision    recall  f1-score   support

        LOC       0.90      0.89      0.90      1946
       MISC       0.71      0.81      0.76       801
          O       1.00      1.00      1.00    448567
        ORG       0.79      0.87      0.83      2272
        PER       0.94      0.95      0.95      2734

avg / tota