In [6]:
import json
import random
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader



from sentence_transformers import SentenceTransformer
import nltk

nltk.download("punkt")
nltk.download("punkt_tab")

from nltk.tokenize import sent_tokenize
import pandas as pd

[nltk_data] Downloading package punkt to /Users/benjamin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/benjamin/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [None]:
df = pd.read_json("Digital_Music_5.json", lines=True)


df.head()

Unnamed: 0,reviewerID,asin,reviewerName,helpful,reviewText,overall,summary,unixReviewTime,reviewTime
0,A3EBHHCZO6V2A4,5555991584,"Amaranth ""music fan""","[3, 3]","It's hard to believe ""Memory of Trees"" came ou...",5,Enya's last great album,1158019200,"09 12, 2006"
1,AZPWAXJG9OJXV,5555991584,bethtexas,"[0, 0]","A clasically-styled and introverted album, Mem...",5,Enya at her most elegant,991526400,"06 3, 2001"
2,A38IRL0X2T4DPF,5555991584,bob turnley,"[2, 2]",I never thought Enya would reach the sublime h...,5,The best so far,1058140800,"07 14, 2003"
3,A22IK3I6U76GX0,5555991584,Calle,"[1, 1]",This is the third review of an irish album I w...,5,Ireland produces good music.,957312000,"05 3, 2000"
4,A1AISPOIIHTHXX,5555991584,"Cloud ""...""","[1, 1]","Enya, despite being a successful recording art...",4,4.5; music to dream to,1200528000,"01 17, 2008"


In [7]:
len(df)

64706

In [None]:

# -----------------------------
# Config
# -----------------------------

JSON_PATH = "Digital_Music_5.json"       
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
MAX_REVIEWS: Optional[int] = 100_000  
TRAIN_FRACTION = 0.8
VAL_FRACTION = 0.1
SEED = 42
BATCH_SIZE = 256
EPOCHS = 15
LR = 1e-3


# -----------------------------
# Data structures
# -----------------------------

@dataclass
class Interaction:
    user: str
    item: str
    rating: float
    review_emb: np.ndarray  


# -----------------------------
# SBERT
# -----------------------------

def build_sbert_model(model_name: str = MODEL_NAME) -> SentenceTransformer:
    model = SentenceTransformer(model_name)
    return model


def encode_review(text: str, model: SentenceTransformer) -> np.ndarray:
    sents = sent_tokenize(text)
    if not sents:
        return None
    emb = model.encode(sents, convert_to_numpy=True)
    return emb.mean(axis=0)

