In [None]:
import torch
import argparse

#- train- #

def train_epoch(model, optimizer, train_loader, criterion, args):
    model.train()
    train_loss, train_acc = 0, 0

    for id, (feature, target) in enumerate(train_loader):
        feature, target = feature.cuda(), target.cuda()

        optimizer.zero_grad()

        out = model(feature)
        predicted = torch.argmax(out, dim=1)

        correct = predicted.squeeze() == target.int()
        train_acc += correct.sum().item()

        loss = criterion(out, target.cuda())
        train_loss += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        del feature, target


    avg_loss = train_loss / len(train_loader)
    avg_accuracy = train_acc / (len(train_loader) * train_loader.batch_size)

    return avg_loss, avg_accuracy

def evaluate_epoch(model, data_loader, criterion):
    model.eval()
    acc, loss = 0, 0

    with torch.no_grad():
        for id, (feature, target) in enumerate(data_loader):
            feature, target = feature.cuda(), target.cuda()

            out = model(feature)
            predicted = torch.argmax(out, dim=1)

            correct = predicted.squeeze() == target.int()
            acc += correct.sum().item()

            loss += criterion(out, target.cuda()).item()
            del feature, target

    avg_loss = loss / len(data_loader)
    avg_accuracy = acc / (len(data_loader) * data_loader.batch_size)

    return avg_loss, avg_accuracy

def train_f(model, optimizer, train_loader, val_loader, criterion, epochs, es_threshold, args):
    val_loss_min = torch.inf
    es_trigger = 0

    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []

    for e in range(1, epochs+1):
        train_loss, train_accuracy = train_epoch(model, optimizer, train_loader, criterion, args)
        val_loss, val_accuracy = evaluate_epoch(model, val_loader, criterion)

        print("[EPOCH: %d], Train Loss: %5.2f | Train Accuracy: %5.2f%%" % (e, train_loss, train_accuracy * 100))
        print("[EPOCH: %d], Val Loss: %5.2f | Val Accuracy: %5.2f%%" % (e, val_loss, val_accuracy * 100))

        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        if val_loss < val_loss_min:
            val_loss_min = val_loss
            torch.save(model.state_dict(), './sentiment_lstm.pt')
            es_trigger = 0
        else:
            es_trigger += 1

        if es_trigger >= es_threshold:
            print("Early stopping triggered!")
            break

    return train_losses, train_accuracies, val_losses, val_accuracies
