# Train two-tower SBERT on Instacart data

This notebook mirrors `src/train` so you can run and inspect each step interactively.
It assumes you have already run the data prep (`src.data.prepare_instacart_sbert`) and have
datasets under `processed/`.

## 1. Setup: imports and paths

In [1]:
from pathlib import Path
import json

from datasets import Dataset, load_from_disk
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

# Project root (parent of notebooks/) if you start Jupyter from project root
PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
PROCESSED_DIR = PROJECT_ROOT / "processed"
OUTPUT_DIR = PROJECT_ROOT / "models" / "two_tower_sbert_notebook"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("PROCESSED_DIR:", PROCESSED_DIR)
print("OUTPUT_DIR:", OUTPUT_DIR)

PROJECT_ROOT: /Users/chen_bowen/AI & ML/Projects/Instacart_Personalization
PROCESSED_DIR: /Users/chen_bowen/AI & ML/Projects/Instacart_Personalization/processed
OUTPUT_DIR: /Users/chen_bowen/AI & ML/Projects/Instacart_Personalization/models/two_tower_sbert_notebook


## 2. Load processed datasets and Information retrieval artifacts

In [2]:
train_dataset = load_from_disk(str(PROCESSED_DIR / "train_dataset"))
eval_dataset = None
if (PROCESSED_DIR / "eval_dataset").exists():
    eval_dataset = load_from_disk(str(PROCESSED_DIR / "eval_dataset"))

with open(PROCESSED_DIR / "eval_queries.json", "r") as f:
    eval_queries = json.load(f)
with open(PROCESSED_DIR / "eval_corpus.json", "r") as f:
    eval_corpus = json.load(f)
with open(PROCESSED_DIR / "eval_relevant_docs.json", "r") as f:
    _raw = json.load(f)
    eval_relevant_docs = {k: set(v) for k, v in _raw.items()}

print("train_dataset:", train_dataset)
print("eval_dataset:", eval_dataset)
print("#queries:", len(eval_queries), "#corpus docs:", len(eval_corpus), "#qrels:", len(eval_relevant_docs))

train_dataset: Dataset({
    features: ['anchor', 'positive'],
    num_rows: 1246220
})
eval_dataset: Dataset({
    features: ['anchor', 'positive'],
    num_rows: 138397
})
#queries: 13120 #corpus docs: 49688 #qrels: 13120


Inspect a sample pair and one Information retrieval example.

In [3]:
sample = train_dataset[0]
print("Anchor sample (user context):\n", sample["anchor"][:400], "...\n")
print("Positive sample (product):\n", sample["positive"])
qid = list(eval_queries.keys())[0]
print("\nSample eval query id:", qid)
print("Query text:\n", eval_queries[qid][:400], "...\n")
print("Relevant product ids (first 10):", list(eval_relevant_docs[qid])[:10])

Anchor sample (user context):
 Previously ordered: Tuna Ventresca, in Olive Oil (x1), Bulgarian Yogurt (x2), Organic 4% Milk Fat Whole Milk Cottage Cheese (x2), Organic Small Bunch Celery (x2), Organic Whole String Cheese (x2), Banana (x1), Plus Cranberry Almond + Antioxidants with Macadamia Nuts Bar (x1), Pure Sparkling Water (x3), Dark Chocolate Cinnamon Pecan Bar (x2), Sparkling Water Grapefruit (x1), Naturally Smoked Oyster ...

Positive sample (product):
 Product: Bulgarian Yogurt. Aisle: yogurt. Department: dairy eggs.

Sample eval query id: 3178496
Query text:
 Previously ordered: Organic Romaine Lettuce (x1), Organic Red Radish, Bunch (x1), Organic Rainbow Chard Vegetable (x1), Organic Dandelion Greens (x1), Chinese Eggplant (x1), Organic Zucchini (x1), Organic Grape Tomatoes (x1), Veggie Ground (x1), Mini Crispy Crabless Cakes (x1), Lemongrass Basil Simmer Sauce (x1), Orange Mango Chicken (x1), Golden Fishless Filet (x4), Chicken Drumsticks (x1), Organ ...

Relevant product id

## 3. Build model and loss

In [4]:
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"  # change if you like
MAX_SEQ_LENGTH = 256

model = SentenceTransformer(MODEL_NAME)
model.max_seq_length = MAX_SEQ_LENGTH

loss = MultipleNegativesRankingLoss(model)
model, loss

Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

[1mBertModel LOAD REPORT[0m from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m


(SentenceTransformer(
   (0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'BertModel'})
   (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
   (2): Normalize()
 ),
 MultipleNegativesRankingLoss(
   (model): SentenceTransformer(
     (0): Transformer({'max_seq_length': 256, 'do_lower_case': False, 'architecture': 'BertModel'})
     (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
     (2): Normalize()
   )
   (cross_entropy_loss): CrossEntropyLoss()
 ))

## 4. Build InformationRetrievalEvaluator

In [5]:
information_retrieval_evaluator = InformationRetrievalEvaluator(
    queries=eval_queries,
    corpus=eval_corpus,
    relevant_docs=eval_relevant_docs,
    name="instacart-two-tower-notebook",
)
information_retrieval_evaluator

<sentence_transformers.evaluation.InformationRetrievalEvaluator.InformationRetrievalEvaluator at 0x1284424e0>

## 5. Define training arguments

In [6]:
# Performance optimizations:
# - dataloader_num_workers=4: parallel data loading (was 0 = single-threaded bottleneck)
# - dataloader_pin_memory=True: faster GPU transfer (if CUDA available)
# - eval_steps: Information retrieval evaluator can be slow; increase to reduce frequency or disable with --no-information-retrieval-evaluator
# - Consider increasing batch_size if GPU memory allows
training_args = SentenceTransformerTrainingArguments(
    output_dir=str(OUTPUT_DIR),
    num_train_epochs=1,
    per_device_train_batch_size=32,  # Increase if GPU memory allows (64, 128, etc.)
    per_device_eval_batch_size=32,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    eval_strategy="steps",
    eval_steps=1000,  # Increased from 500 - Information retrieval evaluation is expensive
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    # Performance optimizations:
    dataloader_num_workers=4,  # parallel data loading (was 0 = single-threaded bottleneck)
    dataloader_pin_memory=True,  # faster GPU transfer (if CUDA available)
    gradient_accumulation_steps=1,  # increase to 2-4 if you want larger effective batch size
)
training_args

warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.




## 6. Create trainer

In [7]:
# Note: Information retrieval evaluator can be slow. For faster training, set evaluator=None below.
# You'll still get validation loss from eval_dataset, but won't see Recall@k/MRR metrics during training.
trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=information_retrieval_evaluator,  # Set to None for faster training (only validation loss)
)
trainer

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

<sentence_transformers.trainer.SentenceTransformerTrainer at 0x10ad562d0>

## 7. Train (run a short experiment)

You can run this cell to start training. For quick experiments, keep epochs small
and optionally downsample `train_dataset` above. The Information retrieval evaluator
will run periodically during training.

In [None]:
# This will start training and periodically run the Information retrieval evaluator.
trainer.train()

  super().__init__(loader)


Step,Training Loss,Validation Loss


## 8. Save final model (optional)

In [None]:
final_dir = OUTPUT_DIR / "final_notebook"
final_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(final_dir))
print("Saved model to", final_dir)