In [3]:
import sys
sys.path.append('../')
from data.dataset import get_data
from models.rnn import MyModel
from metrics import eval
from solver import solver_llm

import copy
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

In [None]:
# hyperparameters
train_batch_size = 250
test_batch_size = 250
embedding_size = 12
hidden_size = 64
num_layers = 1
dropout = 0.1
num_class = 2
bidrectional = True
EPOCH = 20
train_file="../data/train_en.tsv"
test_file="../data/dev_en.tsv"
device="cpu"

train_dataset, test_dataset, alphabet, longest_sent, train_x, test_x = get_data(train_file, test_file)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)

model = MyModel(
    len(alphabet),
    longest_sent,
    embedding_size,
    hidden_size,
    num_layers,
    dropout,
    num_class,
    bidrectional,
    device
)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters())

solver = solver_llm.SolverLLM(
    model,
    optimizer,
    criterion,
    model_type="RNN")

## training loop
best = 0.0
best_cm = None
best_model = None
train_acc_epoch = []
valid_acc_epoch = []
for epoch in range(EPOCH):
    solver.epoch = epoch
    # train loop
    train_acc, train_cm = solver.train(train_loader)
    train_acc_epoch.append(train_acc.detach().cpu().numpy())

    # validation loop
    valid_acc, valid_cm = solver.validate(test_loader)
    valid_acc_epoch.append(valid_acc.detach().cpu().numpy())

    if valid_acc > best:
        best = valid_acc
        best_cm = valid_cm
        best_model = copy.deepcopy(model)

print('Best Prec @1 Acccuracy: {:.4f}'.format(best))
per_cls_acc = best_cm.diag().detach().numpy().tolist()
for i, acc_i in enumerate(per_cls_acc):
    print("Accuracy of Class {}: {:.4f}".format(i, acc_i))


plt.figure()
plt.plot(range(EPOCH), train_acc_epoch, label='train')
plt.plot(range(EPOCH), valid_acc_epoch, label='validation')
plt.legend()
plt.title("accuracy curve")
plt.xlabel('epoch')
plt.ylabel('accuracy')