In [None]:
import tensorflow
import pickle
from tensorflow.keras.utils import to_categorical
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from src.dataloader import IMDB
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix
from src.model import get_model
from src.embedding import *

# load data and embedding 

In [None]:
# load confing
import yaml
import pickle
with open('config.yaml', 'r') as f:
    conf = yaml.load(f)
BATCH_SIZE = conf["MODEL"]["BATCH_SIZE"]
MAX_EPOCHS = conf["MODEL"]["MAX_EPOCHS"]
WORD2VEC_MODEL = conf["EMBEDDING"]["WORD2VEC_MODEL"]

In [None]:
print('load data ...')
X_train = np.load('data/X_train.npy')
y_train = np.load('data/y_train.npy')
X_test = np.load('data/X_test.npy')
y_test = np.load('data/y_test.npy')
X_val = np.load('data/X_val.npy')
y_val = np.load('data/y_val.npy')

with open('data/word_index.pkl', 'rb') as f:
    word_index = pickle.load(f)

In [None]:
# load embedding
from gensim import models
word2vec_model = models.KeyedVectors.load_word2vec_format(WORD2VEC_MODEL, binary=True)
embeddings_index, embedding_dim = get_embeddings_index(word2vec_model)
embedding_layer = get_embedding_layer(word_index, embeddings_index, embedding_dim, True)
word2vec_model = None

# train model

In [None]:
lstm = get_model(2, embedding_layer)

In [None]:
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)
lstm.compile(
    loss='binary_crossentropy',
    optimizer=tf.keras.optimizers.Adam(0.0001),
    metrics=['accuracy'])

In [None]:
lstm.summary()

In [None]:
# train model
history = lstm.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=MAX_EPOCHS,
    callbacks=[callback],
    validation_data=(X_test, y_test)
)

# history

In [None]:
from matplotlib import pyplot as plt
#  "Accuracy"
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
# "Loss"
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()