In [1]:
import keras
import numpy as np
import tensorflow as tf

from absl import flags
from pathlib import Path
from operator import itemgetter
from keras.optimizers import Adam
from sklearn.metrics import classification_report, recall_score, precision_score, f1_score

Using TensorFlow backend.


In [2]:
import importlib

import model as nn
importlib.reload(nn)

elsa_architecture = nn.elsa_architecture

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "2"

lang = "en"
batch_size = 250
lr = 1e-3
epochs = 100
patience = 3
data_dir = "/data/elsa2"
checkpoint_dir = "./ckpt"
optimizer = "adam"

In [None]:
lstm_hidden = 512
lstm_drop = 0.5
final_drop = 0.5
embed_drop = 0.0
highway = False
compute_class_weight = False
multilabel = True

In [None]:
data_dir = Path(data_dir)
wv_path = (data_dir / "{:s}_wv.npy".format(lang)).__str__()
X_path = (data_dir / "{:s}_X.npy".format(lang)).__str__()
y_path = (data_dir / "{:s}_y.npy".format(lang)).__str__()
emoji_path = (data_dir / "{:s}_emoji.txt".format(lang)).__str__()

wv = np.load(wv_path, allow_pickle=True)
input_vec = np.load(X_path, allow_pickle=True)
input_label = np.load(y_path, allow_pickle=True)

nb_tokens = len(wv)
embed_dim = wv.shape[1]
input_len = len(input_label)
nb_classes = input_label.shape[1]
maxlen = input_vec.shape[1]

train_end = int(input_len*0.7)
val_end = int(input_len*0.9)

(X_train, y_train) = (input_vec[:train_end], input_label[:train_end])
(X_val, y_val) = (input_vec[train_end:val_end], input_label[train_end:val_end])
(X_test, y_test) = (input_vec[val_end:], input_label[val_end:])

if multilabel:
    def to_multilabel(y):
        outputs = []
        for i in range(nb_classes):
            outputs.append(y[:, i])
        return outputs

    y_train = to_multilabel(y_train)
    y_val = to_multilabel(y_val)
    y_test = to_multilabel(y_test)

model = elsa_architecture(nb_classes=nb_classes,
                          nb_tokens=nb_tokens,
                          maxlen=maxlen,
                          final_dropout_rate=final_drop,
                          embed_dropout_rate=embed_drop,
                          load_embedding=True,
                          pre_embedding=wv,
                          high=highway,
                          embed_dim=embed_dim,
                          multilabel=multilabel)
model.summary()

computed_class_weight = None

if multilabel:
    loss = "binary_crossentropy"
else:
    loss = "categorical_crossentropy"
    if compute_class_weight:
        y_train_sps = []
        for row in y_train:
            y_train_sps.extend(np.where(row)[0].tolist())
        computed_class_weight = class_weight.compute_class_weight(
            'balanced', list(range(nb_classes)), y_train_sps)
        print("computed class weight = {:s}".format(str(computed_class_weight)))

if optimizer == 'adam':
    adam = Adam(clipnorm=1, lr=lr)
    model.compile(loss=loss, optimizer=adam, metrics=['accuracy'])
elif optimizer == 'rmsprop':
    model.compile(loss=loss, optimizer='rmsprop', metrics=['accuracy'])

checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
    checkpoint_dir.mkdir()
checkpoint_weight_path = (checkpoint_dir / "elsa_{:s}.hdf5".format(lang)).__str__()

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss', min_delta=0, patience=patience, verbose=0, mode='auto'),
    keras.callbacks.ModelCheckpoint(checkpoint_weight_path, monitor='val_loss',
                                    verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)
]
model.fit(X_train,
          y_train,
          batch_size=batch_size,
          epochs=epochs,
          validation_data=(X_val, y_val),
          class_weight=computed_class_weight,
          callbacks=callbacks,
          verbose=1)

freq = {line.split()[0]: int(line.split()[1]) for line in open(emoji_path).readlines()}
freq_topn = sorted(freq.items(), key=itemgetter(1), reverse=True)[:nb_classes]

