In [1]:
import keras
import numpy as np
import tensorflow as tf

from absl import flags
from pathlib import Path
from operator import itemgetter
from keras.optimizers import Adam
from sklearn.metrics import classification_report, recall_score, precision_score, f1_score

Using TensorFlow backend.


In [2]:
import importlib

import model as nn
importlib.reload(nn)

elsa_architecture = nn.elsa_architecture

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

lang = "en"
batch_size = 250
lr = 1e-3
epochs = 100
patience = 3
data_dir = "/data/elsa2"
checkpoint_dir = "./ckpt"
optimizer = "adam"

In [4]:
lstm_hidden = 512
lstm_drop = 0.5
final_drop = 0.5
embed_drop = 0.0
highway = False
compute_class_weight = False
multilabel = True

In [5]:
data_dir = Path(data_dir)
wv_path = (data_dir / "{:s}_wv.npy".format(lang)).__str__()
X_path = (data_dir / "{:s}_X.npy".format(lang)).__str__()
y_path = (data_dir / "{:s}_y.npy".format(lang)).__str__()
emoji_path = (data_dir / "{:s}_emoji.txt".format(lang)).__str__()

wv = np.load(wv_path, allow_pickle=True)
input_vec = np.load(X_path, allow_pickle=True)
input_label = np.load(y_path, allow_pickle=True)

nb_tokens = len(wv)
embed_dim = wv.shape[1]
input_len = len(input_label)
nb_classes = input_label.shape[1]
maxlen = input_vec.shape[1]

train_end = int(input_len*0.7)
val_end = int(input_len*0.9)

(X_train, y_train) = (input_vec[:train_end], input_label[:train_end])
(X_val, y_val) = (input_vec[train_end:val_end], input_label[train_end:val_end])
(X_test, y_test) = (input_vec[val_end:], input_label[val_end:])

if multilabel:
    def to_multilabel(y):
        outputs = []
        for i in range(nb_classes):
            outputs.append(y[:, i])
        return outputs

    y_train = to_multilabel(y_train)
    y_val = to_multilabel(y_val)
    y_test = to_multilabel(y_test)

model = elsa_architecture(nb_classes=nb_classes,
                          nb_tokens=nb_tokens,
                          maxlen=maxlen,
                          final_dropout_rate=final_drop,
                          embed_dropout_rate=embed_drop,
                          load_embedding=True,
                          pre_embedding=wv,
                          high=highway,
                          embed_dim=embed_dim,
                          multilabel=multilabel)
model.summary()

computed_class_weight = None

if multilabel:
    loss = "binary_crossentropy"
else:
    loss = "categorical_crossentropy"
    if compute_class_weight:
        y_train_sps = []
        for row in y_train:
            y_train_sps.extend(np.where(row)[0].tolist())
        computed_class_weight = class_weight.compute_class_weight(
            'balanced', list(range(nb_classes)), y_train_sps)
        print("computed class weight = {:s}".format(str(computed_class_weight)))

if optimizer == 'adam':
    adam = Adam(clipnorm=1, lr=lr)
    model.compile(loss=loss, optimizer=adam, metrics=['accuracy'])
elif optimizer == 'rmsprop':
    model.compile(loss=loss, optimizer='rmsprop', metrics=['accuracy'])

checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
    checkpoint_dir.mkdir()
checkpoint_weight_path = (checkpoint_dir / "elsa_{:s}.hdf5".format(lang)).__str__()

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss', min_delta=0, patience=patience, verbose=0, mode='auto'),
    keras.callbacks.ModelCheckpoint(checkpoint_weight_path, monitor='val_loss',
                                    verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)
]
model.fit(X_train,
          y_train,
          batch_size=batch_size,
          epochs=epochs,
          validation_data=(X_val, y_val),
          class_weight=computed_class_weight,
          callbacks=callbacks,
          verbose=1)

