In [36]:
import keras
import pandas as pd
from keras_self_attention import SeqSelfAttention
from keras.layers import Dense,Embedding, Flatten, Conv1D, GlobalMaxPooling1D, LSTM, Bidirectional, Dropout
from keras.preprocessing.text import Tokenizer, text_to_word_sequence
from keras.preprocessing.sequence import pad_sequences

In [37]:
upsampled_input_cc_types_df = pd.read_csv("../upsampled_train_val_test/train_no_charade.csv")
val_cc_types_df = pd.read_csv("../upsampled_train_val_test/val_no_charade.csv")
test_cc_types_df = pd.read_csv("../upsampled_train_val_test/test_no_charade.csv")

In [38]:
tokenizer = Tokenizer(filters='"#$%&()*+-/:;<=>?@[\]^_`{|}~')
tokenizer.fit_on_texts(pd.concat([upsampled_input_cc_types_df,val_cc_types_df,test_cc_types_df])['clue'])

In [49]:
cc_input_df = upsampled_input_cc_types_df.drop(['category'],axis=1)
cc_val_df = val_cc_types_df.drop_duplicates().drop(['category'],axis=1)
cc_test_df = test_cc_types_df.drop_duplicates().drop(['category'],axis=1)

In [50]:
cc_input_df['clue'] = cc_input_df['clue'].apply(lambda x:' '.join(text_to_word_sequence(x)))
cc_val_df['clue'] = cc_val_df['clue'].apply(lambda x:' '.join(text_to_word_sequence(x)))
cc_test_df['clue'] = cc_test_df['clue'].apply(lambda x:' '.join(text_to_word_sequence(x)))

In [51]:
cc_input_data = pad_sequences(tokenizer.texts_to_sequences(cc_input_df.clue.tolist()),maxlen=15)
cc_val_data = pad_sequences(tokenizer.texts_to_sequences(cc_val_df.clue.tolist()),maxlen=15)
cc_test_data = pad_sequences(tokenizer.texts_to_sequences(cc_test_df.clue.tolist()),maxlen=15)

In [52]:
cc_input_data_out = cc_input_df[cc_input_df.columns[2:]] * 1
cc_val_data_out = cc_val_df[cc_val_df.columns[2:]] * 1
cc_test_data_out = cc_test_df[cc_test_df.columns[2:]] * 1

In [69]:
import keras
from keras_self_attention import SeqSelfAttention


model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=len(tokenizer.index_word)+1,input_shape=(15,),output_dim=30,mask_zero=False, name='Embedding'))
model.add(SeqSelfAttention(attention_activation='sigmoid',name='self_attention'))
model.add(keras.layers.Bidirectional(keras.layers.LSTM(units=128,name='LSTM')))
# model.add(Flatten())
model.add(keras.layers.Dense(units=14,activation='softmax',name='output'))
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['categorical_accuracy'],
)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Embedding (Embedding)        (None, 15, 30)            866340    
_________________________________________________________________
self_attention (SeqSelfAtten (None, 15, 30)            1985      
_________________________________________________________________
bidirectional_21 (Bidirectio (None, 256)               162816    
_________________________________________________________________
output (Dense)               (None, 14)                3598      
Total params: 1,034,739
Trainable params: 1,034,739
Non-trainable params: 0
_________________________________________________________________


In [70]:
history = model.fit(cc_input_data,cc_input_data_out,validation_data=(cc_val_data,cc_val_data_out),validation_split=0.1,epochs=10,batch_size=1028)

Train on 109707 samples, validate on 1920 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [75]:
model.get_layer('self_attention').get_weights()

[array([[ 6.56322688e-02, -2.51274049e-01,  5.90076819e-02,
         -4.07265604e-01, -9.18292403e-02, -2.27919623e-01,
          3.14504474e-01,  1.17114253e-01, -2.48483613e-01,
         -7.26953819e-02, -1.30351111e-01,  6.37510419e-02,
          2.20103525e-02,  3.70873272e-01, -7.57387141e-03,
         -5.41849881e-02, -1.28409758e-01,  1.39092922e-01,
         -6.19974881e-02, -8.08947608e-02,  1.32738262e-01,
          4.97642234e-02,  9.81043503e-02, -2.39714399e-01,
         -3.53506178e-01,  1.00431532e-01,  4.45404723e-02,
          6.08091317e-02, -2.07012996e-01, -2.57985413e-01,
         -1.28019691e-01, -1.08170278e-01],
        [ 3.77329835e-03, -1.51742455e-02,  1.69313341e-01,
         -1.19187452e-01,  3.03264529e-01,  1.33833885e-01,
          1.76806096e-02,  6.77089393e-02, -2.46614948e-01,
          4.74581011e-02, -7.02788979e-02, -1.05065234e-01,
         -1.06939830e-01,  5.65955266e-02, -2.10264012e-01,
          4.67859721e-03, -9.56497565e-02, -6.91840351e-

In [33]:
len(cc_test_data)

393