# Train + Evaluate Retrieval in Colab

This notebook continues after tokenization.

Pipeline in this notebook:
1. Clone/update repo and install dependencies
2. Check token/split files
3. Train contrastive embedder
4. Evaluate `Recall@1/10/100`, `MRR`
5. Save artifacts to Google Drive (optional)


## 0) Runtime

In Colab: `Runtime -> Change runtime type -> GPU`.


In [None]:
import platform
print('Python:', platform.python_version())

try:
    import torch
    print('Torch:', torch.__version__)
    print('CUDA available:', torch.cuda.is_available())
    if torch.cuda.is_available():
        print('GPU:', torch.cuda.get_device_name(0))
except Exception as exc:
    print('Torch check failed:', exc)


In [None]:
# Clone/update repo and switch into it
from pathlib import Path
import subprocess

REPO_ROOT = Path('/content/CL_ml')
if not REPO_ROOT.exists():
    !git clone https://github.com/epitaph76/CL_ml.git /content/CL_ml
else:
    print('Repo already exists at', REPO_ROOT)

subprocess.run(['git', '-C', str(REPO_ROOT), 'pull', '--ff-only'], check=False)
head = subprocess.check_output(['git', '-C', str(REPO_ROOT), 'rev-parse', '--short', 'HEAD'], text=True).strip()
print('Repo HEAD:', head)
%cd /content/CL_ml


In [None]:
# Install project dependencies
!pip -q install -r /content/CL_ml/requirements.txt


## 1) Data paths and quick checks


In [None]:
from pathlib import Path

REPO_ROOT = Path('/content/CL_ml')
TOKENS_ROOT = REPO_ROOT / 'data' / 'tokens'
SPLITS_ROOT = REPO_ROOT / 'data' / 'splits'
CHECKPOINTS_ROOT = REPO_ROOT / 'data' / 'checkpoints'
REPORTS_ROOT = REPO_ROOT / 'data' / 'reports'

TOKENS_ROOT.mkdir(parents=True, exist_ok=True)
SPLITS_ROOT.mkdir(parents=True, exist_ok=True)
CHECKPOINTS_ROOT.mkdir(parents=True, exist_ok=True)
REPORTS_ROOT.mkdir(parents=True, exist_ok=True)

print('TOKENS_ROOT =', TOKENS_ROOT)
print('SPLITS_ROOT =', SPLITS_ROOT)
print('CHECKPOINTS_ROOT =', CHECKPOINTS_ROOT)
print('REPORTS_ROOT =', REPORTS_ROOT)


In [None]:
# Optional: mount Google Drive and point TOKENS_ROOT to your prepared tokens
# from google.colab import drive
# drive.mount('/content/drive')
# TOKENS_ROOT = Path('/content/drive/MyDrive/CL_ml_tokens')
# print('TOKENS_ROOT switched to', TOKENS_ROOT)


In [None]:
# Validate tokens
import itertools

token_files = sorted(TOKENS_ROOT.rglob('*.pt')) + sorted(TOKENS_ROOT.rglob('*.npz'))
print('Token files found:', len(token_files))
for p in itertools.islice(token_files, 10):
    print('-', p)

if len(token_files) < 3:
    raise RuntimeError(
        'Need at least 3 token files. Run notebook 01_moss_tokenization_colab.ipynb or copy tokens to TOKENS_ROOT.'
    )


In [None]:
# Build splits if missing
import shlex
import subprocess
import os

need_splits = not all((SPLITS_ROOT / name).exists() for name in ['train.txt', 'val.txt', 'test.txt'])
if need_splits:
    cmd = [
        'python', '-m', 'src.dataset.build_splits',
        '--tokens-root', str(TOKENS_ROOT),
        '--output-root', str(SPLITS_ROOT),
        '--val-ratio', '0.1',
        '--test-ratio', '0.1',
    ]
    print('Running:', ' '.join(shlex.quote(x) for x in cmd))
    env = os.environ.copy()
    env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
    subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, check=True)