if multilabel:
    y_pred = model.predict([X_test], batch_size=batch_size)
    y_pred = [np.squeeze(p) for p in y_pred]

    y_test_1d = np.array(y_test).flatten()
    y_pred_1d = np.array(y_pred).flatten()
    print(f1_score(y_test_1d, y_pred_1d > 0.5))
    print(classification_report(y_test_1d, y_pred_1d > 0.5))

    gold, pred = [], []
    for i in range(len(X_test)):
        each_gold, each_pred = [], []
        for c in range(nb_classes):
            if y_test[c][i] == 1.0:
                each_gold.append(c+1)
            else:
                each_gold.append(0)
            if y_pred[c][i] > 0.5:
                each_pred.append(c+1)
            else:
                each_pred.append(0)
        gold.extend(each_gold)
        pred.extend(each_pred)

    target_name = [""] + [e[0] for e in freq_topn]
    print(classification_report(gold, pred, target_names=target_name))
else:
    _, acc = model.evaluate(X_test, y_test, batch_size=batch_size, verbose=0)
    print(acc)

    y_pred = model.predict(X_test)
    print(classification_report(y_test.argmax(axis=1), y_pred.argmax(
        axis=1), target_names=[e[0] for e in freq_topn]))

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 20)           0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 20, 200)      64132000    input_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 20, 200)      0           embedding[0][0]                  
__________________________________________________________________________________________________
bi_lstm_0 (Bidirectio

Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Train on 2677236 samples, validate on 764924 samples
Epoch 1/100
Epoch 2/100


    250/2677236 [..............................] - ETA: 1:06:41 - loss: 5.1072 - sigmoid_0_loss: 0.3930 - sigmoid_1_loss: 0.3113 - sigmoid_2_loss: 0.2373 - sigmoid_3_loss: 0.1676 - sigmoid_4_loss: 0.2221 - sigmoid_5_loss: 0.1413 - sigmoid_6_loss: 0.0552 - sigmoid_7_loss: 0.0975 - sigmoid_8_loss: 0.0677 - sigmoid_9_loss: 0.0895 - sigmoid_10_loss: 0.0533 - sigmoid_11_loss: 0.1114 - sigmoid_12_loss: 0.0463 - sigmoid_13_loss: 0.0561 - sigmoid_14_loss: 0.0689 - sigmoid_15_loss: 0.0887 - sigmoid_16_loss: 0.0454 - sigmoid_17_loss: 0.0808 - sigmoid_18_loss: 0.0684 - sigmoid_19_loss: 0.0358 - sigmoid_20_loss: 0.0789 - sigmoid_21_loss: 0.0516 - sigmoid_22_loss: 0.0728 - sigmoid_23_loss: 0.0493 - sigmoid_24_loss: 0.0771 - sigmoid_25_loss: 0.0816 - sigmoid_26_loss: 0.0393 - sigmoid_27_loss: 0.0665 - sigmoid_28_loss: 0.0266 - sigmoid_29_loss: 0.0896 - sigmoid_30_loss: 0.0992 - sigmoid_31_loss: 0.0276 - sigmoid_32_loss: 0.0892 - sigmoid_33_loss: 0.0079 - sigmoid_34_loss: 0.1177 - sigmoid_35_loss: 0.

Epoch 3/100


    250/2677236 [..............................] - ETA: 1:11:22 - loss: 4.9178 - sigmoid_0_loss: 0.3904 - sigmoid_1_loss: 0.2814 - sigmoid_2_loss: 0.3243 - sigmoid_3_loss: 0.1691 - sigmoid_4_loss: 0.1677 - sigmoid_5_loss: 0.1115 - sigmoid_6_loss: 0.0918 - sigmoid_7_loss: 0.0692 - sigmoid_8_loss: 0.0489 - sigmoid_9_loss: 0.1160 - sigmoid_10_loss: 0.0618 - sigmoid_11_loss: 0.1124 - sigmoid_12_loss: 0.0464 - sigmoid_13_loss: 0.0497 - sigmoid_14_loss: 0.1165 - sigmoid_15_loss: 0.1071 - sigmoid_16_loss: 0.0282 - sigmoid_17_loss: 0.1370 - sigmoid_18_loss: 0.0555 - sigmoid_19_loss: 0.0905 - sigmoid_20_loss: 0.0899 - sigmoid_21_loss: 0.0490 - sigmoid_22_loss: 0.0973 - sigmoid_23_loss: 0.1026 - sigmoid_24_loss: 0.1022 - sigmoid_25_loss: 0.0858 - sigmoid_26_loss: 0.0928 - sigmoid_27_loss: 0.0709 - sigmoid_28_loss: 0.0670 - sigmoid_29_loss: 0.0323 - sigmoid_30_loss: 0.0651 - sigmoid_31_loss: 0.0318 - sigmoid_32_loss: 0.0833 - sigmoid_33_loss: 0.0433 - sigmoid_34_loss: 0.0399 - sigmoid_35_loss: 0.

Epoch 4/100


    250/2677236 [..............................] - ETA: 1:01:29 - loss: 4.8854 - sigmoid_0_loss: 0.3973 - sigmoid_1_loss: 0.3150 - sigmoid_2_loss: 0.2530 - sigmoid_3_loss: 0.1665 - sigmoid_4_loss: 0.1514 - sigmoid_5_loss: 0.1339 - sigmoid_6_loss: 0.1129 - sigmoid_7_loss: 0.1144 - sigmoid_8_loss: 0.0546 - sigmoid_9_loss: 0.1638 - sigmoid_10_loss: 0.0597 - sigmoid_11_loss: 0.1346 - sigmoid_12_loss: 0.0587 - sigmoid_13_loss: 0.0326 - sigmoid_14_loss: 0.1175 - sigmoid_15_loss: 0.0981 - sigmoid_16_loss: 0.0782 - sigmoid_17_loss: 0.0849 - sigmoid_18_loss: 0.0522 - sigmoid_19_loss: 0.0324 - sigmoid_20_loss: 0.0452 - sigmoid_21_loss: 0.0624 - sigmoid_22_loss: 0.0675 - sigmoid_23_loss: 0.0540 - sigmoid_24_loss: 0.0740 - sigmoid_25_loss: 0.0718 - sigmoid_26_loss: 0.0646 - sigmoid_27_loss: 0.0691 - sigmoid_28_loss: 0.0588 - sigmoid_29_loss: 0.0512 - sigmoid_30_loss: 0.1304 - sigmoid_31_loss: 0.0965 - sigmoid_32_loss: 0.0337 - sigmoid_33_loss: 0.0152 - sigmoid_34_loss: 0.0554 - sigmoid_35_loss: 0.

Epoch 5/100


    250/2677236 [..............................] - ETA: 1:04:22 - loss: 4.6846 - sigmoid_0_loss: 0.3925 - sigmoid_1_loss: 0.2181 - sigmoid_2_loss: 0.1922 - sigmoid_3_loss: 0.1322 - sigmoid_4_loss: 0.1675 - sigmoid_5_loss: 0.1052 - sigmoid_6_loss: 0.0569 - sigmoid_7_loss: 0.0784 - sigmoid_8_loss: 0.0754 - sigmoid_9_loss: 0.1184 - sigmoid_10_loss: 0.0496 - sigmoid_11_loss: 0.1683 - sigmoid_12_loss: 0.0582 - sigmoid_13_loss: 0.0430 - sigmoid_14_loss: 0.1481 - sigmoid_15_loss: 0.0691 - sigmoid_16_loss: 0.0807 - sigmoid_17_loss: 0.0430 - sigmoid_18_loss: 0.0759 - sigmoid_19_loss: 0.0943 - sigmoid_20_loss: 0.1026 - sigmoid_21_loss: 0.0297 - sigmoid_22_loss: 0.0851 - sigmoid_23_loss: 0.0546 - sigmoid_24_loss: 0.1125 - sigmoid_25_loss: 0.0672 - sigmoid_26_loss: 0.0471 - sigmoid_27_loss: 0.0824 - sigmoid_28_loss: 0.0267 - sigmoid_29_loss: 0.0695 - sigmoid_30_loss: 0.0901 - sigmoid_31_loss: 0.0912 - sigmoid_32_loss: 0.0902 - sigmoid_33_loss: 0.0030 - sigmoid_34_loss: 0.1194 - sigmoid_35_loss: 0.

Epoch 6/100


    250/2677236 [..............................] - ETA: 1:03:57 - loss: 4.7724 - sigmoid_0_loss: 0.3609 - sigmoid_1_loss: 0.2462 - sigmoid_2_loss: 0.2587 - sigmoid_3_loss: 0.2014 - sigmoid_4_loss: 0.1869 - sigmoid_5_loss: 0.1441 - sigmoid_6_loss: 0.0895 - sigmoid_7_loss: 0.0860 - sigmoid_8_loss: 0.0461 - sigmoid_9_loss: 0.1070 - sigmoid_10_loss: 0.1003 - sigmoid_11_loss: 0.1579 - sigmoid_12_loss: 0.0409 - sigmoid_13_loss: 0.0444 - sigmoid_14_loss: 0.0792 - sigmoid_15_loss: 0.1189 - sigmoid_16_loss: 0.0377 - sigmoid_17_loss: 0.1083 - sigmoid_18_loss: 0.0326 - sigmoid_19_loss: 0.0921 - sigmoid_20_loss: 0.1048 - sigmoid_21_loss: 0.0515 - sigmoid_22_loss: 0.0663 - sigmoid_23_loss: 0.0474 - sigmoid_24_loss: 0.0464 - sigmoid_25_loss: 0.0767 - sigmoid_26_loss: 0.1169 - sigmoid_27_loss: 0.0406 - sigmoid_28_loss: 0.0661 - sigmoid_29_loss: 0.0356 - sigmoid_30_loss: 0.0394 - sigmoid_31_loss: 0.0823 - sigmoid_32_loss: 0.0181 - sigmoid_33_loss: 0.0079 - sigmoid_34_loss: 0.0742 - sigmoid_35_loss: 0.

Epoch 7/100


    250/2677236 [..............................] - ETA: 1:02:53 - loss: 4.3293 - sigmoid_0_loss: 0.3914 - sigmoid_1_loss: 0.3041 - sigmoid_2_loss: 0.2079 - sigmoid_3_loss: 0.1917 - sigmoid_4_loss: 0.1059 - sigmoid_5_loss: 0.1059 - sigmoid_6_loss: 0.1151 - sigmoid_7_loss: 0.0723 - sigmoid_8_loss: 0.0413 - sigmoid_9_loss: 0.0842 - sigmoid_10_loss: 0.0560 - sigmoid_11_loss: 0.1210 - sigmoid_12_loss: 0.0430 - sigmoid_13_loss: 0.0400 - sigmoid_14_loss: 0.0922 - sigmoid_15_loss: 0.0656 - sigmoid_16_loss: 0.0257 - sigmoid_17_loss: 0.0392 - sigmoid_18_loss: 0.0329 - sigmoid_19_loss: 0.0606 - sigmoid_20_loss: 0.1362 - sigmoid_21_loss: 0.0700 - sigmoid_22_loss: 0.1081 - sigmoid_23_loss: 0.0334 - sigmoid_24_loss: 0.0809 - sigmoid_25_loss: 0.0141 - sigmoid_26_loss: 0.0727 - sigmoid_27_loss: 0.0955 - sigmoid_28_loss: 0.0290 - sigmoid_29_loss: 0.0577 - sigmoid_30_loss: 0.0501 - sigmoid_31_loss: 0.0533 - sigmoid_32_loss: 0.0486 - sigmoid_33_loss: 0.0540 - sigmoid_34_loss: 0.0521 - sigmoid_35_loss: 0.

Epoch 8/100


    250/2677236 [..............................] - ETA: 1:03:24 - loss: 4.7375 - sigmoid_0_loss: 0.3603 - sigmoid_1_loss: 0.2189 - sigmoid_2_loss: 0.2378 - sigmoid_3_loss: 0.1782 - sigmoid_4_loss: 0.1831 - sigmoid_5_loss: 0.1657 - sigmoid_6_loss: 0.0827 - sigmoid_7_loss: 0.0731 - sigmoid_8_loss: 0.0667 - sigmoid_9_loss: 0.1033 - sigmoid_10_loss: 0.0523 - sigmoid_11_loss: 0.0983 - sigmoid_12_loss: 0.0567 - sigmoid_13_loss: 0.0924 - sigmoid_14_loss: 0.1213 - sigmoid_15_loss: 0.0922 - sigmoid_16_loss: 0.0439 - sigmoid_17_loss: 0.0940 - sigmoid_18_loss: 0.0380 - sigmoid_19_loss: 0.0783 - sigmoid_20_loss: 0.1154 - sigmoid_21_loss: 0.0679 - sigmoid_22_loss: 0.0664 - sigmoid_23_loss: 0.0438 - sigmoid_24_loss: 0.0573 - sigmoid_25_loss: 0.0584 - sigmoid_26_loss: 0.0974 - sigmoid_27_loss: 0.0962 - sigmoid_28_loss: 0.1105 - sigmoid_29_loss: 0.0498 - sigmoid_30_loss: 0.0652 - sigmoid_31_loss: 0.0632 - sigmoid_32_loss: 0.0867 - sigmoid_33_loss: 0.0275 - sigmoid_34_loss: 0.0606 - sigmoid_35_loss: 0.

Epoch 9/100


    250/2677236 [..............................] - ETA: 1:05:16 - loss: 4.5347 - sigmoid_0_loss: 0.3371 - sigmoid_1_loss: 0.2778 - sigmoid_2_loss: 0.2097 - sigmoid_3_loss: 0.2045 - sigmoid_4_loss: 0.1361 - sigmoid_5_loss: 0.0806 - sigmoid_6_loss: 0.1270 - sigmoid_7_loss: 0.0289 - sigmoid_8_loss: 0.0665 - sigmoid_9_loss: 0.1173 - sigmoid_10_loss: 0.0377 - sigmoid_11_loss: 0.0908 - sigmoid_12_loss: 0.0573 - sigmoid_13_loss: 0.0141 - sigmoid_14_loss: 0.1700 - sigmoid_15_loss: 0.0718 - sigmoid_16_loss: 0.0510 - sigmoid_17_loss: 0.0994 - sigmoid_18_loss: 0.0733 - sigmoid_19_loss: 0.0335 - sigmoid_20_loss: 0.0815 - sigmoid_21_loss: 0.0491 - sigmoid_22_loss: 0.0762 - sigmoid_23_loss: 0.1052 - sigmoid_24_loss: 0.0508 - sigmoid_25_loss: 0.0809 - sigmoid_26_loss: 0.0446 - sigmoid_27_loss: 0.0601 - sigmoid_28_loss: 0.0213 - sigmoid_29_loss: 0.0452 - sigmoid_30_loss: 0.0875 - sigmoid_31_loss: 0.0663 - sigmoid_32_loss: 0.0300 - sigmoid_33_loss: 0.0122 - sigmoid_34_loss: 0.0850 - sigmoid_35_loss: 0.

Epoch 10/100


    250/2677236 [..............................] - ETA: 1:04:55 - loss: 4.2997 - sigmoid_0_loss: 0.3674 - sigmoid_1_loss: 0.2058 - sigmoid_2_loss: 0.2250 - sigmoid_3_loss: 0.1591 - sigmoid_4_loss: 0.1571 - sigmoid_5_loss: 0.1215 - sigmoid_6_loss: 0.0867 - sigmoid_7_loss: 0.1172 - sigmoid_8_loss: 0.0457 - sigmoid_9_loss: 0.1535 - sigmoid_10_loss: 0.0736 - sigmoid_11_loss: 0.1382 - sigmoid_12_loss: 0.0425 - sigmoid_13_loss: 0.0510 - sigmoid_14_loss: 0.1151 - sigmoid_15_loss: 0.0789 - sigmoid_16_loss: 0.0475 - sigmoid_17_loss: 0.0869 - sigmoid_18_loss: 0.0362 - sigmoid_19_loss: 0.0255 - sigmoid_20_loss: 0.1101 - sigmoid_21_loss: 0.0388 - sigmoid_22_loss: 0.0585 - sigmoid_23_loss: 0.0969 - sigmoid_24_loss: 0.0480 - sigmoid_25_loss: 0.0997 - sigmoid_26_loss: 0.0873 - sigmoid_27_loss: 0.0920 - sigmoid_28_loss: 0.0807 - sigmoid_29_loss: 0.0638 - sigmoid_30_loss: 0.0484 - sigmoid_31_loss: 0.0127 - sigmoid_32_loss: 0.0235 - sigmoid_33_loss: 0.0118 - sigmoid_34_loss: 0.0641 - sigmoid_35_loss: 0.

