In [None]:
from process_tweets import get_data
from rnn import MyModel
from train_validate import (train, validate, convert_tweet2tensor)

import copy
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import torch
from torch.utils.data import DataLoader

### Train and validate vanilla LSTM model

In [None]:
# hyperparameters
train_batch_size = 250
test_batch_size = 250
embedding_size = 12
hidden_size = 64
num_layers = 1
dropout = 0.1
num_classes = 2
bidrectional = True
EPOCH = 20
train_file="../data/train_en.tsv"
test_file="../data/dev_en.tsv"
device="cpu"

train_dataset, test_dataset, alphabet, longest_sent, train_x, test_x = get_data(train_file, test_file)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)

model = MyModel(
    len(alphabet),
    longest_sent,
    embedding_size,
    hidden_size,
    num_layers,
    dropout,
    num_classes,
    bidrectional,
    device
)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters())

## training loop
best = 0.0
best_cm = None
best_model = None
train_acc_epoch = []
valid_acc_epoch = []
for epoch in range(EPOCH):
    # train loop
    train_acc, train_cm = train(epoch, train_loader, model, optimizer, criterion)
    train_acc_epoch.append(train_acc.detach().cpu().numpy())

    # validation loop
    valid_acc, valid_cm = validate(epoch, test_loader, model, criterion)
    valid_acc_epoch.append(valid_acc.detach().cpu().numpy())

    if valid_acc > best:
        best = valid_acc
        best_cm = valid_cm
        best_model = copy.deepcopy(model)

print('Best Prec @1 Acccuracy: {:.4f}'.format(best))
per_cls_acc = best_cm.diag().detach().numpy().tolist()
for i, acc_i in enumerate(per_cls_acc):
    print("Accuracy of Class {}: {:.4f}".format(i, acc_i))

# save best model
pickle.dump(best_model, open("best_RNN_model.pkl", 'wb'))

plt.figure()
plt.plot(range(EPOCH), train_acc_epoch, label='train')
plt.plot(range(EPOCH), valid_acc_epoch, label='validation')
plt.legend()
plt.title("accuracy curve")
plt.xlabel('epoch')
plt.ylabel('accuracy')
# plt.show()

### Save best model

In [None]:
pickle.dump(best_model, open("best_RNN_model.pkl", 'wb'))


### Test model

In [None]:
loaded_model = pickle.load(open("best_RNN_model.pkl", 'rb'))

In [None]:
test_df = pd.read_csv(test_file, sep='\t',skiprows=0, encoding = 'utf-8')
test_df.head()

In [None]:
test_sample = test_df.sample(10, random_state=0)
for idx, row in test_sample.iterrows():
    test_tensor = convert_tweet2tensor(row['text'], alphabet, longest_sent)
    print(f"test text {test_df.loc[idx,'text']}")
    pred_prob = loaded_model.forward(test_tensor)
    pred_label = torch.argmax(pred_prob)
    print(f"true label {test_df.loc[idx,'HS']}")
    print(f"best model predicted label {pred_label}")
    print()
