In [3]:
import pandas as pd
import argparse
import os
import pickle
import json
import logging
from surprise import Dataset, Reader, SVD, NMF, KNNBasic, accuracy
from surprise.model_selection import train_test_split

In [4]:
# Logging setup
log_dir = "log"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "train_mf_model.log")
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler(log_file, encoding="utf-8"), logging.StreamHandler()],
)

In [15]:
def load_data(file_path: str):
    logging.info(f"Loading data from {file_path}")
    df = pd.read_csv(file_path)
    reader = Reader(rating_scale=(1, 5))
    data = Dataset.load_from_df(df[["userId", "productId", "rating"]], reader)
    return data


def train_model(data, algorithm="SVD", params=None):
    logging.info(f"Training model using algorithm: {algorithm}")
    trainset, testset = train_test_split(data, test_size=0.2)

    if algorithm == "SVD":
        default_params = {
            "n_factors": 20,
            "n_epochs": 50,
            "lr_all": 0.005,
            "reg_all": 0.01,
        }
        model = SVD(**(params or default_params))
    elif algorithm == "NMF":
        default_params = {"n_factors": 50, "n_epochs": 50}
        model = NMF(**(params or default_params))
    elif algorithm == "KNNBasic":
        default_params = {
            "k": 40,
            "sim_options": {"name": "cosine", "user_based": False},
        }
        model = KNNBasic(**(params or default_params))
    else:
        raise ValueError(f"Unsupported algorithm: {algorithm}")

    model.fit(trainset)
    predictions = model.test(testset)

    rmse = accuracy.rmse(predictions, verbose=False)
    mae = accuracy.mae(predictions, verbose=False)

    logging.info(f"RMSE: {rmse:.4f}, MAE: {mae:.4f}")

    return model, rmse, mae


def save_model(model, output_path="model.pkl"):
    with open(output_path, "wb") as f:
        pickle.dump(model, f)
    logging.info(f"Model saved to {output_path}")


def load_model(model_path="model.pkl"):
    try:
        with open(model_path, "rb") as f:
            model = pickle.load(f)
        logging.info(f"Model loaded from {model_path}")
        return model
    except Exception as e:
        logging.error(f"Failed to load model: {e}")
        return None


def get_user_recommendations(model, df, userId, top_k=10, exclude_purchased=True):
    all_products = df["productId"].unique()
    user_products = set(df[df["userId"] == userId]["productId"].values)

    predictions = []
    for productId in all_products:
        if exclude_purchased and productId in user_products:
            continue
        pred = model.predict(userId, productId)
        predictions.append((productId, pred.est))

    predictions.sort(key=lambda x: x[1], reverse=True)
    top_recs = [prod_id for prod_id, _ in predictions[:top_k]]
    return top_recs

In [13]:
data_path = "data/processed/reviews.csv"
output_path = "models/mf_model.pkl"
algorithm = "SVD"
custom_params = None

In [None]:
logging.info("Start training pipeline")
data = load_data(data_path)
model, rmse, mae = train_model(data, algorithm, custom_params)
save_model(model, output_path)

print("✅ Training complete")
print("RMSE:", rmse)
print("MAE:", mae)
print("Model saved to:", output_path)

2025-06-21 22:31:32,099 - INFO - Start training pipeline
2025-06-21 22:31:32,100 - INFO - Loading data from data/processed/reviews.csv
2025-06-21 22:31:32,810 - INFO - Training model using algorithm: SVD
2025-06-21 22:31:39,185 - INFO - RMSE: 0.5562, MAE: 0.2961
2025-06-21 22:31:39,636 - INFO - Model saved to models/mf_model.pkl


✅ Training complete
RMSE: 0.556219214330712
MAE: 0.2960820205624096
Model saved to: models/mf_model.pkl


In [17]:
df = pd.read_csv(data_path)
mf_model = load_model(output_path)

2025-06-21 22:34:58,989 - INFO - Model loaded from models/mf_model.pkl


In [18]:
get_user_recommendations(mf_model, df, 21665899)

[140012510,
 14453314,
 1667493,
 167935897,
 175798776,
 187183513,
 191782133,
 194130726,
 196001721,
 201845361]