else:
    print('Splits already exist, skip build.')


In [None]:
# Inspect split summary
for name in ['train.txt', 'val.txt', 'test.txt', 'summary.json']:
    p = SPLITS_ROOT / name
    print('
===', name, '===')
    if p.exists():
        print(p.read_text(encoding='utf-8')[:500])
    else:
        print('not found')


## 2) Train baseline embedder


In [None]:
# Full training run
import shlex
import subprocess
import os

cmd = [
    'python', '-m', 'src.train.train_contrastive',
    '--config', 'configs/train.yaml',
    '--device', 'auto',
    '--output-dir', str(CHECKPOINTS_ROOT),
]
print('Running:', ' '.join(shlex.quote(x) for x in cmd))
env = os.environ.copy()
env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, check=True)


In [None]:
# Smoke training run (uncomment if needed)
# import shlex, subprocess, os
# cmd = [
#     'python', '-m', 'src.train.train_contrastive',
#     '--config', 'configs/train.yaml',
#     '--device', 'auto',
#     '--max-steps-per-epoch', '5',
#     '--output-dir', str(REPO_ROOT / 'data' / 'checkpoints_smoke'),
# ]
# print('Running:', ' '.join(shlex.quote(x) for x in cmd))
# env = os.environ.copy()
# env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
# subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, check=True)


In [None]:
# Check checkpoint files
for name in ['best.pt', 'last.pt', 'history.json']:
    p = CHECKPOINTS_ROOT / name
    print(name, '->', 'exists' if p.exists() else 'missing', p)


## 3) Evaluate retrieval metrics


In [None]:
# Exact retrieval metrics
import shlex
import subprocess
import os

cmd = [
    'python', '-m', 'src.index.evaluate_retrieval',
    '--config', 'configs/train.yaml',
    '--checkpoint', str(CHECKPOINTS_ROOT / 'best.pt'),
    '--tokens-root', str(TOKENS_ROOT),
    '--splits-root', str(SPLITS_ROOT),
    '--split', 'val',
    '--topk', '1,10,100',
    '--device', 'auto',
    '--output-json', str(REPORTS_ROOT / 'val_exact.json'),
]
print('Running:', ' '.join(shlex.quote(x) for x in cmd))
env = os.environ.copy()
env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, check=True)


In [None]:
# Exact + FAISS metrics (optional)
import shlex
import subprocess
import os

cmd = [
    'python', '-m', 'src.index.evaluate_retrieval',
    '--config', 'configs/train.yaml',
    '--checkpoint', str(CHECKPOINTS_ROOT / 'best.pt'),
    '--tokens-root', str(TOKENS_ROOT),
    '--splits-root', str(SPLITS_ROOT),
    '--split', 'val',
    '--topk', '1,10,100',
    '--device', 'auto',
    '--use-faiss',
    '--output-json', str(REPORTS_ROOT / 'val_exact_faiss.json'),
]
print('Running:', ' '.join(shlex.quote(x) for x in cmd))
env = os.environ.copy()
env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, check=True)


In [None]:
# Print metric JSON files
import json

for name in ['val_exact.json', 'val_exact_faiss.json']:
    p = REPORTS_ROOT / name
    print('
===', name, '===')
    if p.exists():
        payload = json.loads(p.read_text(encoding='utf-8'))
        print(json.dumps(payload.get('results', {}), indent=2, ensure_ascii=False))
    else:
        print('not found')


## 4) Save artifacts to Google Drive (optional)


In [None]:
# Optional: persist run artifacts
# from google.colab import drive
# drive.mount('/content/drive')
# !mkdir -p /content/drive/MyDrive/CL_ml_runs/run1
# !cp -r /content/CL_ml/data/checkpoints /content/drive/MyDrive/CL_ml_runs/run1/
# !cp -r /content/CL_ml/data/reports /content/drive/MyDrive/CL_ml_runs/run1/
# !cp -r /content/CL_ml/configs /content/drive/MyDrive/CL_ml_runs/run1/
