# 04 实验跟踪（轻量）：model_logger 接入演示

- **目标**：统一演示如何给 Trainer 传入 `model_logger` 来记录指标/超参。
- **默认**：`model_logger=None`（不记录、不要求安装任何 tracking 依赖）。

## 可选依赖
- Wandb：`pip install wandb`
- SwanLab：`pip install swanlab`
- TensorBoardX：`pip install tensorboardX`（查看：`tensorboard --logdir ./runs`）


In [None]:
# import os
# os.environ['WANDB_API_KEY'] = "your API_KEY"
# os.environ['SWANLAB_API_KEY'] = "your API_KEY"

In [1]:
import os
import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from tqdm import tqdm

from torch_rechub.basic.features import DenseFeature, SparseFeature
from torch_rechub.models.ranking import DeepFM
from torch_rechub.trainers import CTRTrainer
from torch_rechub.utils.data import DataGenerator

# 跟踪组件（按需导入，未安装时会 ImportError）
from torch_rechub.basic.tracking import WandbLogger, SwanLabLogger, TensorBoardXLogger

SEED = 2022
DEVICE = "cpu"
DATASET_PATH = "../examples/ranking/data/criteo/criteo_sample.csv"

EPOCH = 1
BATCH_SIZE = 2048
LR = 1e-3
WEIGHT_DECAY = 1e-3
EARLYSTOP_PATIENCE = 2

# 默认不启用任何 logger（保持轻量）
USE_WANDB = False
USE_SWANLAB = False
USE_TENSORBOARD = False
PROJECT_NAME = "tracking-demo"

torch.manual_seed(SEED)
print("DATASET_PATH:", os.path.abspath(DATASET_PATH))


DATASET_PATH: e:\RecommendSystemProject\torch-rechub\examples\ranking\data\criteo\criteo_sample.csv


In [2]:
def convert_numeric_feature(val):
    v = int(val)
    if v > 2:
        return int(np.log(v) ** 2)
    else:
        return v - 2


def get_criteo_data_dict(data_path):
    data = pd.read_csv(data_path, compression="gzip") if data_path.endswith(".gz") else pd.read_csv(data_path)
    dense_features = [f for f in data.columns.tolist() if f.startswith("I")]
    sparse_features = [f for f in data.columns.tolist() if f.startswith("C")]

    data[sparse_features] = data[sparse_features].fillna("0")
    data[dense_features] = data[dense_features].fillna(0)

    for feat in tqdm(dense_features, desc="discretize dense"):
        sparse_features.append(feat + "_cat")
        data[feat + "_cat"] = data[feat].apply(lambda x: convert_numeric_feature(x))

    sca = MinMaxScaler()
    data[dense_features] = sca.fit_transform(data[dense_features])

    for feat in tqdm(sparse_features, desc="label encode sparse"):
        lbe = LabelEncoder()
        data[feat] = lbe.fit_transform(data[feat])

    dense_feas = [DenseFeature(name) for name in dense_features]
    sparse_feas = [SparseFeature(name, vocab_size=data[name].nunique(), embed_dim=16) for name in sparse_features]

    y = data["label"]
    x = data.drop(columns=["label"])
    return dense_feas, sparse_feas, x, y


def build_loggers():
    """返回 list[BaseLogger] 或 None；未安装依赖时自动跳过。"""
    loggers = []

    if USE_WANDB:
        try:
            loggers.append(
                WandbLogger(
                    project=PROJECT_NAME,
                    name=f"deepfm-{SEED}",
                    config={"lr": LR, "batch_size": BATCH_SIZE, "seed": SEED},
                    tags=["criteo", "ctr", "deepfm"],
                )
            )
            print("✓ WandbLogger initialized")
        except ImportError as e:
            print("✗ Wandb not installed, skipped:", e)

    if USE_SWANLAB:
        try:
            loggers.append(
                SwanLabLogger(
                    project=PROJECT_NAME,
                    experiment_name=f"deepfm-{SEED}",
                    config={"lr": LR, "batch_size": BATCH_SIZE, "seed": SEED},
                )
            )
            print("✓ SwanLabLogger initialized")
        except ImportError as e:
            print("✗ SwanLab not installed, skipped:", e)

    if USE_TENSORBOARD:
        try:
            loggers.append(TensorBoardXLogger(log_dir=f"./runs/deepfm-{SEED}"))
            print("✓ TensorBoardXLogger initialized: ./runs/deepfm-%s" % SEED)
        except ImportError as e:
            print("✗ tensorboardX not installed, skipped:", e)

    return loggers if loggers else None


In [3]:
dense_feas, sparse_feas, x, y = get_criteo_data_dict(DATASET_PATH)
dg = DataGenerator(x, y)
train_dl, val_dl, test_dl = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=BATCH_SIZE)

model = DeepFM(
    deep_features=dense_feas,
    fm_features=sparse_feas,
    mlp_params={"dims": [64, 32], "dropout": 0.1, "activation": "relu"},
)

model_logger = build_loggers()  # None 表示不记录

# 也可以手动记录超参（Trainer 内部也会在合适时机调用 logger）
if model_logger is not None:
    for lg in model_logger:
        lg.log_hyperparams({"epoch": EPOCH, "lr": LR, "batch_size": BATCH_SIZE, "weight_decay": WEIGHT_DECAY})

ctr_trainer = CTRTrainer(
    model,
    optimizer_params={"lr": LR, "weight_decay": WEIGHT_DECAY},
    n_epoch=EPOCH,
    earlystop_patience=EARLYSTOP_PATIENCE,
    device=DEVICE,
    model_path="./",
    model_logger=model_logger,
)

ctr_trainer.fit(train_dl, val_dl)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dl)
print(f"test auc: {auc}")

# 手动补充记录一个最终指标
if model_logger is not None:
    for lg in model_logger:
        lg.log_metrics({"test/auc": float(auc)}, step=EPOCH)
        lg.finish()


discretize dense: 100%|██████████| 13/13 [00:00<00:00, 4332.96it/s]
label encode sparse: 100%|██████████| 39/39 [00:00<00:00, 7790.16it/s]

the samples of train : val : test are  80 : 11 : 24





epoch: 0


train: 100%|██████████| 1/1 [00:00<00:00, 24.02it/s]
validation: 100%|██████████| 1/1 [00:00<00:00, 308.22it/s]


epoch: 0 validation: auc: 0.19999999999999998


validation: 100%|██████████| 1/1 [00:00<00:00, 221.46it/s]

test auc: 0.2875



