In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [None]:
from src.models import HAN
from src.dataset import HANDataset

In [None]:
def load_yelp(file_path="data/yelp-2015.json"):
    df = pd.read_json(file_path, lines=True)
    df = df[["stars", "text"]]
    return df

In [None]:
def split_data(df, train_frac=0.8, eval_frac=0.1, test_frac=0.1, random_state=0):
    # Ensure the fractions sum to 1.0
    assert abs(train_frac + eval_frac + test_frac - 1.0) < 1e-6, (
        "Fractions must sum to 1.0"
    )

    df = df.sample(frac=1, random_state=random_state).reset_index(drop=True)

    train_df, temp_df = train_test_split(
        df, test_size=(1 - train_frac), random_state=random_state
    )
    eval_df, test_df = train_test_split(
        temp_df, test_size=0.5, random_state=random_state
    )

    return train_df, eval_df, test_df

In [None]:
def train_model(model, train_dataloader, eval_dataloader, num_epochs=5, lr=1e-3, device=torch.device("cpu")):
    """
    Trains the model on the training set and evaluates on the validation set.
    """
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for batch_docs, batch_labels in tqdm(train_dataloader):
            batch_docs = batch_docs.to(device)
            # Adjust labels from 1-5 to 0-4
            batch_labels = (batch_labels - 1).to(device)
            
            optimizer.zero_grad()
            logits, _, _ = model(batch_docs)
            loss = criterion(logits, batch_labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * batch_docs.size(0)
        avg_loss = running_loss / len(train_dataloader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {avg_loss:.4f}")

        # Evaluate on validation set
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for val_docs, val_labels in eval_dataloader:
                val_docs = val_docs.to(device)
                val_labels = (val_labels - 1).to(device)
                logits, _, _ = model(val_docs)
                predictions = torch.argmax(logits, dim=1)
                correct += (predictions == val_labels).sum().item()
                total += val_labels.size(0)
        val_acc = correct / total if total > 0 else 0
        print(f"Epoch {epoch+1}/{num_epochs} - Validation Accuracy: {val_acc:.4f}")

In [None]:
def evaluate_model(model, test_dataloader, device=torch.device("cpu")):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for test_docs, test_labels in test_dataloader:
            test_docs = test_docs.to(device)
            test_labels = (test_labels - 1).to(device)
            logits, _, _ = model(test_docs)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == test_labels).sum().item()
            total += test_labels.size(0)
    test_acc = correct / total
    print("Test Accuracy: {:.4f}".format(test_acc))
    return test_acc

In [None]:
def visualize_attention(model, dataset, index, device=torch.device("cpu")):
    """
    Visualizes the attention weights for a single document from the dataset.
    Assumes that the dataset has a 'vocab' attribute (word-to-index dictionary).
    """
    model.eval()
    # Get a sample document and its label.
    doc_tensor, label = dataset[index]
    # Add batch dimension: shape (1, max_sentences, max_sentence_length)
    doc_tensor = doc_tensor.unsqueeze(0).to(device)
    
    # Forward pass to get attention weights.
    with torch.no_grad():
        logits, word_attn_weights, sent_attn_weights = model(doc_tensor)
        pred = torch.argmax(logits, dim=1).item() + 1  # adjust back to 1-5 scale

    # Build an inverse vocabulary mapping (id -> token)
    inv_vocab = {v: k for k, v in dataset.vocab.items()}
    
    # Reconstruct tokens from the document tensor.
    doc_array = doc_tensor.squeeze(0).cpu().numpy()  # shape: (max_sentences, max_sentence_length)
    doc_tokens = []
    for sent in doc_array:
        tokens = [inv_vocab.get(token_id, "<UNK>") for token_id in sent]
        doc_tokens.append(tokens)
    
    # Plot Sentence-Level Attention.
    plt.figure(figsize=(8, 4))
    sent_attn = sent_attn_weights.squeeze(0).cpu().numpy()
    sns.barplot(x=list(range(len(sent_attn))), y=sent_attn)
    plt.title("Sentence Attention Weights")
    plt.xlabel("Sentence Index")
    plt.ylabel("Attention Weight")
    plt.show()
    
    # Plot Word-Level Attention for each sentence.
    # word_attn_weights shape: (batch_size*num_sentences, max_sentence_length)
    word_attn = word_attn_weights.cpu().numpy().reshape(-1, doc_tensor.size(-1))
    num_sentences = len(doc_tokens)
    fig, axes = plt.subplots(num_sentences, 1, figsize=(12, num_sentences * 1.5))
    if num_sentences == 1:
        axes = [axes]
    for i in range(num_sentences):
        ax = axes[i]
        tokens = doc_tokens[i]
        attn = word_attn[i]
        ax.bar(range(len(tokens)), attn, tick_label=tokens)
        ax.set_title(f"Word Attention for Sentence {i+1}")
        ax.tick_params(axis='x', rotation=45)
    plt.tight_layout()
    plt.show()
    
    print(f"True Label: {label}, Predicted Label: {pred}")

In [None]:
print("Loading Yelp dataset...")
df = load_yelp()
print(f"Loaded {len(df)} samples from Yelp dataset.")
print(df.head())

Loading Yelp dataset...
Loaded 1569264 samples from Yelp dataset.
   stars                                               text
0      5  dr. goldberg offers everything i look for in a...
1      2  Unfortunately, the frustration of being Dr. Go...
2      4  Dr. Goldberg has been my doctor for years and ...
3      4  Been going to Dr. Goldberg for over 10 years. ...
4      4  Got a letter in the mail last week that said D...


In [None]:
print("Splitting data into train, eval, and test sets...")
train_df, eval_df, test_df = split_data(df)

train_documents = train_df["text"].tolist()
train_labels = train_df["stars"].tolist()

eval_documents = eval_df["text"].tolist()
eval_labels = eval_df["stars"].tolist()

test_documents = test_df["text"].tolist()
test_labels = test_df["stars"].tolist()

Splitting data into train, eval, and test sets...


In [None]:
train_dataset = HANDataset(train_documents, train_labels, batch_size=1000, n_process=10)

Creating HAN datasets...
Loading spacy model...
Tokenizing documents...


Tokenizing: 100%|██████████| 1255411/1255411 [24:13<00:00, 863.48it/s] 


Caching tokenized documents...


In [None]:
eval_dataset = HANDataset(eval_documents, eval_labels, batch_size=1000, n_process=10)

In [None]:
test_dataset = HANDataset(test_documents, test_labels, batch_size=1000, n_process=10)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
vocab_size = len(train_dataset.vocab)
embedding_dim = 200  # Arbitrary Word2Vec embedding size

# Hyperparameters for GRU layers
word_hidden_dim = 50
sent_hidden_dim = 50
num_classes = 5  # 1-5 star ratings

print("Initializing HAN model...")
model = HAN(
    vocab_size,
    embedding_dim,
    word_hidden_dim,
    sent_hidden_dim,
    num_classes,
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
print("Training HAN model...")
train_model(
    model, train_dataloader, eval_dataloader, num_epochs=5, lr=1e-3, device=device
)