In [None]:
import csv
import os, random, sys, copy
import torch, torch.nn as nn, numpy as np
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from nltk.tokenize import word_tokenize
import matplotlib.pyplot as plt
# Adapted From HW3
def train_sarcasm_classification(model, train_dataset, valid_dataset, epochs=10, batch_size=32, learning_rate=.001, print_frequency=25):

    criteria = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)


    epochs = epochs
    batch_size = batch_size
    print_frequency = print_frequency

    #We'll create an instance of a torch dataloader to collate our data. This class handles batching and shuffling (should be done each epoch)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=128, shuffle=False)

    print('Total train batches: {}'.format(train_dataset.__len__() / batch_size))

    best_model_sd = None
    best_accuracy = 0.0

    for i in range(epochs):
        print('### Epoch: ' + str(i+1) + ' ###')
    
        model.train()

        avg_loss = 0
        for step, data in enumerate(train_dataloader):

            (x, x_lengths), y = data	# Our dataset is returning the input example x and also the lengths of the examples, so we'll unpack that here
            optimizer.zero_grad()

            model_output = model(x, x_lengths)

            loss = criteria(model_output.squeeze(1), y.float())

            loss.backward()
            optimizer.step()

            avg_loss += loss.item()

            if step % print_frequency == 0:
                print('epoch: {} batch: {} loss: {}'.format(
                    i,
                    step,
                    avg_loss / print_frequency
                ))
                avg_loss = 0

        print('Evaluating...')
        model.eval()
        with torch.no_grad():
            acc = getMetrics(model, valid_dataloader, 'accuracy')
            if acc > best_accuracy:
                best_model_sd = copy.deepcopy(model.state_dict())
                best_accuracy = acc

    return model.state_dict(), best_model_sd