<a href="https://colab.research.google.com/github/jdasam/mas1004-2022/blob/main/notebooks/Data_AI_Week13_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RNN Tutorial
### With Names from Different Countries

In [None]:
import torch
import torch.nn as nn
from pathlib import Path
import matplotlib.pyplot as plt

import random
from tqdm.auto import tqdm

DEV= 'cuda'

## 0. Get Dataset

In [None]:
!wget https://download.pytorch.org/tutorial/data.zip
!unzip data.zip

In [None]:
data_dir = Path('data/names')
txt_fns = data_dir.glob("*.txt")

names_in_dict = {}
for txt_fn in txt_fns:
  with open(txt_fn) as f:
    name_of_countries = f.readlines()
    print(f"Category: {txt_fn.stem}")
    names_in_dict[txt_fn.stem] = name_of_countries

In [None]:
for key in names_in_dict:
  print(f"{key}: {names_in_dict[key][0][:-2]}, {len(names_in_dict[key])}")

## 1. Define Dataset

In [None]:
def normalize_name(name):
  return name.replace('\n', '').replace(u'\xa0', u' ').lower()


### 1-1 Add zero-padding 
- Each name in the dataset has a different length
- Therefore, you have to add zero padding so that each data sample in a batch has same length
  - If you use zero-padding, it is better to not use index 0 for input category
  - For example, if you represent `"a"` as a categorical index 0 and then use zero padding, you cannot figure out whether `[0, 0, 0, 0]` is just padded zeros or `[a, a, a, a]`

## 2. Make RNN Model
### 2-1. See how RNN works
- Since our input is a categorical index, we will use nn.Embedding
- ![Diagram](https://datascience-enthusiast.com/figures/rnn_step_forward.png)

## 2.2 Make Name Classification Model
- Input: Sequence of characters, in categorical indices
  - Length of the input sequence is arbitrary 
- Output: Probability of the corresponding nationality for a given name in a sequence of characters
  - Regardless of input length, the output is a single vector
  - Softmax output of each class

## 2.3 Complete Trainer

In [4]:
class Trainer:
  def __init__(self, model, train_loader, valid_loader, model_name='resnet'):
    self.model = model
    self.train_loader = train_loader
    self.valid_loader = valid_loader
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.model.to(self.device)
    self.criterion = nn.NLLLoss()
    self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
    self.best_loss = np.inf
    self.best_acc = 0.0
    self.train_losses = []
    self.valid_losses = []
    self.train_accs = []
    self.valid_accs = []
    self.model_name = model_name

  def validation(self):
    self.model.eval() # change the model from train mode to evaluation mode
    # Some models work in different ways based on whtehter it is on training step
    # or on inference step

    # In validation step, you don't have to calculate the gradient
    # with torch.no_grad():

    current_loss = 0
    num_total_correct_pred = 0
    with torch.inference_mode(): # every torch computation under this indent
    # will be run without calculating the gradient or computation history
      for batch in self.valid_loader:
        images, labels = batch
        images, labels = images.to(self.device), labels.to(self.device)
        outputs = self.model(images)
        probs = torch.softmax(outputs, dim=-1)
        log_probs = torch.log(probs)

        loss = self.criterion(log_probs, labels)
        predicted_classes = torch.argmax(outputs, dim=-1)
        num_acc_pred = (predicted_classes == labels.to(self.device)).sum()
        #num_acc_pred is on self.device
        num_total_correct_pred += num_acc_pred.item()
        # in validation stage, we don't care about single batch's loss
        # we want to see the result for total images of validation set

        current_loss += loss.item() * len(labels)
        # instead of adding the mean loss, we add sum of loss
        # because the batch size can be different
    mean_loss = current_loss / len(self.valid_loader.dataset)
    mean_acc = num_total_correct_pred / len(self.valid_loader.dataset) # number of total datasample in the validation loader
    return mean_loss, mean_acc
    # return {'loss': mean_loss, 'acc': mean_acc}



  def train_by_number_of_epochs(self, num_epochs):
    for epoch in tqdm(range(num_epochs)):
      self.model.train()
      for batch in tqdm(self.train_loader, leave=False):
        images, labels = batch
        images, labels = images.to(self.device), labels.to(self.device)
        self.optimizer.zero_grad()
        outputs = self.model(images) # this is logits
        probs = torch.softmax(outputs, dim=-1)
        log_probs = torch.log(probs)
        loss = self.criterion(log_probs, labels) # you have to feed log_probs

        acc = (torch.argmax(outputs, dim=-1) == labels.to(self.device)).sum() / len(labels)
        # for torch.nn.NLLLoss
        loss.backward()
        self.optimizer.step()

        self.train_losses.append(loss.item())
        self.train_accs.append(acc.item())
        # don't try self.train_losses.append(loss)
      # training step has ended
      # we want to test our model on the validation set
      valid_loss, valid_acc = self.validation()

      # is this model the best? 
      # let's decide it based on valid_acc
      if valid_acc > self.best_acc:
        self.best_acc = valid_acc

        # If it is the best model, save the model's weight'
        models_parameters = self.model.state_dict()
        print(f"Saving best model at epoch {len(self.valid_accs)}, acc: {valid_acc}")
        torch.save(models_parameters, f'{self.model_name}_best.pt')

      self.valid_losses.append(valid_loss)
      self.valid_accs.append(valid_acc)

    # Plot Accuracy curve
    plt.plot(self.train_accs)
    plt.plot(range(len(self.train_loader)-1, len(self.train_accs), len(self.train_loader)) ,self.valid_accs)
    plt.title("Accuracy")