In [2]:
!pip install datasets



In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import datasets
import os
import random
import torchmetrics

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from datasets import load_dataset
from nltk.tokenize import word_tokenize
from sklearn.model_selection import train_test_split
import nltk

from collections import Counter


os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(42)
random.seed(42)

torch.manual_seed(42) 
torch.cuda.manual_seed(42)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [4]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Data Preparation


In [6]:
dataset = datasets.load_dataset('ag_news')

Downloading readme:   0%|          | 0.00/8.07k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 18.6M/18.6M [00:00<00:00, 71.8MB/s]
Downloading data: 100%|██████████| 1.23M/1.23M [00:00<00:00, 12.0MB/s]


Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

In [7]:
words = Counter()

for example in tqdm(dataset['train']['text']):
    prccessed_text = example.lower().translate(
        str.maketrans('', '', string.punctuation))

    for word in word_tokenize(prccessed_text):
        words[word] += 1


vocab = set(['<unk>', '<pad>'])
counter_threshold = 25

for char, cnt in words.items():
    if cnt > counter_threshold:
        vocab.add(char)

print(f'Vocab size: {len(vocab)}')

word2ind = {char: i for i, char in enumerate(vocab)}
ind2word = {i: char for char, i in word2ind.items()}

100%|██████████| 120000/120000 [00:49<00:00, 2411.83it/s]

Vocab size: 11840





In [8]:
class NewsDataset:
    def __init__(self, sentences):
        self.data = sentences
        self.unk_id = word2ind['<unk>']
        self.pad_id = word2ind['<pad>']

    def __getitem__(self, idx):
        processed_text = self.data[idx]['text'].lower().translate(
            str.maketrans('', '', string.punctuation))
        tokenized_sentence = []
        tokenized_sentence += [
            word2ind.get(word, self.unk_id) for word in word_tokenize(processed_text)
            ] 

        train_sample = {
            "text": torch.LongTensor(tokenized_sentence),
            "label": torch.LongTensor([self.data[idx]['label']])
        }

        return train_sample

    def __len__(self):
        return len(self.data)


In [9]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    sequences = [item['text'] for item in batch]
    targets = [item['label'] for item in batch]
    
    MAX_LENGTH = 512
    sequences = [seq[:MAX_LENGTH] for seq in sequences]
    
    sequences_padded = pad_sequence(sequences=sequences, batch_first=True, padding_value=word2ind['<pad>'])
    
    return sequences_padded, torch.LongTensor(targets)

In [10]:
train_dataset = NewsDataset(dataset['train'])

np.random.seed(42)
idx = np.random.choice(np.arange(len(dataset['test'])), 5000)
eval_dataset = NewsDataset(dataset['test'].select(idx))

batch_size = 128
train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=batch_size , num_workers=os.cpu_count(), pin_memory=True)

eval_dataloader = DataLoader(
    eval_dataset, shuffle=False, collate_fn=collate_fn, batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)

# Model

In [11]:
class RNN(nn.Module):
    def __init__( self,  vocab_size, hidden_dim=256, embed_dim=512, num_classes = 4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=word2ind['<pad>'])
        self.rnn = nn.LSTM(input_size=embed_dim, 
                           hidden_size=hidden_dim, 
                           num_layers=3, 
                           batch_first=True,
                           bidirectional=True,
                           dropout=0.2)
        
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, num_classes)
        )
        

    def forward(self, input_batch):
        embeddings = self.embedding(input_batch) 
        output, (_, _) = self.rnn(embeddings) 
        output = output.mean(dim=1)

        return self.fc(output)

In [12]:
model = RNN(vocab_size=len(vocab)).to(device)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
optim = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
amp_scaler = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=3, factor=0.5)

accuracy = torchmetrics.classification.MulticlassAccuracy(4).to(device)

In [13]:
for epoch in range(300):
    
    train_total_loss = 0
    train_total_acc = 0
    model.train()
    loop = tqdm(train_dataloader)
    for X,y in loop:
        X,y = X.to(device), y.to(device)
        with torch.cuda.amp.autocast():
            logits = model(X)
            loss = loss_fn(logits, y)
        
        optim.zero_grad(set_to_none='True')
        amp_scaler.scale(loss).backward()
        amp_scaler.step(optim)
        amp_scaler.update()
        
        train_total_loss += loss
        train_total_acc += accuracy(logits, y)
        
    train_total_loss /= len(train_dataloader)
    train_total_acc /= len(train_dataloader)
    
    
    test_total_loss = 0
    test_total_acc = 0
    model.eval()
    with torch.inference_mode():
        loop = tqdm(eval_dataloader)
        for X,y in loop:
            X,y = X.to(device), y.to(device)
            logits = model(X)
            loss = loss_fn(logits, y)
            
            test_total_loss += loss
            test_total_acc += accuracy(logits, y)
            
        test_total_loss /= len(eval_dataloader)
        test_total_acc /= len(eval_dataloader)
    
    scheduler.step(test_total_acc)
    
    print(f"epoch: {epoch}   train_loss: {train_total_loss}, train_acc: {train_total_acc}, test_loss: {test_total_loss}, test_acc: {test_total_acc}")

100%|██████████| 938/938 [01:22<00:00, 11.38it/s]
100%|██████████| 40/40 [00:01<00:00, 23.43it/s]


epoch: 0   train_loss: 0.6439827680587769, train_acc: 0.8486865162849426, test_loss: 0.5729407668113708, test_acc: 0.8958719372749329


100%|██████████| 938/938 [01:21<00:00, 11.46it/s]
100%|██████████| 40/40 [00:02<00:00, 19.78it/s]


