In [1]:
# Imports

import os
import re
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
from tqdm import tqdm

import warnings
warnings.simplefilter('ignore')

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
# Custom Dataset to Read Local IMDB Files

class IMDBDataset(Dataset):
    """
    Custom PyTorch Dataset for IMDB movie reviews.
    Expects directory structure with 'pos' and 'neg' subfolders.
    """
    def __init__(self, root_dir, vocab=None, tokenizer=None, max_length=256):
        self.root_dir = root_dir
        self.max_length = max_length

        # Collect all file paths & labels
        self.samples = []
        for label_name in ["pos", "neg"]:
            label = 1 if label_name == "pos" else 0
            folder_path = os.path.join(root_dir, label_name)
            for filename in os.listdir(folder_path):
                if filename.endswith(".txt"):
                    self.samples.append(
                        (os.path.join(folder_path, filename), label)
                    )

        # Use provided tokenizer or fall back to a simple regex-based one
        self.tokenizer = tokenizer if tokenizer else self.basic_tokenizer

        # Build vocab if not provided (e.g., training split)
        if vocab:
            self.vocab = vocab
        else:
            self.vocab = self.build_vocab()

    def basic_tokenizer(self, text):
        """
        Basic tokenizer: 
        - Lowercases
        - Splits on word boundaries
        - Strips punctuation
        """
        text = text.lower()
        tokens = re.findall(r"\b\w+\b", text)
        return tokens

    def build_vocab(self, min_freq=2):
        """
        Construct a vocabulary dictionary from the training set.
        Rare tokens (freq < min_freq) are ignored.
        """
        counter = Counter()
        print(" Building vocabulary...")
        for path, _ in tqdm(self.samples):
            with open(path, "r", encoding="utf-8") as f:
                tokens = self.tokenizer(f.read())
                counter.update(tokens)

        # Reserve 0 for <PAD>, 1 for <UNK>
        vocab = {"<PAD>": 0, "<UNK>": 1}
        for token, freq in counter.items():
            if freq >= min_freq:
                vocab[token] = len(vocab)

        print(f" Vocab size: {len(vocab):,}")
        return vocab

    def numericalize(self, tokens):
        """
        Convert list of tokens into list of vocab indices.
        """
        return [self.vocab.get(t, self.vocab["<UNK>"]) for t in tokens]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        with open(path, "r", encoding="utf-8") as f:
            text = f.read()
        tokens = self.tokenizer(text)
        token_ids = self.numericalize(tokens)[: self.max_length]
        return torch.tensor(token_ids, dtype=torch.long), torch.tensor(label, dtype=torch.long)

In [None]:
# Collate Function for Batching

def collate_batch(batch):
    """
    Collate function for DataLoader.
    - Pads sequences in a batch to the same length
    - Stacks labels into a tensor
    """
    sequences, labels = zip(*batch)
    padded_seqs = pad_sequence(
        sequences
        , batch_first=True
        , padding_value=0  # <PAD> index
    )
    labels = torch.stack(labels)
    return padded_seqs, labels
