# DLRM Training for Book Recommendation System


## 1: Import Libraries and Setup Environment


In [1]:
import contextlib
import tempfile
from typing import Generator
import csv
import os
import random
import pickle
import sys
import itertools
from functools import partial
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Iterable, cast

# PyTorch and distributed training
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.distributed._sharded_tensor import ShardedTensor

# TorchRec for DLRM
from torchrec import EmbeddingBagCollection
from torchrec.distributed import TrainPipelineSparseDist
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_parallel import (
    DistributedModelParallel,
    get_default_sharders,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.storage_reservations import (
    HeuristicalStorageReservation,
)
from torchrec.models.dlrm import DLRM, DLRM_DCN, DLRM_Projection, DLRMTrain
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.datasets.utils import Batch
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

# Streaming data
from streaming import StreamingDataset, StreamingDataLoader
from streaming.base import MDSWriter

# Metrics and utilities
import torchmetrics as metrics
from tqdm import tqdm
from pyre_extensions import none_throws
import mlflow


  from torch.distributed._sharded_tensor import ShardedTensor
  ALLREDUCE = partial(_ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook)
  FP16_COMPRESS = partial(
  BF16_COMPRESS = partial(
  QUANTIZE_PER_TENSOR = partial(
  QUANTIZE_PER_CHANNEL = partial(
  POWER_SGD = partial(
  POWER_SGD_RANK2 = partial(
  BATCHED_POWER_SGD = partial(
  BATCHED_POWER_SGD_RANK2 = partial(
  NOOP = partial(
  return torch._C._cuda_getDeviceCount() > 0
  _register_pytree_node(JaggedTensor, _jt_flatten, _jt_unflatten)
  _register_pytree_node(
  return arg(*args, **kwargs)
  _register_pytree_node(KeyedTensor, _kt_flatten, _kt_unflatten)


## 2: Setup Environment Variables for Distributed Training


In [2]:
# Setup environment for single GPU/CPU training
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0" 
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

# Setup device and backend
device = torch.device("cpu")  # Using CPU for compatibility
backend = "gloo"

print(f"Device: {device}")

Device: cpu


## 3: Load Preprocessing Information


In [5]:
print("Loading book data preprocessing info...")

with open('/home/mr-behdadi/PROJECT/ICE/book_dlrm_preprocessing.pkl', 'rb') as f:
        preprocessing_info = pickle.load(f)
    
dense_cols = preprocessing_info['dense_cols']
cat_cols = preprocessing_info['cat_cols'] 
emb_counts = preprocessing_info['emb_counts']
    
print("✅ Preprocessing info loaded successfully")
print(f"Dense columns ({len(dense_cols)}): {dense_cols}")
print(f"Categorical columns ({len(cat_cols)}): {cat_cols}")
print(f"Embedding counts: {dict(zip(cat_cols, emb_counts))}")
    

Loading book data preprocessing info...
✅ Preprocessing info loaded successfully
Dense columns (6): ['age_normalized', 'year_normalized', 'user_activity', 'book_popularity', 'user_avg_rating', 'book_avg_rating']
Categorical columns (7): ['user_id_encoded', 'book_id_encoded', 'publisher_encoded', 'country_encoded', 'age_group', 'decade_encoded', 'rating_level']
Embedding counts: {'user_id_encoded': 2578, 'book_id_encoded': 4313, 'publisher_encoded': 315, 'country_encoded': 45, 'age_group': 6, 'decade_encoded': 7, 'rating_level': 4}


## 4: Define Training Arguments Dataclass


In [6]:
@dataclass
class BookDLRMArgs:
    """Training arguments for book recommendation DLRM"""
    epochs: int = 3
    embedding_dim: int = 64
    dense_arch_layer_sizes: list = field(default_factory=lambda: [256, 128, 64])  
    over_arch_layer_sizes: list = field(default_factory=lambda: [512, 256, 128, 1])
    learning_rate: float = 0.01
    eps: float = 1e-8
    batch_size: int = 512
    print_sharding_plan: bool = True
    print_lr: bool = False
    lr_warmup_steps: int = 0  # Set to 0 to avoid scheduler issues
    lr_decay_start: int = 0
    lr_decay_steps: int = 0
    validation_freq: int = None
    limit_train_batches: int = None
    limit_val_batches: int = None
    limit_test_batches: int = None

@dataclass
class TrainValTestResults:
    """Results storage for training"""
    val_aurocs: List[float] = field(default_factory=list)
    test_auroc: Optional[float] = None

print("✅ Training arguments dataclass defined")

✅ Training arguments dataclass defined


## 5: Define Batch Transformation Function


In [7]:
def transform_to_torchrec_batch(batch, num_embeddings_per_feature: Optional[List[int]] = None) -> Batch:
    """Transform batch to TorchRec format for book recommendation data"""
    # Dense features
    cat_list = []
    for col_name in dense_cols:
        val = torch.tensor(batch[col_name], dtype=torch.float32)
        cat_list.append(val.unsqueeze(0).T)
    dense_features = torch.cat(cat_list, dim=1)

    # Sparse features
    kjt_values: List[int] = []
    kjt_lengths: List[int] = []
    for col_idx, col_name in enumerate(cat_cols):
        values = batch[col_name]
        for value in values:
            if value is not None and value >= 0:
                kjt_values.append(
                    int(value) % num_embeddings_per_feature[col_idx]
                )
                kjt_lengths.append(1)
            else:
                kjt_lengths.append(0)

    sparse_features = KeyedJaggedTensor.from_lengths_sync(
        cat_cols,
        torch.tensor(kjt_values),
        torch.tensor(kjt_lengths, dtype=torch.int32),
    )
    
    # Labels
    labels = torch.tensor(batch["label"], dtype=torch.int32)
    assert isinstance(labels, torch.Tensor)

    return Batch(
        dense_features=dense_features,
        sparse_features=sparse_features,
        labels=labels,
    )

print("✅ Batch transformation function defined")

✅ Batch transformation function defined


## 6: Create Transform Partial Function

In [8]:
# Create partial function with embedding counts
if emb_counts:
    transform_partial = partial(transform_to_torchrec_batch, num_embeddings_per_feature=emb_counts)
    print("✅ Transform partial function created")
else:
    print("❌ Cannot create transform function - missing embedding counts")
    transform_partial = None

✅ Transform partial function created


## 7: Define DataLoader Creation Function


In [9]:
def get_dataloader_with_mosaic(path, batch_size, label):
    """Get DataLoader for book recommendation data"""
    print(f"Getting {label} data from {path}")
    try:
        dataset = StreamingDataset(local=path, shuffle=True, batch_size=batch_size)
        dataloader = StreamingDataLoader(dataset, batch_size=batch_size)
        print(f"✅ {label} dataloader created successfully")
        return dataloader
    except Exception as e:
        print(f"❌ Error creating {label} dataloader: {e}")
        return None

print("✅ DataLoader creation function defined")

✅ DataLoader creation function defined


## 8: Define Learning Rate Scheduler


In [10]:
class LRPolicyScheduler(_LRScheduler):
    """Learning rate scheduler with warmup and decay"""
    def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps):
        self.num_warmup_steps = num_warmup_steps
        self.decay_start_step = decay_start_step
        self.decay_end_step = decay_start_step + num_decay_steps
        self.num_decay_steps = num_decay_steps

        if self.decay_start_step < self.num_warmup_steps:
            print("Warning: Learning rate warmup must finish before the decay starts")

        super(LRPolicyScheduler, self).__init__(optimizer)

    def get_lr(self):
        step_count = self._step_count
        if step_count < self.num_warmup_steps:
            # warmup
            scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps
            lr = [base_lr * scale for base_lr in self.base_lrs]
            self.last_lr = lr
        elif self.decay_start_step <= step_count and step_count < self.decay_end_step:
            # decay
            decayed_steps = step_count - self.decay_start_step
            scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2
            min_lr = 0.0000001
            lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs]
            self.last_lr = lr
        else:
            if self.num_decay_steps > 0:
                lr = self.last_lr
            else:
                lr = self.base_lrs
        return lr

print("✅ Learning rate scheduler defined")

✅ Learning rate scheduler defined


## 9: Define Utility Functions


In [11]:
def batched(it, n):
    """Helper function for batching"""
    assert n >= 1
    for x in it:
        yield itertools.chain((x,), itertools.islice(it, n - 1))

def get_relevant_fields(args, dense_cols, cat_cols, emb_counts):
    """Get relevant fields for MLflow logging"""
    fields_to_save = ["epochs", "embedding_dim", "dense_arch_layer_sizes", 
                     "over_arch_layer_sizes", "learning_rate", "eps", "batch_size"]
    result = {key: getattr(args, key) for key in fields_to_save}
    result["dense_cols"] = dense_cols
    result["cat_cols"] = cat_cols
    result["emb_counts"] = emb_counts
    return result

print("✅ Utility functions defined")

✅ Utility functions defined


## 10: Define Training Function


In [12]:
def train(
    pipeline: TrainPipelineSparseDist,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    epoch: int,
    lr_scheduler,
    print_lr: bool,
    validation_freq: Optional[int],
    limit_train_batches: Optional[int],
    limit_val_batches: Optional[int]) -> None:
    """Train model for 1 epoch"""
    
    pipeline._model.train()
    iterator = itertools.islice(iter(train_dataloader), limit_train_batches)
    
    is_rank_zero = dist.get_rank() == 0
    if is_rank_zero:
        pbar = tqdm(
            iter(int, 1),
            desc=f"Epoch {epoch + 1}",
            total=len(train_dataloader),
            disable=False,
        )

    start_it = 0
    n = validation_freq if validation_freq else len(train_dataloader)
    for batched_iterator in batched(iterator, n):
        for it in itertools.count(start_it):
            try:
                if is_rank_zero and print_lr:
                    for i, g in enumerate(pipeline._optimizer.param_groups):
                        print(f"lr: {it} {i} {g['lr']:.6f}")
                pipeline.progress(map(transform_partial, batched_iterator))
                lr_scheduler.step()
                if is_rank_zero:
                    pbar.update(1)
            except StopIteration:
                if is_rank_zero:
                    print(f"Completed {it} training iterations")
                start_it = it
                break

        if validation_freq and start_it % validation_freq == 0:
            evaluate(limit_val_batches, pipeline, val_dataloader, "val")
            pipeline._model.train()

print("✅ Training function defined")

✅ Training function defined


## 11: Define Evaluation Function


In [13]:
def evaluate(
    limit_batches: Optional[int],
    pipeline: TrainPipelineSparseDist,
    eval_dataloader: DataLoader,
    stage: str) -> float:
    """Evaluate model and compute AUROC"""
    
    pipeline._model.eval()
    device = pipeline._device

    iterator = itertools.islice(iter(eval_dataloader), limit_batches)
    auroc = metrics.AUROC(task="binary").to(device)

    is_rank_zero = dist.get_rank() == 0
    if is_rank_zero:
        pbar = tqdm(
            iter(int, 1),
            desc=f"Evaluating {stage} set",
            total=len(eval_dataloader),
            disable=False,
        )
    
    with torch.no_grad():
        while True:
            try:
                _loss, logits, labels = pipeline.progress(map(transform_partial, iterator))
                preds = torch.sigmoid(logits)
                auroc(preds, labels)
                if is_rank_zero:
                    pbar.update(1)
            except StopIteration:
                break

    auroc_result = auroc.compute().item()
    num_samples = torch.tensor(sum(map(len, auroc.target)), device=device)
    dist.reduce(num_samples, 0, op=dist.ReduceOp.SUM)

    if is_rank_zero:
        print(f"AUROC over {stage} set: {auroc_result:.4f}")
        print(f"Number of {stage} samples: {num_samples}")
    return auroc_result

print("✅ Evaluation function defined")


✅ Evaluation function defined


## 12: Define Complete Training Loop Function


In [14]:
def train_val_test(args, model, optimizer, device, train_dataloader, val_dataloader, test_dataloader, lr_scheduler) -> TrainValTestResults:
    """Complete training loop"""
    
    results = TrainValTestResults()
    pipeline = TrainPipelineSparseDist(model, optimizer, device, execute_all_batches=True)
    
    # Initial validation
    print("Running initial validation...")
    val_auroc = evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
    results.val_aurocs.append(val_auroc)
    if int(os.environ["RANK"]) == 0:
        mlflow.log_metric('val_auroc', val_auroc, step=0)

    # Training loop
    for epoch in range(args.epochs):
        print(f"\n=== Epoch {epoch + 1}/{args.epochs} ===")
        
        train(
            pipeline,
            train_dataloader,
            val_dataloader,
            epoch,
            lr_scheduler,
            args.print_lr,
            args.validation_freq,
            args.limit_train_batches,
            args.limit_val_batches,
        )

        # Validate after each epoch
        val_auroc = evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
        results.val_aurocs.append(val_auroc)
        if int(os.environ["RANK"]) == 0:
            mlflow.log_metric('val_auroc', val_auroc, step=epoch + 1)
            
            # Save model state
            model_path = f"dlrm_book_model_epoch_{epoch}.pth"
            torch.save(pipeline._model.state_dict(), model_path)
            print(f"Model saved: {model_path}")

    # Final test evaluation
    print("\nRunning final test evaluation...")
    test_auroc = evaluate(args.limit_test_batches, pipeline, test_dataloader, "test")
    results.test_auroc = test_auroc
    if int(os.environ["RANK"]) == 0:
        mlflow.log_metric('test_auroc', test_auroc)
        
        # Save final model
        final_model_path = "dlrm_book_model_final.pth"
        torch.save(pipeline._model.state_dict(), final_model_path)
        print(f"Final model saved: {final_model_path}")
        
    return results

print("✅ Complete training loop function defined")


✅ Complete training loop function defined


## 13: Initialize Training Arguments


In [15]:
# Cell 13: Initialize Training Arguments
# Create training arguments optimized for book recommendation
args = BookDLRMArgs(
    epochs=3,
    embedding_dim=64,
    batch_size=512,
    learning_rate=0.01,
    lr_warmup_steps=0,
    dense_arch_layer_sizes=[256, 128, 64],
    over_arch_layer_sizes=[512, 256, 128, 1]
)

print("✅ Training arguments initialized")
print(f"Training configuration:")
print(f"  Epochs: {args.epochs}")
print(f"  Batch size: {args.batch_size}")
print(f"  Embedding dim: {args.embedding_dim}")
print(f"  Learning rate: {args.learning_rate}")
print(f"  Dense features: {len(dense_cols)}")
print(f"  Categorical features: {len(cat_cols)}")

✅ Training arguments initialized
Training configuration:
  Epochs: 3
  Batch size: 512
  Embedding dim: 64
  Learning rate: 0.01
  Dense features: 6
  Categorical features: 7


## 14: Setup MLflow Experiment


In [16]:
# Setup MLflow experiment
username = "book_recommender"
experiment_path = f'dlrm-book-recommendation-{username}'

try:
    experiment = mlflow.set_experiment(experiment_path)
    print(f"✅ MLflow experiment set: {experiment_path}")
    
    # Log parameters
    param_dict = get_relevant_fields(args, dense_cols, cat_cols, emb_counts)
    mlflow.log_params(param_dict)
    print("✅ Parameters logged to MLflow")
    
except Exception as e:
    print(f"⚠️ MLflow setup warning: {e}")
    print("Training will continue without MLflow logging")

2025/09/13 21:54:48 INFO mlflow.tracking.fluent: Experiment with name 'dlrm-book-recommendation-book_recommender' does not exist. Creating a new experiment.


✅ MLflow experiment set: dlrm-book-recommendation-book_recommender
✅ Parameters logged to MLflow


## 15: Initialize Distributed Training


In [17]:
# Initialize distributed training
print("Initializing distributed training...")
try:
    # Disable JIT for compatibility
    torch.jit._state.disable()
    
    # Initialize process group
    dist.init_process_group(backend=backend)
    
    global_rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    
    print(f"✅ Distributed training initialized")
    print(f"Global rank: {global_rank}")
    print(f"Local rank: {local_rank}")
    print(f"Device: {device}")
    
except Exception as e:
    print(f"❌ Error initializing distributed training: {e}")


Initializing distributed training...
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
✅ Distributed training initialized
Global rank: 0
Local rank: 0
Device: cpu


## 16: Load Training Data

In [18]:
# Load data
print("Loading book recommendation data...")

# Check if data directories exist
data_paths = {
    "train": "dlrm_book_data/mds_train",
    "validation": "dlrm_book_data/mds_validation", 
    "test": "dlrm_book_data/mds_test"
}

missing_paths = [path for path in data_paths.values() if not os.path.exists(path)]

if missing_paths:
    print("❌ Missing data directories:")
    for path in missing_paths:
        print(f"   - {path}")
    print("Please run the preprocessing notebook first.")
    train_dataloader = val_dataloader = test_dataloader = None
else:
    train_dataloader = get_dataloader_with_mosaic(data_paths["train"], args.batch_size, "train")
    val_dataloader = get_dataloader_with_mosaic(data_paths["validation"], args.batch_size, "val")
    test_dataloader = get_dataloader_with_mosaic(data_paths["test"], args.batch_size, "test")
    
    print("✅ All dataloaders created successfully")




Loading book recommendation data...
Getting train data from dlrm_book_data/mds_train
✅ train dataloader created successfully
Getting val data from dlrm_book_data/mds_validation
✅ val dataloader created successfully
Getting test data from dlrm_book_data/mds_test
✅ test dataloader created successfully
✅ All dataloaders created successfully


## 17: Create DLRM Model


In [19]:
# Create embedding configs for book recommendation
if cat_cols and emb_counts:
    print("Creating embedding configurations...")
    eb_configs = [
        EmbeddingBagConfig(
            name=f"t_{feature_name}",
            embedding_dim=args.embedding_dim,
            num_embeddings=emb_counts[feature_idx],
            feature_names=[feature_name],
        )
        for feature_idx, feature_name in enumerate(cat_cols)
    ]
    
    print(f"✅ Created {len(eb_configs)} embedding configurations")
    
    # Create DLRM model for book recommendation
    print("Creating DLRM model...")
    dlrm_model = DLRM(
        embedding_bag_collection=EmbeddingBagCollection(
            tables=eb_configs, device=device
        ),
        dense_in_features=len(dense_cols),
        dense_arch_layer_sizes=args.dense_arch_layer_sizes,
        over_arch_layer_sizes=args.over_arch_layer_sizes,
        dense_device=device,
    )

    train_model = DLRMTrain(dlrm_model)
    model = train_model.to(device)
    
    num_params = sum(p.numel() for p in model.parameters())
    print(f"✅ DLRM model created with {num_params:,} parameters")
    
else:
    print("❌ Cannot create model - missing categorical columns or embedding counts")
    model = None


Creating embedding configurations...
✅ Created 7 embedding configurations
Creating DLRM model...
✅ DLRM model created with 720,065 parameters


## 18: Setup Optimizer and Learning Rate Scheduler  


In [20]:
if model is not None:
    # Setup optimizer
    print("Setting up optimizer...")
    optimizer = torch.optim.Adagrad(model.parameters(), lr=args.learning_rate, eps=args.eps)
    
    # Setup learning rate scheduler
    lr_scheduler = LRPolicyScheduler(
        optimizer, 
        args.lr_warmup_steps, 
        args.lr_decay_start, 
        args.lr_decay_steps
    )
    
    print("✅ Optimizer and learning rate scheduler created")
    
    # Save initial model
    try:
        initial_model_path = "dlrm_book_model_initial.pth"
        torch.save(model.state_dict(), initial_model_path)
        print(f"✅ Initial model saved: {initial_model_path}")
    except Exception as e:
        print(f"⚠️ Warning: Could not save initial model: {e}")
        
else:
    print("❌ Cannot setup optimizer - model not created")
    optimizer = lr_scheduler = None


Setting up optimizer...
✅ Optimizer and learning rate scheduler created
✅ Initial model saved: dlrm_book_model_initial.pth


## 19: Run Training


In [21]:
# Start training
if all([model, optimizer, train_dataloader, val_dataloader, test_dataloader, transform_partial]):
    print("\n🚀 Starting DLRM training for book recommendation...")
    print("=" * 60)
    
    try:
        results = train_val_test(
            args,
            model,
            optimizer,
            device,
            train_dataloader,
            val_dataloader,
            test_dataloader,
            lr_scheduler,
        )
        
        print("\n🎉 Training completed successfully!")
        print(f"📊 Final Results:")
        print(f"   Final validation AUROC: {results.val_aurocs[-1]:.4f}")
        print(f"   Test AUROC: {results.test_auroc:.4f}")
        print(f"   Best validation AUROC: {max(results.val_aurocs):.4f}")
        
        training_completed = True
        
    except Exception as e:
        print(f"❌ Error during training: {e}")
        results = None
        training_completed = False
        
else:
    print("❌ Cannot start training - missing required components")
    missing_components = []
    if model is None: missing_components.append("model")
    if optimizer is None: missing_components.append("optimizer") 
    if train_dataloader is None: missing_components.append("train_dataloader")
    if val_dataloader is None: missing_components.append("val_dataloader")
    if test_dataloader is None: missing_components.append("test_dataloader")
    if transform_partial is None: missing_components.append("transform_partial")
    
    print(f"Missing: {', '.join(missing_components)}")
    results = None
    training_completed = False


🚀 Starting DLRM training for book recommendation...
Running initial validation...


  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)
  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)
Evaluating val set:  95%|█████████▍| 73/77 [00:08<00:00, 16.12it/s]

AUROC over val set: 0.6262
Number of val samples: 39389

=== Epoch 1/3 ===


Evaluating val set: 100%|██████████| 77/77 [00:09<00:00,  8.14it/s]
  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)


Completed 270 training iterations


Epoch 1: 100%|██████████| 270/270 [00:33<00:00,  8.10it/s]
  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)
Evaluating val set:  97%|█████████▋| 75/77 [00:09<00:00, 15.81it/s]

AUROC over val set: 0.9990
Number of val samples: 39389
Model saved: dlrm_book_model_epoch_0.pth

=== Epoch 2/3 ===


Evaluating val set: 100%|██████████| 77/77 [00:09<00:00,  7.76it/s]
  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)
Epoch 2: 100%|██████████| 270/270 [00:33<00:00,  7.98it/s]


Completed 270 training iterations


  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)
Evaluating val set: 100%|██████████| 77/77 [00:08<00:00,  8.82it/s]


AUROC over val set: 0.9989
Number of val samples: 39389
Model saved: dlrm_book_model_epoch_1.pth

=== Epoch 3/3 ===


  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)
Epoch 3: 100%|██████████| 270/270 [00:35<00:00,  7.67it/s]


Completed 270 training iterations


  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)
Evaluating val set: 100%|██████████| 77/77 [00:08<00:00,  8.83it/s]


AUROC over val set: 0.9989
Number of val samples: 39389
Model saved: dlrm_book_model_epoch_2.pth

Running final test evaluation...


  val = torch.tensor(batch[col_name], dtype=torch.float32)
  labels = torch.tensor(batch["label"], dtype=torch.int32)
Evaluating test set: 100%|██████████| 39/39 [00:04<00:00,  8.97it/s]

AUROC over test set: 0.9989
Number of test samples: 19714
Final model saved: dlrm_book_model_final.pth

🎉 Training completed successfully!
📊 Final Results:
   Final validation AUROC: 0.9989
   Test AUROC: 0.9989
   Best validation AUROC: 0.9990





# Cell 20: Save Training Results and Cleanup


In [22]:
if training_completed and results is not None:
    print("💾 Saving training results...")
    
    try:
        # Save comprehensive results
        results_dict = {
            'final_val_auroc': results.val_aurocs[-1],
            'test_auroc': results.test_auroc,
            'best_val_auroc': max(results.val_aurocs),
            'val_aurocs_history': results.val_aurocs,
            'args': asdict(args),
            'preprocessing_info': {
                'dense_cols': dense_cols,
                'cat_cols': cat_cols,
                'emb_counts': emb_counts
            },
            'model_info': {
                'num_parameters': sum(p.numel() for p in model.parameters()),
                'embedding_dim': args.embedding_dim,
                'dense_features': len(dense_cols),
                'categorical_features': len(cat_cols)
            }
        }
        
        with open('dlrm_book_training_results.pkl', 'wb') as f:
            pickle.dump(results_dict, f)
        
        print("✅ Results saved to dlrm_book_training_results.pkl")
        
        # Display training summary
        print(f"\n📋 Training Summary:")
        print(f"   - Model parameters: {results_dict['model_info']['num_parameters']:,}")
        print(f"   - Training epochs: {args.epochs}")
        print(f"   - Batch size: {args.batch_size}")
        print(f"   - Learning rate: {args.learning_rate}")
        print(f"   - Embedding dimension: {args.embedding_dim}")
        print(f"   - Dense features: {len(dense_cols)}")
        print(f"   - Categorical features: {len(cat_cols)}")
        print(f"   - Final test AUROC: {results.test_auroc:.4f}")
        print(f"   - Best validation AUROC: {max(results.val_aurocs):.4f}")
        
    except Exception as e:
        print(f"⚠️ Warning: Could not save results: {e}")

else:
    print("❌ No results to save - training was not completed successfully")

# Cleanup distributed training
try:
    if dist.is_initialized():
        dist.destroy_process_group()
        print("✅ Distributed training cleanup completed")
except Exception as e:
    print(f"⚠️ Warning during cleanup: {e}")

print("\n" + "="*60)
print("🏁 DLRM Book Recommendation Training Notebook Complete!")
print("="*60)


💾 Saving training results...
✅ Results saved to dlrm_book_training_results.pkl

📋 Training Summary:
   - Model parameters: 720,065
   - Training epochs: 3
   - Batch size: 512
   - Learning rate: 0.01
   - Embedding dimension: 64
   - Dense features: 6
   - Categorical features: 7
   - Final test AUROC: 0.9989
   - Best validation AUROC: 0.9990
✅ Distributed training cleanup completed

🏁 DLRM Book Recommendation Training Notebook Complete!


## 21: Load and Test Saved Model (Optional)


In [23]:

# Optional cell to test loading the saved model
try:
    print("🧪 Testing saved model loading...")
    
    # Load model state
    model_state = torch.load("dlrm_book_model_final.pth", map_location=device)
    print("✅ Model state loaded successfully")
    
    # Load training results
    with open('dlrm_book_training_results.pkl', 'rb') as f:
        saved_results = pickle.load(f)
    
    print("✅ Training results loaded successfully")
    print(f"Saved test AUROC: {saved_results['test_auroc']:.4f}")
    print(f"Saved best validation AUROC: {saved_results['best_val_auroc']:.4f}")
    
    # Display model information
    print(f"Model info: {saved_results['model_info']}")
    
except Exception as e:
    print(f"❌ Error testing saved model: {e}")

🧪 Testing saved model loading...
✅ Model state loaded successfully
✅ Training results loaded successfully
Saved test AUROC: 0.9989
Saved best validation AUROC: 0.9990
Model info: {'num_parameters': 720065, 'embedding_dim': 64, 'dense_features': 6, 'categorical_features': 7}


## 22: Utility Functions for Model Inference


In [24]:
def load_trained_model(model_path: str = "dlrm_book_model_final.pth"):
    """Load a trained DLRM model for inference"""
    try:
        # Load preprocessing info
        with open('book_dlrm_preprocessing.pkl', 'rb') as f:
            preprocessing_info = pickle.load(f)
        
        dense_cols = preprocessing_info['dense_cols']
        cat_cols = preprocessing_info['cat_cols']
        emb_counts = preprocessing_info['emb_counts']
        
        # Create model architecture (same as training)
        eb_configs = [
            EmbeddingBagConfig(
                name=f"t_{feature_name}",
                embedding_dim=64,  # Should match training
                num_embeddings=emb_counts[feature_idx],
                feature_names=[feature_name],
            )
            for feature_idx, feature_name in enumerate(cat_cols)
        ]
        
        dlrm_model = DLRM(
            embedding_bag_collection=EmbeddingBagCollection(
                tables=eb_configs, device=device
            ),
            dense_in_features=len(dense_cols),
            dense_arch_layer_sizes=[256, 128, 64],
            over_arch_layer_sizes=[512, 256, 128, 1],
            dense_device=device,
        )
        
        train_model = DLRMTrain(dlrm_model)
        model = train_model.to(device)
        
        # Load saved weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        
        print(f"✅ Model loaded from {model_path}")
        return model, preprocessing_info
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return None, None

def predict_recommendation(model, preprocessing_info, user_features, book_features):
    """Make a recommendation prediction for a user-book pair"""
    # This would need to be implemented based on your specific feature format
    # Placeholder for inference logic
    pass

print("✅ Inference utility functions defined")

# Final Summary
print(f"\n📚 DLRM Book Recommendation Training Complete!")
print(f"🎯 Key accomplishments:")
print(f"   ✓ Environment setup and library imports")
print(f"   ✓ Data preprocessing info loaded")
print(f"   ✓ DLRM model architecture created")
print(f"   ✓ Training pipeline implemented")
print(f"   ✓ Model evaluation with AUROC metrics")
print(f"   ✓ Results saving and model checkpointing")
print(f"   ✓ Utility functions for future inference")

print(f"\n📁 Generated files:")
print(f"   - dlrm_book_model_final.pth (final trained model)")
print(f"   - dlrm_book_training_results.pkl (training results)")

✅ Inference utility functions defined

📚 DLRM Book Recommendation Training Complete!
🎯 Key accomplishments:
   ✓ Environment setup and library imports
   ✓ Data preprocessing info loaded
   ✓ DLRM model architecture created
   ✓ Training pipeline implemented
   ✓ Model evaluation with AUROC metrics
   ✓ Results saving and model checkpointing
   ✓ Utility functions for future inference

📁 Generated files:
   - dlrm_book_model_final.pth (final trained model)
   - dlrm_book_training_results.pkl (training results)


## 23: Hyperparameter Tuning Helper


In [25]:
def create_training_config(
    epochs: int = 5,
    embedding_dim: int = 64,
    batch_size: int = 512,
    learning_rate: float = 0.01,
    dense_layers: List[int] = None,
    over_layers: List[int] = None
) -> BookDLRMArgs:
    """Helper function to create different training configurations"""
    
    if dense_layers is None:
        dense_layers = [256, 128, 64]
    if over_layers is None:
        over_layers = [512, 256, 128, 1]
    
    return BookDLRMArgs(
        epochs=epochs,
        embedding_dim=embedding_dim,
        batch_size=batch_size,
        learning_rate=learning_rate,
        dense_arch_layer_sizes=dense_layers,
        over_arch_layer_sizes=over_layers,
        lr_warmup_steps=0
    )

# Example configurations
config_small = create_training_config(
    epochs=3,
    embedding_dim=32,
    batch_size=256,
    dense_layers=[128, 64],
    over_layers=[256, 128, 1]
)

config_medium = create_training_config(
    epochs=5,
    embedding_dim=64,
    batch_size=512,
    dense_layers=[256, 128, 64],
    over_layers=[512, 256, 128, 1]
)

config_large = create_training_config(
    epochs=10,
    embedding_dim=128,
    batch_size=1024,
    dense_layers=[512, 256, 128],
    over_layers=[1024, 512, 256, 1]
)

print("✅ Hyperparameter tuning helper functions defined")
print("Available configurations: config_small, config_medium, config_large")



✅ Hyperparameter tuning helper functions defined
Available configurations: config_small, config_medium, config_large


# Cell 24: Model Performance Analysis


In [26]:
def analyze_training_results(results_path: str = 'dlrm_book_training_results.pkl'):
    """Analyze training results and plot metrics"""
    try:
        with open(results_path, 'rb') as f:
            results = pickle.load(f)
        
        print("📊 Training Results Analysis:")
        print("=" * 50)
        
        # Basic metrics
        val_aurocs = results['val_aurocs_history']
        test_auroc = results['test_auroc']
        best_val_auroc = max(val_aurocs)
        
        print(f"Final Test AUROC: {test_auroc:.4f}")
        print(f"Best Validation AUROC: {best_val_auroc:.4f}")
        print(f"Final Validation AUROC: {val_aurocs[-1]:.4f}")
        print(f"Total Epochs: {len(val_aurocs)-1}")
        
        # Model info
        model_info = results.get('model_info', {})
        print(f"\nModel Information:")
        print(f"  Parameters: {model_info.get('num_parameters', 'N/A'):,}")
        print(f"  Embedding Dimension: {model_info.get('embedding_dim', 'N/A')}")
        print(f"  Dense Features: {model_info.get('dense_features', 'N/A')}")
        print(f"  Categorical Features: {model_info.get('categorical_features', 'N/A')}")
        
        # Training config
        train_args = results.get('args', {})
        print(f"\nTraining Configuration:")
        print(f"  Batch Size: {train_args.get('batch_size', 'N/A')}")
        print(f"  Learning Rate: {train_args.get('learning_rate', 'N/A')}")
        print(f"  Dense Architecture: {train_args.get('dense_arch_layer_sizes', 'N/A')}")
        print(f"  Over Architecture: {train_args.get('over_arch_layer_sizes', 'N/A')}")
        
        # Validation progress
        print(f"\nValidation AUROC Progress:")
        for i, auroc in enumerate(val_aurocs):
            epoch_label = "Initial" if i == 0 else f"Epoch {i}"
            print(f"  {epoch_label}: {auroc:.4f}")
        
        # Simple plotting with matplotlib if available
        try:
            import matplotlib.pyplot as plt
            
            plt.figure(figsize=(10, 6))
            epochs = list(range(len(val_aurocs)))
            plt.plot(epochs, val_aurocs, 'b-o', label='Validation AUROC')
            plt.axhline(y=test_auroc, color='r', linestyle='--', label=f'Test AUROC ({test_auroc:.4f})')
            plt.xlabel('Epoch')
            plt.ylabel('AUROC')
            plt.title('DLRM Book Recommendation Training Progress')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig('dlrm_training_progress.png', dpi=300, bbox_inches='tight')
            plt.show()
            print("\n✅ Training progress plot saved as 'dlrm_training_progress.png'")
            
        except ImportError:
            print("\n⚠️ matplotlib not available - skipping plot generation")
        
        return results
        
    except FileNotFoundError:
        print("❌ Results file not found. Please complete training first.")
        return None
    except Exception as e:
        print(f"❌ Error analyzing results: {e}")
        return None

print("✅ Performance analysis function defined")

✅ Performance analysis function defined


## 25: Model Comparison Utilities


In [27]:
def compare_model_configs():
    """Compare different model configurations"""
    configs = {
        'Small': config_small,
        'Medium': config_medium, 
        'Large': config_large
    }
    
    print("🔍 Model Configuration Comparison:")
    print("=" * 60)
    
    for name, config in configs.items():
        print(f"\n{name} Configuration:")
        print(f"  Epochs: {config.epochs}")
        print(f"  Embedding Dim: {config.embedding_dim}")
        print(f"  Batch Size: {config.batch_size}")
        print(f"  Learning Rate: {config.learning_rate}")
        print(f"  Dense Layers: {config.dense_arch_layer_sizes}")
        print(f"  Over Layers: {config.over_arch_layer_sizes}")
        
        # Estimate parameter count (approximate)
        if cat_cols and emb_counts:
            emb_params = sum(count * config.embedding_dim for count in emb_counts)
            dense_params = 0
            
            # Dense architecture
            prev_size = len(dense_cols)
            for size in config.dense_arch_layer_sizes:
                dense_params += prev_size * size + size  # weights + bias
                prev_size = size
            
            # Over architecture  
            prev_size = prev_size + len(cat_cols) * config.embedding_dim
            for size in config.over_arch_layer_sizes:
                dense_params += prev_size * size + size
                prev_size = size
            
            total_params = emb_params + dense_params
            print(f"  Est. Parameters: ~{total_params:,}")
            print(f"    - Embedding: ~{emb_params:,}")
            print(f"    - Dense: ~{dense_params:,}")

print("✅ Model comparison utilities defined")


✅ Model comparison utilities defined


## 26: Advanced Training Options


In [28]:
def train_with_early_stopping(
    args: BookDLRMArgs,
    patience: int = 3,
    min_delta: float = 0.001,
    restore_best_weights: bool = True
):
    """Enhanced training with early stopping"""
    print(f"🔄 Training with early stopping (patience={patience}, min_delta={min_delta})")
    
    # This would be integrated into the main training loop
    # For now, just show the concept
    
    class EarlyStopping:
        def __init__(self, patience=3, min_delta=0.001):
            self.patience = patience
            self.min_delta = min_delta
            self.best_score = None
            self.counter = 0
            self.early_stop = False
            
        def __call__(self, val_score):
            if self.best_score is None:
                self.best_score = val_score
            elif val_score > self.best_score + self.min_delta:
                self.best_score = val_score
                self.counter = 0
            else:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True
            
            return self.early_stop
    
    early_stopping = EarlyStopping(patience, min_delta)
    print("✅ Early stopping mechanism ready")
    
    return early_stopping

def create_learning_rate_schedule(
    base_lr: float = 0.01,
    schedule_type: str = "cosine",
    warmup_epochs: int = 1,
    total_epochs: int = 10
):
    """Create different learning rate schedules"""
    schedules = {
        "constant": f"Constant LR: {base_lr}",
        "step": f"Step decay from {base_lr}",
        "cosine": f"Cosine annealing from {base_lr}",
        "exponential": f"Exponential decay from {base_lr}"
    }
    
    print(f"📈 Learning Rate Schedule: {schedules.get(schedule_type, 'Custom')}")
    print(f"   Base LR: {base_lr}")
    print(f"   Warmup epochs: {warmup_epochs}")
    print(f"   Total epochs: {total_epochs}")
    
    return schedule_type

print("✅ Advanced training options defined")

✅ Advanced training options defined


## 27: Data Validation and Debugging


In [29]:
def validate_training_setup():
    """Validate that all components are ready for training"""
    print("🔍 Validating Training Setup...")
    print("=" * 40)
    
    checks = []
    
    # Check preprocessing info
    if dense_cols and cat_cols and emb_counts:
        checks.append(("✅", "Preprocessing info loaded"))
        print(f"   Dense features: {len(dense_cols)}")
        print(f"   Categorical features: {len(cat_cols)}")
    else:
        checks.append(("❌", "Missing preprocessing info"))
    
    # Check data directories
    data_dirs = ["dlrm_book_data/mds_train", "dlrm_book_data/mds_validation", "dlrm_book_data/mds_test"]
    all_dirs_exist = all(os.path.exists(d) for d in data_dirs)
    if all_dirs_exist:
        checks.append(("✅", "Data directories found"))
    else:
        checks.append(("❌", "Missing data directories"))
        for d in data_dirs:
            if not os.path.exists(d):
                print(f"     Missing: {d}")
    
    # Check GPU/CPU setup
    checks.append(("✅", f"Device: {device}"))
    
    # Check distributed setup
    try:
        rank = int(os.environ.get("RANK", "0"))
        world_size = int(os.environ.get("WORLD_SIZE", "1"))
        checks.append(("✅", f"Distributed setup: rank {rank}/{world_size}"))
    except:
        checks.append(("⚠️", "Distributed setup incomplete"))
    
    # Display results
    print(f"\nValidation Results:")
    for status, message in checks:
        print(f"{status} {message}")
    
    all_good = all(check[0] == "✅" for check in checks)
    if all_good:
        print(f"\n🎯 All systems ready for training!")
    else:
        print(f"\n⚠️ Please resolve issues before training")
    
    return all_good

print("✅ Training setup validation function defined")


✅ Training setup validation function defined


## 28: Quick Test Functions


In [31]:
def quick_data_test():
    """Quick test to verify data loading works"""
    try:
        print("🧪 Quick data loading test...")
        
        # Test loading a small batch
        test_dataloader = get_dataloader_with_mosaic(
            "dlrm_book_data/mds_train", 
            batch_size=32, 
            label="test_batch"
        )
        
        if test_dataloader:
            # Get one batch
            batch_iter = iter(test_dataloader)
            sample_batch = next(batch_iter)
            
            print(f"✅ Successfully loaded test batch")
            print(f"   Batch keys: {list(sample_batch.keys())}")
            print(f"   Label shape: {sample_batch['label'].shape if hasattr(sample_batch['label'], 'shape') else 'N/A'}")
            
            # Test transformation
            if transform_partial:
                transformed_batch = transform_partial(sample_batch)
                print(f"✅ Batch transformation successful")
                print(f"   Dense features shape: {transformed_batch.dense_features.shape}")
                print(f"   Sparse features keys: {transformed_batch.sparse_features.keys()}")
                print(f"   Labels shape: {transformed_batch.labels.shape}")
            else:
                print("❌ Transform function not available")
                
        else:
            print("❌ Failed to create test dataloader")
            
    except Exception as e:
        print(f"❌ Data test failed: {e}")

def quick_model_test():
    """Quick test to verify model creation works"""
    try:
        print("🧪 Quick model creation test...")
        
        if cat_cols and emb_counts:
            # Create minimal model for testing
            test_eb_configs = [
                EmbeddingBagConfig(
                    name=f"test_{feature_name}",
                    embedding_dim=8,  # Small for testing
                    num_embeddings=emb_counts[feature_idx],
                    feature_names=[feature_name],
                )
                for feature_idx, feature_name in enumerate(cat_cols[:2])  # Only first 2 features
            ]
            
            test_dlrm = DLRM(
                embedding_bag_collection=EmbeddingBagCollection(
                    tables=test_eb_configs, device=device
                ),
                dense_in_features=len(dense_cols),
                dense_arch_layer_sizes=[32, 16],
                over_arch_layer_sizes=[64, 32, 1],
                dense_device=device,
            )
            
            test_model = DLRMTrain(test_dlrm).to(device)
            num_params = sum(p.numel() for p in test_model.parameters())
            
            print(f"✅ Test model created successfully")
            print(f"   Parameters: {num_params:,}")
            print(f"   Device: {next(test_model.parameters()).device}")
            
        else:
            print("❌ Cannot create test model - missing feature info")
            
    except Exception as e:
        print(f"❌ Model test failed: {e}")

print("✅ Quick test functions defined")

✅ Quick test functions defined


## 29: Training Progress Monitoring


In [33]:
def setup_training_monitoring():
    """Setup training monitoring and logging"""
    
    class TrainingMonitor:
        def __init__(self):
            self.start_time = None
            self.epoch_times = []
            self.best_val_auroc = 0.0
            self.training_history = {
                'val_aurocs': [],
                'epoch_times': [],
                'learning_rates': []
            }
        
        def start_training(self):
            import time
            self.start_time = time.time()
            print(f"🚀 Training started at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        
        def log_epoch(self, epoch, val_auroc, lr=None):
            import time
            current_time = time.time()
            
            if epoch == 0:
                epoch_time = 0
            else:
                epoch_time = current_time - (self.start_time + sum(self.epoch_times))
            
            self.epoch_times.append(epoch_time)
            self.training_history['val_aurocs'].append(val_auroc)
            self.training_history['epoch_times'].append(epoch_time)
            if lr:
                self.training_history['learning_rates'].append(lr)
            
            if val_auroc > self.best_val_auroc:
                self.best_val_auroc = val_auroc
                print(f"🎯 New best validation AUROC: {val_auroc:.4f}")
            
            total_time = current_time - self.start_time
            avg_epoch_time = total_time / (epoch + 1) if epoch > 0 else 0
            
            print(f"📊 Epoch {epoch} Summary:")
            print(f"   Val AUROC: {val_auroc:.4f}")
            print(f"   Epoch time: {epoch_time:.1f}s")
            print(f"   Total time: {total_time:.1f}s")
            print(f"   Avg epoch time: {avg_epoch_time:.1f}s")
            if lr:
                print(f"   Learning rate: {lr:.6f}")
        
        def finish_training(self):
            import time
            total_time = time.time() - self.start_time
            print(f"🏁 Training completed!")
            print(f"   Total time: {total_time:.1f}s ({total_time/60:.1f} minutes)")
            print(f"   Best validation AUROC: {self.best_val_auroc:.4f}")
            print(f"   Average epoch time: {np.mean(self.epoch_times):.1f}s")
    
    monitor = TrainingMonitor()
    print("✅ Training monitor setup complete")
    return monitor

print("✅ Training monitoring setup defined")


✅ Training monitoring setup defined
