In [14]:
import os
import csv
import re
import random
from collections import defaultdict, Counter
from tqdm import tqdm
import numpy as np

# --- Paths ---
DATA_DIR = "Amazon_products"
TEST_DIR = os.path.join(DATA_DIR, "test")
TRAIN_DIR = os.path.join(DATA_DIR, "train")

TEST_CORPUS_PATH = os.path.join(TEST_DIR, "test_corpus.txt")
TRAIN_CORPUS_PATH = os.path.join(TRAIN_DIR, "train_corpus.txt")
CLASSES_PATH = os.path.join(DATA_DIR, "classes.txt")
HIER_PATH = os.path.join(DATA_DIR, "class_hierarchy.txt")
KEYWORD_PATH = os.path.join(DATA_DIR, "class_related_keywords.txt")

SUBMISSION_PATH = "submission.csv"

In [15]:
# --- Constants ---
NUM_CLASSES = 531
MIN_LABELS = 2
MAX_LABELS = 3
SEED = 42

random.seed(SEED)
np.random.seed(SEED)

In [16]:
# --- Utils ---
TAG_RE = re.compile(r"<[^>]+>")
_kw_regex_cache = {}

def clean_text(s: str) -> str:
    s = s or ""
    s = TAG_RE.sub(" ", s)
    s = s.lower()
    s = re.sub(r"\s+", " ", s).strip()
    return s

In [17]:
def load_corpus(path):
    pid2text = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                pid, text = parts
                pid2text[pid] = clean_text(text)
    return pid2text

In [18]:
def load_classes(path):
    name2id = {}
    id2name = {}
    next_id = 0
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = re.split(r"[\t, ]+", line)
            if parts[0].isdigit():
                cid = int(parts[0])
                cname = parts[1] if len(parts) > 1 else str(cid)
            else:
                cid = next_id
                cname = parts[0]
                next_id += 1
            name2id[cname] = cid
            id2name[cid] = cname
    return name2id, id2name

In [19]:
def load_hierarchy(path):
    parents = defaultdict(set)  # child -> parents
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            p_str, c_str = line.split("\t")
            p, c = int(p_str), int(c_str)
            parents[c].add(p)
    return parents

def get_ancestors(cid, parents):
    anc = set()
    stack = [cid]
    while stack:
        x = stack.pop()
        for p in parents.get(x, []):
            if p not in anc:
                anc.add(p)
                stack.append(p)
    return anc

In [20]:
def compute_depths(num_classes, parents):
    depth = [-1] * num_classes
    def dfs(x):
        if depth[x] != -1:
            return depth[x]
        ps = parents.get(x, [])
        if not ps:
            depth[x] = 0
            return 0
        depth[x] = 1 + max(dfs(p) for p in ps)
        return depth[x]
    for i in range(num_classes):
        dfs(i)
    return depth

In [21]:
def load_keywords_name_format(path, name2id):
    kw2cids = defaultdict(set)
    missing = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line or ":" not in line:
                continue
            cname, rest = line.split(":", 1)
            cname = cname.strip()
            if cname not in name2id:
                missing.append(cname)
                continue
            cid = name2id[cname]
            kws = [clean_text(k) for k in rest.split(",")]
            kws = [k for k in kws if k]
            for kw in kws:
                kw2cids[kw].add(cid)
    if missing:
        print("WARNING: keyword classes missing from classes.txt (showing 10):", missing[:10])
    return kw2cids

In [22]:
def kw_match(text, kw):
    if " " in kw:   # phrase
        return kw in text
    if kw not in _kw_regex_cache:
        _kw_regex_cache[kw] = re.compile(rf"\b{re.escape(kw)}\b")
    return _kw_regex_cache[kw].search(text) is not None

def score_doc(text, kw2cids, depths):
    scores = defaultdict(float)
    for kw, cids in kw2cids.items():
        if kw_match(text, kw):
            w = 1.0 / (1.0 + np.log(1 + len(cids)))  # generic penalty
            for cid in cids:
                scores[cid] += w
    for cid in list(scores.keys()):
        scores[cid] += 0.02 * depths[cid]            # specificity bonus
    return scores

In [23]:
def predict_labels(text, kw2cids, parents, depths, k_min=2, k_max=3):
    scores = score_doc(text, kw2cids, depths)

    # If nothing matched: we'll fill later using a global fallback (not hardcoded 0)
    if not scores:
        return []

    ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    seed_top = [cid for cid, _ in ranked[:k_max]]

    cand = set(seed_top)
    for cid in seed_top:
        cand |= get_ancestors(cid, parents)

    cand_ranked = sorted(list(cand), key=lambda c: (scores.get(c, 0.0), depths[c]), reverse=True)

    chosen = []
    for cid in cand_ranked:
        if cid not in chosen:
            chosen.append(cid)
        if len(chosen) >= k_max:
            break

    if len(chosen) < k_min:
        for cid, _ in ranked:
            if cid not in chosen:
                chosen.append(cid)
            if len(chosen) >= k_min:
                break

    return chosen[:k_max]