def encode_reviews_with_sbert(
    json_path: str,
    model: SentenceTransformer,
    max_reviews: Optional[int] = None,
) -> List[Interaction]:
    interactions: List[Interaction] = []
    with open(json_path, "r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            if max_reviews is not None and len(interactions) >= max_reviews:
                break

            line = line.strip()
            if not line:
                continue
            try:
                r = json.loads(line)
            except json.JSONDecodeError:
                continue

            # minimal required fields
            if "reviewerID" not in r or "asin" not in r or "overall" not in r or "reviewText" not in r:
                continue

            text = r["reviewText"]
            emb = encode_review(text, model)
            if emb is None:
                continue

            inter = Interaction(
                user=str(r["reviewerID"]),
                item=str(r["asin"]),
                rating=float(r["overall"]),
                review_emb=emb,
            )
            interactions.append(inter)

    return interactions


# -----------------------------
# Split
# -----------------------------

def split_interactions(
    interactions: List[Interaction],
    train_frac: float = TRAIN_FRACTION,
    val_frac: float = VAL_FRACTION,
    seed: int = SEED,
) -> Tuple[List[Interaction], List[Interaction], List[Interaction]]:
    rand = random.Random(seed)
    shuffled = interactions[:]
    rand.shuffle(shuffled)

    n = len(shuffled)
    n_train = int(n * train_frac)
    n_val = int(n * val_frac)
    n_test = n - n_train - n_val

    train = shuffled[:n_train]
    val = shuffled[n_train:n_train + n_val]
    test = shuffled[n_train + n_val:]
    assert len(train) + len(val) + len(test) == n

    return train, val, test


# -----------------------------
# Embedding Aggregation
# -----------------------------

def aggregate_embeddings(
    interactions: List[Interaction],
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
    user_sum: Dict[str, np.ndarray] = {}
    user_count: Dict[str, int] = {}
    item_sum: Dict[str, np.ndarray] = {}
    item_count: Dict[str, int] = {}

    for inter in interactions:
        u, i, emb = inter.user, inter.item, inter.review_emb

        # user
        if u not in user_sum:
            user_sum[u] = emb.copy()
            user_count[u] = 1
        else:
            user_sum[u] += emb
            user_count[u] += 1

        # item
        if i not in item_sum:
            item_sum[i] = emb.copy()
            item_count[i] = 1
        else:
            item_sum[i] += emb
            item_count[i] += 1

    user_vecs = {u: user_sum[u] / user_count[u] for u in user_sum}
    item_vecs = {i: item_sum[i] / item_count[i] for i in item_sum}

    return user_vecs, item_vecs


# -----------------------------
# PyTorch Dataset
# -----------------------------

class RatingDataset(Dataset):
    def __init__(
        self,
        interactions: List[Interaction],
        user_vecs: Dict[str, np.ndarray],
        item_vecs: Dict[str, np.ndarray],
    ):
        samples_x = []
        samples_y = []

        for inter in interactions:
            if inter.user not in user_vecs or inter.item not in item_vecs:
                continue

            u_vec = user_vecs[inter.user]
            i_vec = item_vecs[inter.item]
            x = np.concatenate([u_vec, i_vec], axis=0).astype(np.float32)
            y = np.float32(inter.rating)

            samples_x.append(x)
            samples_y.append(y)

        self.x = np.stack(samples_x) if samples_x else np.zeros((0,))
        self.y = np.array(samples_y)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.x[idx])
        y = torch.tensor(self.y[idx], dtype=torch.float32)
        return x, y


# -----------------------------
# Rating Model (MLP)
# -----------------------------

class RatingMLP(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        return self.net(x).squeeze(-1)  


def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    epochs: int = EPOCHS,
    lr: float = LR,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        n_batches = 0

        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            preds = model(x)
            loss = loss_fn(preds, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            n_batches += 1

        avg_train_loss = total_loss / max(1, n_batches)

        val_mse = evaluate_mse(model, val_loader, device=device)

        print(f"Epoch {epoch}: train MSE={avg_train_loss:.4f}, val MSE={val_mse:.4f}")


def evaluate_mse(
    model: nn.Module,
    loader: DataLoader,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> float:
    model.eval()
    loss_fn = nn.MSELoss()
    total_loss = 0.0
    n_batches = 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            preds = model(x)
            loss = loss_fn(preds, y)
            total_loss += loss.item()
            n_batches += 1

    return total_loss / max(1, n_batches)


# -----------------------------
# Main
# -----------------------------

def main():
    print("Loading SBERT model...")
    sbert = build_sbert_model(MODEL_NAME)

    print("Encoding reviews with SBERT...")
    interactions = encode_reviews_with_sbert(
        JSON_PATH,
        model=sbert,
        max_reviews=MAX_REVIEWS,
    )
    print(f"Encoded {len(interactions)} interactions.")

    if len(interactions) < 100:
        print("Warning: very few interactions; consider increasing MAX_REVIEWS.")

    train, val, test = split_interactions(interactions)
    print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")

    user_vecs, item_vecs = aggregate_embeddings(train)
    emb_dim = next(iter(user_vecs.values())).shape[0]
    input_dim = emb_dim * 2

    train_ds = RatingDataset(train, user_vecs, item_vecs)
    val_ds = RatingDataset(val, user_vecs, item_vecs)
    test_ds = RatingDataset(test, user_vecs, item_vecs)

    print(
        f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}, Test samples: {len(test_ds)}"
    )

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

    model = RatingMLP(input_dim=input_dim)
    print("Training model...")
    train_model(model, train_loader, val_loader)

    test_mse = evaluate_mse(model, test_loader)
    print(f"Final TEST MSE: {test_mse:.4f}")


if __name__ == "__main__":
    main()

Loading SBERT model...
Encoding reviews with SBERT...
Encoded 64705 interactions.
Train: 51764, Val: 6470, Test: 6471
Train samples: 51764, Val samples: 6470, Test samples: 6471
Training model...
Epoch 1: train MSE=2.7165, val MSE=1.0777
Epoch 2: train MSE=0.9697, val MSE=1.0021
Epoch 3: train MSE=0.9050, val MSE=0.9869
Epoch 4: train MSE=0.8853, val MSE=0.9836
Epoch 5: train MSE=0.8730, val MSE=0.9724
Epoch 6: train MSE=0.8655, val MSE=0.9730
Epoch 7: train MSE=0.8616, val MSE=0.9713
Epoch 8: train MSE=0.8546, val MSE=0.9678
Epoch 9: train MSE=0.8513, val MSE=0.9645
Epoch 10: train MSE=0.8463, val MSE=0.9602
Epoch 11: train MSE=0.8407, val MSE=0.9566
Epoch 12: train MSE=0.8366, val MSE=0.9543
Epoch 13: train MSE=0.8227, val MSE=0.9442
Epoch 14: train MSE=0.8200, val MSE=0.9458
Epoch 15: train MSE=0.8139, val MSE=0.9722
Final TEST MSE: 0.9844
