In [None]:
import os
import random
import time
import datetime
import numpy as np
from tqdm import tqdm as tq
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

import torch
from transformers import AdamW, get_linear_schedule_with_warmup



def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def train(model, train_dataloader, val_dataloader, criterion, epochs, learning_rate):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = model.to(device)
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = 0,
                                                num_training_steps = total_steps)

    best_val_loss = float('inf')
    best_model = None

    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        print('-' * 10)

        model.train()
        train_loss = 0
        correct_preds = 0
        total_preds = 0
        
        start_time = time.time()
        for idx, batch in enumerate(train_dataloader): # tqdm
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.item()
            correct_preds += (outputs.argmax(1) == labels).sum().item()
            total_preds += labels.size(0)
            
            # print ETA every batch
            elapsed_time = time.time() - start_time
            elapsed_time_per_step = elapsed_time / (idx + 1)
            eta_dt = datetime.timedelta(seconds=int(elapsed_time_per_step * (len(train_dataloader) - idx - 1)))
            eta = str(eta_dt)
            spent_dt = datetime.timedelta(seconds=int(elapsed_time))   
            spent = str(spent_dt)
            print(f"[TRAINING]  {idx+1}/{len(train_dataloader)} | loss : {loss.item():.4f} | time : {spent}-{str(spent_dt+eta_dt)} | eta : {eta}", end='\r')
            # break
        print()
        
        train_loss /= len(train_dataloader)
        train_acc = correct_preds / total_preds
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        print(f"Train loss {np.round(train_loss, 4)} accuracy {np.round(train_acc, 4)}")

        val_loss, val_acc = evaluate(model, val_dataloader, criterion, device)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f"Val   loss {np.round(val_loss, 4)} accuracy {np.round(val_acc, 4)}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss #
            best_model = model #.state_dict()

    return best_model, train_losses, train_accs, val_losses, val_accs


def evaluate(model, dataloader, criterion, device):
    #device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.eval()
    loss = 0
    correct_preds = 0
    total_preds = 0

    start_time = time.time()
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss += criterion(outputs, labels).item()

            correct_preds += (outputs.argmax(1) == labels).sum().item()
            total_preds += labels.size(0)
            
            # print ETA every batch
            elapsed_time = time.time() - start_time
            elapsed_time_per_step = elapsed_time / (idx + 1)
            eta_dt = datetime.timedelta(seconds=int(elapsed_time_per_step * (len(dataloader) - idx - 1)))
            eta = str(eta_dt)
            spent_dt = datetime.timedelta(seconds=int(elapsed_time))   
            spent = str(spent_dt)
            print(f"[VALIDATION]  {idx+1}/{len(dataloader)} | time : {spent}-{str(spent_dt+eta_dt)} | eta : {eta}", end='\r')
        print()
    return loss / len(dataloader), correct_preds / total_preds


def inference(model, dataloader):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = model.to(device)
    model.eval()

    all_preds = []
    correct_preds = 0
    total_preds = 0

    with torch.no_grad():
        for batch in dataloader:
        #for idx, batch in enumerate(dataloader):
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            all_preds.extend(outputs.argmax(1).tolist())
            correct_preds += (outputs.argmax(1) == labels).sum().item()
            total_preds += labels.size(0)

    accuracy = correct_preds / total_preds

    return accuracy


def result(train_losses, train_accs, val_losses, val_accs):
    epochs = range(len(train_losses))

    plt.figure(figsize=(20, 10))
    plt.subplot(2, 2, 1)
    plt.plot(epochs, train_accs, '.-')
    plt.ylabel('train accuracy')
    #plt.xticks(visible=False)

    #ax2 = plt.subplot(2, 2, 3, sharex=ax1)
    plt.subplot(2, 2, 3)
    plt.title('')
    plt.xlabel('epochs')
    plt.plot(epochs, val_accs, '.-')
    plt.ylabel('validation accuracy')
    #plt.show()
    #plt.savefig('result_acc.png')

    plt.subplot(2, 2, 2)
    plt.plot(epochs, train_losses, '.-')
    plt.ylabel('train loss')
    #plt.xticks(visible=False)

    plt.subplot(2, 2, 4)
    plt.title('')
    plt.xlabel('epochs')
    plt.plot(epochs, val_losses, '.-')
    plt.ylabel('validation loss')
    #plt.show()
    plt.savefig('result.png')