In [24]:
# --- Load resources ---
pid2text_test = load_corpus(TEST_CORPUS_PATH)
pid_list_test = list(pid2text_test.keys())

pid2text_train = load_corpus(TRAIN_CORPUS_PATH)  # only used to build fallback
name2id, id2name = load_classes(CLASSES_PATH)
parents = load_hierarchy(HIER_PATH)
depths = compute_depths(NUM_CLASSES, parents)
kw2cids = load_keywords_name_format(KEYWORD_PATH, name2id)

In [26]:
# ------------------------
# Generate train_silver_labels.csv
# (keyword + hierarchy silver labels)
# ------------------------
import pandas as pd

TRAIN_SILVER_PATH = os.path.join("output", "train_silver_labels.csv")
os.makedirs("artifacts", exist_ok=True)

train_rows = []

for pid, text in tqdm(pid2text_train.items(), desc="Generating train silver labels"):
    labels = predict_labels(
        text,
        kw2cids,
        parents,
        depths,
        k_min=MIN_LABELS,
        k_max=MAX_LABELS
    )

    # Skip low-signal rows (no keyword matched)
    if not labels:
        continue

    train_rows.append({
        "pid": pid,
        "text": text,
        "labels": ",".join(map(str, labels))
    })

train_silver_df = pd.DataFrame(train_rows)
train_silver_df.to_csv(TRAIN_SILVER_PATH, index=False)

print(f"Saved train silver labels to: {TRAIN_SILVER_PATH}")
print(f"Train silver samples: {len(train_silver_df)}")
train_silver_df.head()

Generating train silver labels: 100%|██████████| 29487/29487 [15:30<00:00, 31.70it/s]

Saved train silver labels to: artifacts/train_silver_labels.csv
Train silver samples: 19500





Unnamed: 0,pid,text,labels
0,0,omron hem 790it automatic blood pressure monit...,502493145
1,1,natural factors whey factors chocolate works w...,355455271
2,2,"clif bar builder 's bar , 2 . 4 ounce bars i l...",4988366
3,4,clif bar energy bars these were cheaper than w...,480449376
4,9,lumiscope stirrup stockings pair these are ver...,382199169


In [40]:
# --- Just checking ---
# 1) How many training samples survived?
len(train_silver_df)

# Label count distribution
train_silver_df["labels"].apply(lambda s: len(s.split(","))).value_counts()

3    18251
2     1196
1       53
Name: labels, dtype: int64

In [33]:
# 3) Top frequent labels (make sure 0 is not dominating)
from collections import Counter
freq = Counter()
for s in train_silver_df["labels"]:
    for x in s.split(","):
        freq[int(x)] += 1
freq.most_common(10)

[(0, 2563),
 (455, 2085),
 (340, 1991),
 (3, 1523),
 (271, 1342),
 (241, 1229),
 (40, 1216),
 (313, 1210),
 (220, 1144),
 (194, 1143)]

In [41]:
import pandas as pd
from collections import Counter

train_silver_df = pd.read_csv("output/train_silver_labels.csv")

freq = Counter()
for s in train_silver_df["labels"]:
    for x in str(s).split(","):
        freq[int(x)] += 1

fallback_default = [cid for cid, _ in freq.most_common(2)]  # 2 labels fallback
print("Fallback default:", fallback_default)

Fallback default: [0, 455]


In [36]:
# --- Generate predictions ---
all_pids, all_labels = [], []
for pid in tqdm(pid_list_test, desc="Generating silver-label predictions"):
    text = pid2text_test[pid]
    labels = predict_labels(text, kw2cids, parents, depths, k_min=MIN_LABELS, k_max=MAX_LABELS)

    # fill if empty
    if not labels:
        labels = fallback_default

    labels = sorted(set(labels))[:MAX_LABELS]
    # ensure at least MIN_LABELS
    if len(labels) < MIN_LABELS:
        labels = (labels + fallback_default)[:MIN_LABELS]

    all_pids.append(pid)
    all_labels.append(labels)

Generating silver-label predictions: 100%|██████████| 19658/19658 [10:44<00:00, 30.50it/s]


In [37]:
# --- Save submission file ---
with open(SUBMISSION_PATH, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["pid", "labels"])
    for pid, labels in zip(all_pids, all_labels):
        writer.writerow([pid, ",".join(map(str, labels))])

print(f"Submission file saved to: {SUBMISSION_PATH}")
print(f"Total samples: {len(all_pids)}, Labels per sample: {MIN_LABELS}-{MAX_LABELS}")

Submission file saved to: submission.csv
Total samples: 19658, Labels per sample: 2-3


In [38]:
from collections import Counter
import pandas as pd

sub = pd.read_csv("submission.csv")
cnt = sub["labels"].apply(lambda s: len(str(s).split(","))).value_counts()
print(cnt)

freq = Counter()
for s in sub["labels"]:
    for x in str(s).split(","):
        freq[int(x)] += 1
print(freq.most_common(10))

3    12244
2     7414
Name: labels, dtype: int64
[(0, 8388), (455, 7959), (340, 1338), (3, 990), (271, 914), (313, 798), (241, 793), (220, 781), (194, 774), (40, 737)]
