<a href="https://colab.research.google.com/github/hritiksth764/END-COURSE-SESSION-4.0-/blob/main/Session5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import time
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

from torchtext.datasets import DBpedia, YelpReviewPolarity, SogouNews, AG_NEWS
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = get_tokenizer("basic_english")

# Hyperparameters
EPOCHS = 5 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training


#Class 1
class Classification(nn.Module):
  def __init__(self, vocab_size, embed_dim, num_class):
      super(Classification, self).__init__()
      self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
      self.fc = nn.Linear(embed_dim, num_class)
      self.init_weights()

  def init_weights(self):
      initrange = 0.5
      self.embedding.weight.data.uniform_(-initrange, initrange)
      self.fc.weight.data.uniform_(-initrange, initrange)
      self.fc.bias.data.zero_()

  def forward(self, text, offsets):
      embedded = self.embedding(text, offsets)
      return self.fc(embedded)

#Class 2
class MainClass:  
  def __init__(self, dataset_class=AG_NEWS):
    self.model = None
    self.dataset_class = dataset_class

    train_iter = self.dataset_class(split = 'train')
    self.vocab = build_vocab_from_iterator(self.yield_tokens(train_iter), specials=["<unk>"])
    self.vocab.set_default_index(self.vocab["<unk>"])

    self.text_pipeline = lambda x: self.vocab(tokenizer(x))
    self.label_pipeline = lambda x: int(x) - 1

  @staticmethod
  def yield_tokens(data_iter):
      for _, text in data_iter:
        yield tokenizer(text)

  def train(self, dataloader, epoch):
      self.model.train()
      total_acc, total_count = 0, 0
      log_interval = 500
      start_time = time.time()

      for idx, (label, text, offsets) in enumerate(dataloader):
          self.optimizer.zero_grad()
          predited_label = self.model(text, offsets)
          loss = self.criterion(predited_label, label)
          loss.backward()
          torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
          self.optimizer.step()
          total_acc += (predited_label.argmax(1) == label).sum().item()
          total_count += label.size(0)
          if idx % log_interval == 0 and idx > 0:
              elapsed = time.time() - start_time
              print('| epoch {:3d} | {:5d}/{:5d} batches '
                    '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                                total_acc/total_count))
              total_acc, total_count = 0, 0
              start_time = time.time()

  def evaluate(self, dataloader):
      self.model.eval()
      total_acc, total_count = 0, 0

      with torch.no_grad():
          for idx, (label, text, offsets) in enumerate(dataloader):
              predited_label = self.model(text, offsets)
              loss = self.criterion(predited_label, label)
              total_acc += (predited_label.argmax(1) == label).sum().item()
              total_count += label.size(0)
      return total_acc/total_count


  def collate_batch(self, batch):
      label_list, text_list, offsets = [], [], [0]

      for (_label, _text) in batch:
          label_list.append(self.label_pipeline(_label))
          processed_text = torch.tensor(self.text_pipeline(_text), dtype=torch.int64)
          text_list.append(processed_text)
          offsets.append(processed_text.size(0))

      label_list = torch.tensor(label_list, dtype=torch.int64)
      offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
      text_list = torch.cat(text_list)

      return label_list.to(device), text_list.to(device), offsets.to(device)   


  def train_and_validate(self):
    train_iter = self.dataset_class(split='train')
    num_class = len(set([label for (label, text) in train_iter]))
    vocab_size = len(self.vocab)
    emsize = 64
    self.model = Classification(vocab_size, emsize, num_class).to(device)
      
    self.criterion = torch.nn.CrossEntropyLoss()
    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1.0, gamma=0.1)

    total_accu = None
    train_iter, test_iter = self.dataset_class()
    train_dataset = to_map_style_dataset(train_iter)
    test_dataset = to_map_style_dataset(test_iter)
    num_train = int(len(train_dataset) * 0.95)

    split_train_, split_valid_ = \
        random_split(train_dataset, 
                    [num_train, len(train_dataset) - num_train]
                    )

    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                                  shuffle=True, collate_fn=self.collate_batch)
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                                  shuffle=True, collate_fn=self.collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                shuffle=True, collate_fn=self.collate_batch)

    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        self.train(train_dataloader, epoch)
        accu_val = self.evaluate(valid_dataloader)
        if total_accu is not None and total_accu > accu_val:
          scheduler.step()
        else:
          total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid accuracy {:8.3f} '.format(epoch,
                                              time.time() - epoch_start_time,
                                              accu_val))
        print('-' * 59)

    print('Checking the results of test dataset.')
    accu_test = self.evaluate(test_dataloader)
    print('test accuracy {:8.3f}'.format(accu_test))
  
  def predict(self, text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = self.model(text, torch.tensor([0]))
        print(output)
        return output.argmax(1).item() + 1

  def predict_ouput(self, input_text, labels_dict):
    self.model = self.model.to("cpu")
    print("This is a %s news" %labels_dict[self.predict(input_text, self.text_pipeline)])


main_class = MainClass(SogouNews)
main_class.train_and_validate()
main_class.predict_ouput(
    input_text = 'su4 du4 : ( shuo1 mi2ng : dia3n ji1 zi4 do4ng bo1 fa4ng )\n shuo1 mi2ng : dia3n ji1 ga1i a4n niu3 , xua3n ze2 yi1 lu4n ta2n ji2 ke3 ',
    labels_dict = {
      1: 'Sports',
      2: 'Finance',
      3: 'Entertainment',
      4: 'Automobile',
      5: 'Technology'}
)

In [2]:
main_class = MainClass(DBpedia)
main_class.train_and_validate()
main_class.predict_ouput(
    input_text = 'Brekke Church (Norwegian: Brekke kyrkje is a parish church in Gulen Municipality in Sogn og Fjordane county, Norway. It is located in the village of Brekke. The church is part of the Brekke parish in the Nordhordland deanery in the Diocese of BjÃ¸rgvin. The white, wooden church, which has 390 seats, was consecrated on 19 November 1862 by the local Dean Thomas Erichsen. The architect Christian Henrik Grosch made the designs for the church, which is the third church on the site.',
    labels_dict = {
                0: 'Company',
                1: 'EducationalInstitution',
                2: 'Artist',
                3: 'Athlete',
                4: 'OfficeHolder',
                5: 'MeanOfTransportation',
                6: 'Building',
                7: 'NaturalPlace',
                8: 'Village',
                9: 'Animal',
                10: 'Plant',
                11: 'Album',
                12: 'Film',
                13: 'WrittenWork'}
)

100%|██████████| 68.3M/68.3M [00:00<00:00, 125MB/s]


| epoch   1 |   500/ 8313 batches | accuracy    0.694
| epoch   1 |  1000/ 8313 batches | accuracy    0.909
| epoch   1 |  1500/ 8313 batches | accuracy    0.936
| epoch   1 |  2000/ 8313 batches | accuracy    0.951
| epoch   1 |  2500/ 8313 batches | accuracy    0.957
| epoch   1 |  3000/ 8313 batches | accuracy    0.961
| epoch   1 |  3500/ 8313 batches | accuracy    0.965
| epoch   1 |  4000/ 8313 batches | accuracy    0.967
| epoch   1 |  4500/ 8313 batches | accuracy    0.968
| epoch   1 |  5000/ 8313 batches | accuracy    0.970
| epoch   1 |  5500/ 8313 batches | accuracy    0.970
| epoch   1 |  6000/ 8313 batches | accuracy    0.974
| epoch   1 |  6500/ 8313 batches | accuracy    0.973
| epoch   1 |  7000/ 8313 batches | accuracy    0.973
| epoch   1 |  7500/ 8313 batches | accuracy    0.973
| epoch   1 |  8000/ 8313 batches | accuracy    0.975
-----------------------------------------------------------
| end of epoch   1 | time: 79.90s | valid accuracy    0.972 
---------------

In [3]:
main_class = MainClass(SogouNews)
main_class.train_and_validate()
main_class.predict_ouput(
    input_text = 'su4 du4 : ( shuo1 mi2ng : dia3n ji1 zi4 do4ng bo1 fa4ng )\n shuo1 mi2ng : dia3n ji1 ga1i a4n niu3 , xua3n ze2 yi1 lu4n ta2n ji2 ke3 ',
    labels_dict = {
      1: 'Sports',
      2: 'Finance',
      3: 'Entertainment',
      4: 'Automobile',
      5: 'Technology'}
)

| epoch   1 |   500/ 6680 batches | accuracy    0.816
| epoch   1 |  1000/ 6680 batches | accuracy    0.910
| epoch   1 |  1500/ 6680 batches | accuracy    0.918
| epoch   1 |  2000/ 6680 batches | accuracy    0.916
| epoch   1 |  2500/ 6680 batches | accuracy    0.926
| epoch   1 |  3000/ 6680 batches | accuracy    0.925
| epoch   1 |  3500/ 6680 batches | accuracy    0.924
| epoch   1 |  4000/ 6680 batches | accuracy    0.924
| epoch   1 |  4500/ 6680 batches | accuracy    0.929
| epoch   1 |  5000/ 6680 batches | accuracy    0.928
| epoch   1 |  5500/ 6680 batches | accuracy    0.929
| epoch   1 |  6000/ 6680 batches | accuracy    0.927
| epoch   1 |  6500/ 6680 batches | accuracy    0.927
-----------------------------------------------------------
| end of epoch   1 | time: 311.26s | valid accuracy    0.927 
-----------------------------------------------------------
| epoch   2 |   500/ 6680 batches | accuracy    0.932
| epoch   2 |  1000/ 6680 batches | accuracy    0.929
| epoch 