epoch: 1   train_loss: 0.5221606492996216, train_acc: 0.917437732219696, test_loss: 0.5527138113975525, test_acc: 0.9036030769348145


100%|██████████| 938/938 [01:21<00:00, 11.45it/s]
100%|██████████| 40/40 [00:01<00:00, 21.89it/s]


epoch: 2   train_loss: 0.486167848110199, train_acc: 0.9356526136398315, test_loss: 0.54457026720047, test_acc: 0.9071971774101257


100%|██████████| 938/938 [01:21<00:00, 11.48it/s]
100%|██████████| 40/40 [00:01<00:00, 22.98it/s]


epoch: 3   train_loss: 0.4589358866214752, train_acc: 0.9490643739700317, test_loss: 0.5478227734565735, test_acc: 0.910982608795166


100%|██████████| 938/938 [01:21<00:00, 11.54it/s]
100%|██████████| 40/40 [00:01<00:00, 23.48it/s]


epoch: 4   train_loss: 0.435666024684906, train_acc: 0.9605237245559692, test_loss: 0.5387428998947144, test_acc: 0.9124493598937988


100%|██████████| 938/938 [01:22<00:00, 11.36it/s]
100%|██████████| 40/40 [00:01<00:00, 23.01it/s]


epoch: 5   train_loss: 0.40533602237701416, train_acc: 0.9758902192115784, test_loss: 0.54695063829422, test_acc: 0.9135432243347168


100%|██████████| 938/938 [01:22<00:00, 11.33it/s]
100%|██████████| 40/40 [00:01<00:00, 23.43it/s]


epoch: 6   train_loss: 0.39358994364738464, train_acc: 0.9818167090415955, test_loss: 0.5532257556915283, test_acc: 0.9145427942276001


100%|██████████| 938/938 [01:22<00:00, 11.39it/s]
100%|██████████| 40/40 [00:02<00:00, 19.24it/s]


epoch: 7   train_loss: 0.3870829939842224, train_acc: 0.9846965074539185, test_loss: 0.5572208762168884, test_acc: 0.9152964949607849


100%|██████████| 938/938 [01:22<00:00, 11.42it/s]
100%|██████████| 40/40 [00:01<00:00, 23.60it/s]


epoch: 8   train_loss: 0.38191717863082886, train_acc: 0.9870542287826538, test_loss: 0.5659303665161133, test_acc: 0.9092499017715454


100%|██████████| 938/938 [01:22<00:00, 11.36it/s]
100%|██████████| 40/40 [00:01<00:00, 22.13it/s]


epoch: 9   train_loss: 0.37380334734916687, train_acc: 0.9909405708312988, test_loss: 0.5659406781196594, test_acc: 0.9147921800613403


100%|██████████| 938/938 [01:22<00:00, 11.31it/s]
100%|██████████| 40/40 [00:01<00:00, 20.58it/s]


epoch: 10   train_loss: 0.3703511655330658, train_acc: 0.9924708604812622, test_loss: 0.5708441138267517, test_acc: 0.9144525527954102


100%|██████████| 938/938 [01:22<00:00, 11.41it/s]
100%|██████████| 40/40 [00:01<00:00, 22.96it/s]


epoch: 11   train_loss: 0.36941230297088623, train_acc: 0.9927763342857361, test_loss: 0.5688592195510864, test_acc: 0.9157541394233704


100%|██████████| 938/938 [01:22<00:00, 11.39it/s]
100%|██████████| 40/40 [00:01<00:00, 23.49it/s]


epoch: 12   train_loss: 0.36800897121429443, train_acc: 0.9933384656906128, test_loss: 0.568216860294342, test_acc: 0.9168359637260437


100%|██████████| 938/938 [01:22<00:00, 11.31it/s]
100%|██████████| 40/40 [00:01<00:00, 23.56it/s]


epoch: 13   train_loss: 0.36570724844932556, train_acc: 0.9943622946739197, test_loss: 0.5692890286445618, test_acc: 0.9178921580314636


100%|██████████| 938/938 [01:22<00:00, 11.36it/s]
100%|██████████| 40/40 [00:01<00:00, 22.19it/s]


epoch: 14   train_loss: 0.36422836780548096, train_acc: 0.9947353005409241, test_loss: 0.5711308121681213, test_acc: 0.9171568155288696


100%|██████████| 938/938 [01:22<00:00, 11.41it/s]
100%|██████████| 40/40 [00:01<00:00, 23.84it/s]


epoch: 15   train_loss: 0.3633574843406677, train_acc: 0.9951123595237732, test_loss: 0.5732951760292053, test_acc: 0.917816162109375


100%|██████████| 938/938 [01:22<00:00, 11.41it/s]
100%|██████████| 40/40 [00:01<00:00, 23.70it/s]


epoch: 16   train_loss: 0.36299243569374084, train_acc: 0.9952059388160706, test_loss: 0.5777866840362549, test_acc: 0.9164170622825623


100%|██████████| 938/938 [01:22<00:00, 11.32it/s]
100%|██████████| 40/40 [00:01<00:00, 23.41it/s]


epoch: 17   train_loss: 0.3617021441459656, train_acc: 0.995691180229187, test_loss: 0.5765952467918396, test_acc: 0.9161124229431152


100%|██████████| 938/938 [01:23<00:00, 11.29it/s]
100%|██████████| 40/40 [00:01<00:00, 22.56it/s]


epoch: 18   train_loss: 0.36093997955322266, train_acc: 0.9960164427757263, test_loss: 0.5758065581321716, test_acc: 0.91604083776474


 42%|████▏     | 393/938 [00:34<00:48, 11.33it/s]


KeyboardInterrupt: 