<a href="https://colab.research.google.com/github/edgeemer/hillel_ml_2025/blob/main/HW_10.IMDB_RNN/IMDB_RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RNN-based classification for IMDB dataset

Architecture: LSTM-based <br>
criterion: BCEWithLogitsLoss <br>
optimizer: Adam <br>


Link to dataset: https://keras.io/api/datasets/imdb/

In [None]:
import collections

import nltk
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
from nltk.corpus import stopwords

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

from keras import datasets

In [None]:
class IMDBDataset(Dataset):
    def __init__(self, reviews, labels):
        if len(reviews) != len(labels):
          raise ValueError("Different sizes of documents and labels!")

        # Convert reviews & labels to tensors (binary classification: 0 or 1)
        self.reviews = [torch.tensor(review, dtype=torch.long) for review in reviews]
        self.labels = torch.tensor(labels, dtype=torch.float)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.reviews[idx], self.labels[idx]

In [None]:
# A collate function to pad sequences in a batch
def seq_collate_fn(batch):
    texts, labels = zip(*batch)
    texts = pad_sequence([text for text in texts], batch_first=True, padding_value=0)
    labels = torch.tensor([[label] for label in labels], dtype=torch.float)
    return texts, labels

In [None]:
# Load dataset

(X_train, y_train), (X_test, y_test) = datasets.imdb.load_data(
    path="imdb.npz",
    num_words=5000,
    skip_top=0,
    maxlen=1000,
    seed=113,
    start_char=1,
    oov_char=2,
    index_from=3
)

train_dataset = IMDBDataset(X_train, y_train)
test_dataset = IMDBDataset(X_test, y_test)

In [None]:
# DataLoader initialization

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=seq_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=seq_collate_fn)

In [None]:
next(iter(train_loader))[1].shape

torch.Size([64, 1])

In [None]:
# Model constructor initialization

class SimpleLSTM(nn.Module):

    def __init__(self, vocab_size, input_size, hidden_size, output_size):
        super(SimpleLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, input_size, padding_idx=0)
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):  # long tensor [B, T]
        h0 = torch.zeros(1, x.size(0), self.hidden_size, device=x.device)  # float tensor [1, B, HiddenSize]
        c0 = torch.zeros(1, x.size(0), self.hidden_size, device=x.device)  # float tensor [1, B, HiddenSize]

        x = self.embedding(x)  # float tensor [B, T, InputSize]
        out, _ = self.lstm(x, (h0, c0))  # float tensor [B, T, HiddenSize]
        out = out.max(1)[0]
        out = self.fc(out)  # float tensor [B, OutputSize]
        return out

In [None]:
vocab_size = max(max(review) for review in X_train) + 1
embed_dim = 512
hidden_dim = 512
output_dim = 1

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device - {device}")

torch.manual_seed(42)

model = SimpleLSTM(vocab_size, embed_dim, hidden_dim, output_dim)

model = model.to(device)

print(model)
print("Number of trainable parameters -", sum(p.numel() for p in model.parameters() if p.requires_grad))

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

Device - cuda
SimpleLSTM(
  (embedding): Embedding(5000, 512, padding_idx=0)
  (lstm): LSTM(512, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=1, bias=True)
)
Number of trainable parameters - 4661761


In [None]:
# Model training

n_epochs = 20
train_losses = []

for epoch in range(n_epochs):

    model.train()

    print(f"Epoch {epoch + 1}/{n_epochs}")

    losses = []

    for i, (docs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        docs = docs.to(device)
        outputs = model(docs)
        loss = criterion(outputs, labels.to(device))
        loss.backward()

        optimizer.step()

        losses.append(loss.item())

    train_losses.append(np.mean(losses))
    print(f"  loss: {train_losses[-1]}")

Epoch 1/20
  loss: 0.3788175328148668
Epoch 2/20
  loss: 0.1944662390185812
Epoch 3/20
  loss: 0.10511103584634614
Epoch 4/20
  loss: 0.04228810432359845
Epoch 5/20
  loss: 0.01791515239097829
Epoch 6/20
  loss: 0.00797848591843146
Epoch 7/20
  loss: 0.003754651333010279
Epoch 8/20
  loss: 0.006274871153240439
Epoch 9/20
  loss: 0.008244252478635623
Epoch 10/20
  loss: 0.006783718116950138
Epoch 11/20
  loss: 0.0020133348524954725
Epoch 12/20
  loss: 0.0003339142543582454
Epoch 13/20
  loss: 0.0001500684004949222
Epoch 14/20
  loss: 9.563873993059795e-05
Epoch 15/20
  loss: 6.826135261977648e-05
Epoch 16/20
  loss: 5.052605536356353e-05
Epoch 17/20
  loss: 7.436339149986583e-05
Epoch 18/20
  loss: 6.806096840067917e-05
Epoch 19/20
  loss: 2.9018520621166623e-05
Epoch 20/20
  loss: 2.0224888096336822e-05


In [None]:
# Model evaluation
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():

    for docs, labels in test_loader:
        outputs = model(docs.to(device))

        preds = (outputs.sigmoid() >= 0.5).long().detach().cpu()

        all_preds.extend(preds.numpy()[:, 0])
        all_labels.extend(labels.numpy()[:, 0])

print(classification_report(all_labels, all_preds))
print('Accuracy:', accuracy_score(all_labels, all_preds))

              precision    recall  f1-score   support

         0.0       0.90      0.89      0.90     12472
         1.0       0.90      0.90      0.90     12472

    accuracy                           0.90     24944
   macro avg       0.90      0.90      0.90     24944
weighted avg       0.90      0.90      0.90     24944

Accuracy: 0.897490378447723