freq = {line.split()[0]: int(line.split()[1]) for line in open(emoji_path).readlines()}
freq_topn = sorted(freq.items(), key=itemgetter(1), reverse=True)[:nb_classes]

if multilabel:
    y_pred = model.predict([X_test], batch_size=batch_size)
    y_pred = [np.squeeze(p) for p in y_pred]

    y_test_1d = np.array(y_test).flatten()
    y_pred_1d = np.array(y_pred).flatten()
    print(f1_score(y_test_1d, y_pred_1d > 0.5))
    print(classification_report(y_test_1d, y_pred_1d > 0.5))

    gold, pred = [], []
    for i in range(len(X_test)):
        each_gold, each_pred = [], []
        for c in range(nb_classes):
            if y_test[c][i] == 1.0:
                each_gold.append(c+1)
            else:
                each_gold.append(0)
            if y_pred[c][i] > 0.5:
                each_pred.append(c+1)
            else:
                each_pred.append(0)
        gold.extend(each_gold)
        pred.extend(each_pred)

    target_name = [""] + [e[0] for e in freq_topn]
    print(classification_report(gold, pred, target_names=target_name))
else:
    _, acc = model.evaluate(X_test, y_test, batch_size=batch_size, verbose=0)
    print(acc)

    y_pred = model.predict(X_test)
    print(classification_report(y_test.argmax(axis=1), y_pred.argmax(
        axis=1), target_names=[e[0] for e in freq_topn]))

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 20)           0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 20, 200)      64132000    input_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 20, 200)      0           embedding[0][0]                  
__________________________________________________________________________________________________
bi_lstm_0 (Bidirectio

Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Train on 2106921 samples, validate on 601978 samples
Epoch 1/100
Epoch 2/100


    250/2106921 [..............................] - ETA: 50:15 - loss: 3.8118 - sigmoid_0_loss: 0.3868 - sigmoid_1_loss: 0.2217 - sigmoid_2_loss: 0.1930 - sigmoid_3_loss: 0.1627 - sigmoid_4_loss: 0.1439 - sigmoid_5_loss: 0.0990 - sigmoid_6_loss: 0.0630 - sigmoid_7_loss: 0.0559 - sigmoid_8_loss: 0.0082 - sigmoid_9_loss: 0.0816 - sigmoid_10_loss: 0.0152 - sigmoid_11_loss: 0.0753 - sigmoid_12_loss: 0.0080 - sigmoid_13_loss: 0.0115 - sigmoid_14_loss: 0.0614 - sigmoid_15_loss: 0.0652 - sigmoid_16_loss: 0.0270 - sigmoid_17_loss: 0.0195 - sigmoid_18_loss: 0.0484 - sigmoid_19_loss: 0.0864 - sigmoid_20_loss: 0.0339 - sigmoid_21_loss: 0.0830 - sigmoid_22_loss: 0.0430 - sigmoid_23_loss: 0.0572 - sigmoid_24_loss: 0.0740 - sigmoid_25_loss: 0.1552 - sigmoid_26_loss: 0.0395 - sigmoid_27_loss: 0.0386 - sigmoid_28_loss: 0.0477 - sigmoid_29_loss: 0.0956 - sigmoid_30_loss: 0.0663 - sigmoid_31_loss: 0.0201 - sigmoid_32_loss: 0.0404 - sigmoid_33_loss: 0.0063 - sigmoid_34_loss: 0.0669 - sigmoid_35_loss: 0.03

Epoch 3/100


    250/2106921 [..............................] - ETA: 47:07 - loss: 3.6806 - sigmoid_0_loss: 0.3323 - sigmoid_1_loss: 0.3035 - sigmoid_2_loss: 0.1704 - sigmoid_3_loss: 0.1669 - sigmoid_4_loss: 0.1047 - sigmoid_5_loss: 0.1543 - sigmoid_6_loss: 0.0324 - sigmoid_7_loss: 0.0210 - sigmoid_8_loss: 6.5937e-04 - sigmoid_9_loss: 0.0530 - sigmoid_10_loss: 0.1131 - sigmoid_11_loss: 0.0721 - sigmoid_12_loss: 3.6187e-04 - sigmoid_13_loss: 0.0510 - sigmoid_14_loss: 0.0285 - sigmoid_15_loss: 0.0789 - sigmoid_16_loss: 0.0387 - sigmoid_17_loss: 0.0758 - sigmoid_18_loss: 0.0230 - sigmoid_19_loss: 0.0899 - sigmoid_20_loss: 0.0829 - sigmoid_21_loss: 0.0839 - sigmoid_22_loss: 0.0747 - sigmoid_23_loss: 0.0459 - sigmoid_24_loss: 0.0414 - sigmoid_25_loss: 0.0579 - sigmoid_26_loss: 0.0096 - sigmoid_27_loss: 0.0602 - sigmoid_28_loss: 0.0336 - sigmoid_29_loss: 0.0132 - sigmoid_30_loss: 0.0475 - sigmoid_31_loss: 0.0436 - sigmoid_32_loss: 0.0740 - sigmoid_33_loss: 0.0054 - sigmoid_34_loss: 0.0430 - sigmoid_35_lo

Epoch 4/100


    250/2106921 [..............................] - ETA: 48:21 - loss: 3.6778 - sigmoid_0_loss: 0.3520 - sigmoid_1_loss: 0.2191 - sigmoid_2_loss: 0.1570 - sigmoid_3_loss: 0.1379 - sigmoid_4_loss: 0.1436 - sigmoid_5_loss: 0.0881 - sigmoid_6_loss: 0.0499 - sigmoid_7_loss: 0.0646 - sigmoid_8_loss: 0.0100 - sigmoid_9_loss: 0.0691 - sigmoid_10_loss: 0.0877 - sigmoid_11_loss: 0.1283 - sigmoid_12_loss: 0.0074 - sigmoid_13_loss: 0.0369 - sigmoid_14_loss: 0.0698 - sigmoid_15_loss: 0.0996 - sigmoid_16_loss: 0.0171 - sigmoid_17_loss: 0.0335 - sigmoid_18_loss: 0.0172 - sigmoid_19_loss: 0.1143 - sigmoid_20_loss: 0.1332 - sigmoid_21_loss: 0.1029 - sigmoid_22_loss: 0.0383 - sigmoid_23_loss: 0.0649 - sigmoid_24_loss: 0.0393 - sigmoid_25_loss: 0.0736 - sigmoid_26_loss: 0.0198 - sigmoid_27_loss: 0.0590 - sigmoid_28_loss: 0.0204 - sigmoid_29_loss: 0.0362 - sigmoid_30_loss: 0.0962 - sigmoid_31_loss: 0.0511 - sigmoid_32_loss: 0.0318 - sigmoid_33_loss: 0.0115 - sigmoid_34_loss: 0.0390 - sigmoid_35_loss: 0.04

Epoch 5/100


    250/2106921 [..............................] - ETA: 47:51 - loss: 3.5191 - sigmoid_0_loss: 0.3770 - sigmoid_1_loss: 0.2454 - sigmoid_2_loss: 0.1506 - sigmoid_3_loss: 0.1288 - sigmoid_4_loss: 0.1701 - sigmoid_5_loss: 0.0977 - sigmoid_6_loss: 0.0780 - sigmoid_7_loss: 0.0448 - sigmoid_8_loss: 0.0012 - sigmoid_9_loss: 0.0262 - sigmoid_10_loss: 0.0898 - sigmoid_11_loss: 0.0703 - sigmoid_12_loss: 0.0015 - sigmoid_13_loss: 0.0451 - sigmoid_14_loss: 0.0416 - sigmoid_15_loss: 0.1016 - sigmoid_16_loss: 0.0453 - sigmoid_17_loss: 0.0141 - sigmoid_18_loss: 0.0358 - sigmoid_19_loss: 0.0692 - sigmoid_20_loss: 0.0602 - sigmoid_21_loss: 0.0699 - sigmoid_22_loss: 0.0217 - sigmoid_23_loss: 0.0648 - sigmoid_24_loss: 0.0559 - sigmoid_25_loss: 0.0796 - sigmoid_26_loss: 0.0319 - sigmoid_27_loss: 0.0662 - sigmoid_28_loss: 0.0058 - sigmoid_29_loss: 0.0182 - sigmoid_30_loss: 0.0396 - sigmoid_31_loss: 0.0217 - sigmoid_32_loss: 0.0176 - sigmoid_33_loss: 0.0049 - sigmoid_34_loss: 0.0265 - sigmoid_35_loss: 0.04

Epoch 6/100


    250/2106921 [..............................] - ETA: 47:17 - loss: 3.5004 - sigmoid_0_loss: 0.3536 - sigmoid_1_loss: 0.1804 - sigmoid_2_loss: 0.1567 - sigmoid_3_loss: 0.0918 - sigmoid_4_loss: 0.1566 - sigmoid_5_loss: 0.1175 - sigmoid_6_loss: 0.0572 - sigmoid_7_loss: 0.0312 - sigmoid_8_loss: 0.0339 - sigmoid_9_loss: 0.0920 - sigmoid_10_loss: 0.0376 - sigmoid_11_loss: 0.1246 - sigmoid_12_loss: 0.0144 - sigmoid_13_loss: 0.0457 - sigmoid_14_loss: 0.0674 - sigmoid_15_loss: 0.0756 - sigmoid_16_loss: 0.0159 - sigmoid_17_loss: 0.0348 - sigmoid_18_loss: 0.0156 - sigmoid_19_loss: 0.1052 - sigmoid_20_loss: 0.0562 - sigmoid_21_loss: 0.0387 - sigmoid_22_loss: 0.0086 - sigmoid_23_loss: 0.0416 - sigmoid_24_loss: 0.0734 - sigmoid_25_loss: 0.0390 - sigmoid_26_loss: 0.0433 - sigmoid_27_loss: 0.0353 - sigmoid_28_loss: 0.0152 - sigmoid_29_loss: 0.0272 - sigmoid_30_loss: 0.0078 - sigmoid_31_loss: 0.0686 - sigmoid_32_loss: 0.0524 - sigmoid_33_loss: 0.0053 - sigmoid_34_loss: 0.0271 - sigmoid_35_loss: 0.03

0.12943768046914536
              precision    recall  f1-score   support

         0.0       0.99      1.00      0.99  18962307
         1.0       0.62      0.07      0.13    300989

    accuracy                           0.98  19263296
   macro avg       0.80      0.54      0.56  19263296
weighted avg       0.98      0.98      0.98  19263296



  'precision', 'predicted', average, warn_for)


              precision    recall  f1-score   support

                   0.99      1.00      0.99  18962307
           😂       0.56      0.11      0.19     53014
           😭       0.52      0.03      0.05     25023
           ❤       0.71      0.02      0.04     16602
           🤣       0.29      0.00      0.00     11424
           😍       0.64      0.01      0.02     11829
           🥺       0.54      0.02      0.05      9305
           💜       0.63      0.10      0.18      5267
           🔥       0.66      0.14      0.22      5587
           ♀       0.63      0.66      0.65      1916
           💕       0.74      0.00      0.01      4927
           🙏       0.66      0.24      0.35      5626
           😩       0.46      0.00      0.00      6178
           ♂       0.57      0.27      0.37      1169
           👏       0.61      0.12      0.20      3192
           🥰       0.00      0.00      0.00      4646
           😊       0.74      0.01      0.01      5827
           🤦       0.63    