In [126]:
import os, random

train_path = "data/aclImdb/train/"

#adapted from https://developers.google.com/machine-learning/guides/text-classification/step-2
def load_train_dataset(path: str) -> tuple:
    train_texts = []
    train_labels = []
    
    for label in ['pos', 'neg']:
        cat_path = os.path.join(path, label)
        for file_name in os.listdir(cat_path):
            file_path = os.path.join(cat_path, file_name)
            with open(file_path, 'r') as file:
                text = file.read()
                train_texts.append(text)
            train_labels.append(0 if label=='neg' else 1)
    
    
    random.seed(1)
    random.shuffle(train_texts)
    random.seed(1)
    random.shuffle(train_labels)

    return (train_texts, train_labels)

(train_texts, train_labels) = load_train_dataset(train_path)

In [127]:
from sklearn.model_selection import train_test_split
import numpy as np

def get_smaller_dataset(size: int, texts: list[str], labels: list[int], seed=10) -> tuple:

    random.seed(seed)
    smaller_texts = random.sample(texts, size)
    random.seed(seed)
    smaller_labels = random.sample(labels, size)
    
    sm_train_texts, sm_test_texts, sm_train_labels, sm_test_labels = train_test_split(smaller_texts, np.array(smaller_labels))
    
    return ((sm_train_texts, sm_train_labels), (sm_test_texts, sm_test_labels))

In [128]:
import re

def clear_br_tags(text) -> str:
    tags_removed_text = re.sub('<.*?>', '', text)
    return tags_removed_text

train_texts = [clear_br_tags(text) for text in train_texts]

In [129]:
((sm_train_texts, sm_train_labels), (sm_test_texts, sm_test_labels)) = get_smaller_dataset(texts=train_texts, labels=train_labels, size=1000, seed=random.randint(1, 200))

In [130]:
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer(max_features=1000, ngram_range=(1,2))

X_sm_train = vectorizer.fit_transform(sm_train_texts).toarray()
Y_sm_train = np.array(sm_train_labels)

X_sm_test = vectorizer.fit_transform(sm_test_texts).toarray()
Y_sm_test = np.array(sm_test_labels)