In [64]:
import torch
import torch.nn as nn
import numpy as np
import os, re
from collections import Counter
from torch.utils.data import Dataset, DataLoader


import warnings

In [None]:
import kagglehub
path = kagglehub.dataset_download("pranayprasad/aclimdb")

Downloading from https://www.kaggle.com/api/v1/datasets/download/pranayprasad/aclimdb?dataset_version_number=1...


100%|██████████| 111M/111M [00:11<00:00, 10.1MB/s] 

Extracting files...





In [31]:
import nltk
#nltk.download('stopwords') 
from nltk.corpus import stopwords
stopwords = set(stopwords.words("english"))

In [52]:
def pre_process_text(all_files):
    all_words, seq_len = [], []
    for f_name in all_files:
        text = open(f_name).readlines()[0].lower()
        text = re.sub ( r'[^\w\s]', '', text)
        words = text.split(" ")
        words = [w for w in words if (w not in stopwords) and (len(w) >=0) ]
        all_words+=words
        seq_len.append(len(words))
    return (all_words, seq_len)

train_dir = "../../data/aclImdb/train"

all_files = ([ os.path.join ( train_dir,  f"pos/{f_name}") for f_name in  os.listdir(f"{train_dir}/pos")] + 
             [ os.path.join ( train_dir,  f"neg/{f_name}") for f_name in  os.listdir(f"{train_dir}/neg")] )

train_words, sentence_len = pre_process_text(all_files)

In [57]:
f"avg sentence length: {np.mean(sentence_len)}"

'avg sentence length: 125.16576'

#### Creating Tokenizer

In [None]:
bog = dict(Counter(train_words))
words = sorted([key for (key,value) in bog.items() if value > 500])

words.append("<UNK>")
words.append("<PAD>")

w2i = {w: i for i, w in enumerate(words)}
i2w = {i: w for i, w in enumerate(words)}

In [None]:
class IMDBDataLoader(Dataset):
    def __init__(self, data_path, tokenizer, max_seq_len = 200):
        self.tokenizer = w2i
        self.max_seq_len = max_seq_len
        self.data_files = ([ os.path.join ( data_path,  f"pos/{f_name}") for f_name in  os.listdir(f"{data_path}/pos")] + 
             [ os.path.join ( data_path,  f"neg/{f_name}") for f_name in  os.listdir(f"{data_path}/neg")] )

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, idx):
        file_path = self.data_files[idx]

        def get_sample(f_name):
            text = open(f_name).readlines()[0].lower()
            text = re.sub ( r'[^\w\s]', '', text)
            words = text.split(" ")
            return [w for w in words if (w not in stopwords) and (len(w) >=0) ]

        def get_tokenzied_sample_seq(f_name):
            sample = get_sample(f_name)
            # if more than seq_len, trim it
            if len(sample) > self.max_seq_len:
                rand_start_idx = np.random.randint( (len(sample) - self.max_seq_len) )
                sample = sample[rand_start_idx: (rand_start_idx + self.max_seq_len) ]

            ## tokenized result
            tokenized = []
            for w in sample:
                if w in self.tokenizer:
                    tokenized.append(self.tokenizer[w])
                else:
                    tokenized.append(self.tokenizer["<UNK>"])
            
            sample = torch.tensor(tokenized)
            return sample

        sample = get_tokenzied_sample_seq(file_path)

        ## label
        label = 1
        if "neg" in file_path: label = 0
        return sample, label

def data_collator(batch):
    word_tokens, labels = [], []
    for token, label in batch:
        labels.append(label)
        word_tokens.append(token)

    labels = torch.tensor(labels)
    
    word_tokens = nn.utils.rnn.pad_sequence(word_tokens, batch_first=True, padding_value=w2i["<PAD>"])
    return word_tokens, labels

train_ds = IMDBDataLoader(train_dir, tokenizer=w2i)

data_loader = DataLoader(dataset=train_ds, batch_size=16, shuffle=True, collate_fn=data_collator)

for (s, l) in data_loader:
    # (B x T x C)
    print(s, l)
    break

tensor([[990, 371, 664,  ..., 991, 991, 991],
        [990, 723, 990,  ..., 991, 991, 991],
        [990, 600, 314,  ..., 991, 991, 991],
        ...,
        [321, 682, 886,  ...,  87, 480, 510],
        [516, 369, 562,  ..., 991, 991, 991],
        [306, 314, 990,  ..., 991, 991, 991]]) tensor([0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1])
