In [24]:
# ==========================================================
# Cell 1: Imports, Seeds, and Paths
# ==========================================================
import os, re, json, random
from collections import defaultdict, deque
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
DATA_DIR = "./Amazon_products"
NUM_CLASSES = 531

In [25]:
# ==========================================================
# Cell 2: Text Cleaning & Corpus Loading
# ==========================================================
TAG_RE = re.compile(r"<[^>]+>")

def clean_text(s):
    s = s or ""
    s = TAG_RE.sub(" ", s).lower()
    return re.sub(r"\s+", " ", s).strip()

def load_corpus(path):
    pids, texts = [], []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                pids.append(parts[0])
                texts.append(clean_text(parts[1]))
    return pids, texts

train_ids, train_texts = load_corpus(os.path.join(DATA_DIR, "train", "train_corpus.txt"))
test_ids, test_texts = load_corpus(os.path.join(DATA_DIR, "test", "test_corpus.txt"))

In [26]:
# ==========================================================
# Cell 3: Hierarchy & Ancestor Path Building
# ==========================================================
# Build parent mapping
parents = defaultdict(set)
with open(os.path.join(DATA_DIR, "class_hierarchy.txt"), "r") as f:
    for line in f:
        p, c = map(int, line.strip().split("\t"))
        parents[c].add(p)

def build_ancestor_list(parents_dict, num_classes, max_anc=8):
    """Climbs the hierarchy to create paths for the GAT+Path model"""
    ancestors = []
    for c in range(num_classes):
        seen, q, anc = set(), deque(list(parents_dict.get(c, set()))), []
        while q and len(anc) < max_anc:
            p = q.popleft()
            if p not in seen:
                seen.add(p); anc.append(p)
                for pp in parents_dict.get(p, set()):
                    if pp not in seen: q.append(pp)
        ancestors.append(anc)
    return ancestors

ancestor_list = build_ancestor_list(parents, NUM_CLASSES)

In [27]:
# ==========================================================
# Cell 4: Build Keyword Mapping
# ==========================================================
# Load the class names to create keywords for matching
id2name = {}
with open(os.path.join(DATA_DIR, "classes.txt"), "r", encoding="utf-8") as f:
    for line in f:
        m = re.match(r"^(\d+)\s+(.+)$", line.strip())
        if m: id2name[int(m.group(1))] = m.group(2)

# Create a mapping of keyword -> list of category IDs
kw2cids = defaultdict(list)
for cid, name in id2name.items():
    # Clean the category name to use as a keyword
    kw = name.lower().strip()
    kw2cids[kw].append(cid)

In [28]:
# ==========================================================
# Cell 4a: Silver Labeling (Core Class Generation)
# ==========================================================
# 1. Load keywords (Assumes kw2cids is built from classes.txt)
# 2. Perform match to create y_core (Original Keyword matches)
# 3. Expand y_core to y_all (Keyword matches + All Ancestors)

def generate_silver_labels(texts, kw2cids, parents, num_classes):
    y_core = np.zeros((len(texts), num_classes), dtype=np.float32)
    y_all = np.zeros((len(texts), num_classes), dtype=np.float32)
    
    for i, text in enumerate(tqdm(texts, desc="Silver Labeling")):
        matched_cids = set()
        # Check each keyword against the product text
        for kw, cids in kw2cids.items():
            if kw in text:
                matched_cids.update(cids)
        
        for cid in matched_cids:
            # y_core: The specific category found via keywords
            y_core[i, cid] = 1.0
            
            # y_all: The category + all its ancestors (Hierarchy Expansion)
            stack = [cid]
            while stack:
                curr = stack.pop()
                y_all[i, curr] = 1.0
                for p in parents.get(curr, []):
                    stack.append(p)
                    
    return y_core, y_all

# CRITICAL STEP: Call the function to define the variables
y_core, y_all = generate_silver_labels(train_texts, kw2cids, parents, NUM_CLASSES)

Silver Labeling: 100%|██████████| 29487/29487 [00:03<00:00, 7411.31it/s]


In [29]:
# ==========================================================
# Cell 5: BERT Feature Extraction
# ==========================================================
MODEL_NAME = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# We load the model and move it to GPU, then set to .eval() to disable dropout
bert_model = AutoModel.from_pretrained(MODEL_NAME).to(device).eval()

@torch.no_grad() # Disable gradient calculation to save memory/time
def get_bert_features(texts, batch_size=32):
    all_embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="BERT Encoding"):
        batch = texts[i : i+batch_size]
        # Tokenize and move to GPU
        inputs = tokenizer(batch, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
        
        # Get hidden states
        outputs = bert_model(**inputs)
        
        # MEAN POOLING: Instead of just taking the [CLS] token, we average 
        # all token embeddings for a more stable representation of the description.
        last_hidden_state = outputs.last_hidden_state # [batch, seq_len, 768]
        mask = inputs["attention_mask"].unsqueeze(-1) # [batch, seq_len, 1]
        
        # Sum embeddings and divide by the number of non-padding tokens
        sum_embeddings = torch.sum(last_hidden_state * mask, 1)
        count_embeddings = torch.clamp(mask.sum(1), min=1e-9)
        mean_pooled = sum_embeddings / count_embeddings
        
        all_embeddings.append(mean_pooled.cpu().numpy())
        
    return np.vstack(all_embeddings)

# Execute for Train and Test
X_train = get_bert_features(train_texts)
X_test = get_bert_features(test_texts)

# Generate Label Embeddings (Used by the Attention layer)
class_names = [id2name.get(i, f"Category {i}") for i in range(NUM_CLASSES)]
E_label_768 = get_bert_features(class_names) # [531, 768]

BERT Encoding: 100%|██████████| 922/922 [03:18<00:00,  4.65it/s]
BERT Encoding: 100%|██████████| 615/615 [02:21<00:00,  4.33it/s]
BERT Encoding: 100%|██████████| 17/17 [00:00<00:00, 92.74it/s]


In [32]:
# ==========================================================
# Cell 7: Save Preprocessed Data 
# ==========================================================
torch.save({
    "X_train": torch.tensor(X_train, dtype=torch.float32),
    "X_test": torch.tensor(X_test, dtype=torch.float32),
    "y_core": torch.tensor(y_core, dtype=torch.float32),
    "y_all": torch.tensor(y_all, dtype=torch.float32),
    "E_label_768": E_label_768,
    "ancestors": ancestor_list,
    "test_ids": test_ids  # <--- CRITICAL: Add this line
}, "preprocessed_features2.pth")

print("Preprocessing saved.")

Preprocessing saved.
