# MOSS Tokenization + Retrieval Training in Colab

This notebook covers full baseline pipeline for the project in one place.

Pipeline in this notebook:
1. Clone project repo
2. Install dependencies
3. Prepare audio corpus (upload / Google Drive / HuggingFace dataset / custom mix)
4. Run tokenization (`audio -> tokens`)
5. Validate output token files
6. Build train/val/test split files
7. Train baseline contrastive embedder
8. Evaluate `Recall@1/10/100`, `MRR`


## 0) Runtime

In Colab: `Runtime -> Change runtime type -> GPU` (optional but recommended).

In [None]:
import platform
print('Python:', platform.python_version())

try:
    import torch
    print('Torch:', torch.__version__)
    print('CUDA available:', torch.cuda.is_available())
    if torch.cuda.is_available():
        print('GPU:', torch.cuda.get_device_name(0))
except Exception as exc:
    print('Torch check failed:', exc)

In [None]:
# Clone/update repo and switch into it
from pathlib import Path
import subprocess

REPO_ROOT = Path('/content/CL_ml')
if not REPO_ROOT.exists():
    !git clone https://github.com/epitaph76/CL_ml.git /content/CL_ml
else:
    print('Repo already exists at', REPO_ROOT)

subprocess.run(['git', '-C', str(REPO_ROOT), 'pull', '--ff-only'], check=False)
head = subprocess.check_output(['git', '-C', str(REPO_ROOT), 'rev-parse', '--short', 'HEAD'], text=True).strip()
print('Repo HEAD:', head)
%cd /content/CL_ml


In [None]:
# Install project dependencies
!pip -q install -r /content/CL_ml/requirements.txt


In [None]:
# Diagnostic: make sure repository files are present
from pathlib import Path
repo = Path('/content/CL_ml')
print('Repo exists:', repo.exists())
print('src exists:', (repo / 'src').exists())
print('moss_tokenize exists:', (repo / 'src' / 'tokenizer' / 'moss_tokenize.py').exists())
if repo.exists():
    print('Top-level files:', sorted([p.name for p in repo.iterdir()])[:20])


## 1) Configure input/output folders

Options for audio source (local-first):
- Upload a few files directly from your computer
- Upload a zip archive with many tracks (recommended for ~1000 files)
- Download Navrasa-5000 from HuggingFace
- Mix both sources in one corpus


In [None]:
from pathlib import Path

REPO_ROOT = Path('/content/CL_ml')
if not REPO_ROOT.exists():
    raise RuntimeError('Repo root not found. Run clone cell first.')

INPUT_ROOT = Path('/content/audio_input')
OUTPUT_ROOT = REPO_ROOT / 'data' / 'tokens'
SPLITS_ROOT = REPO_ROOT / 'data' / 'splits'

INPUT_ROOT.mkdir(parents=True, exist_ok=True)
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
SPLITS_ROOT.mkdir(parents=True, exist_ok=True)

print('REPO_ROOT =', REPO_ROOT)
print('INPUT_ROOT =', INPUT_ROOT)
print('OUTPUT_ROOT =', OUTPUT_ROOT)
print('SPLITS_ROOT =', SPLITS_ROOT)


In [None]:
# Optional: upload a few local audio files directly from your computer
# For large batches (~1000 files) use the zip upload cell below.
from google.colab import files
import shutil

uploaded = files.upload()
for name in uploaded.keys():
    src = Path('/content') / name
    dst = INPUT_ROOT / name
    if src.exists():
        shutil.move(str(src), str(dst))

print('Uploaded files:', len(list(INPUT_ROOT.glob('*'))))


In [None]:
# Optional: upload one or more zip archives from your computer and extract to INPUT_ROOT/custom
from google.colab import files
import zipfile

uploaded = files.upload()
zip_names = [name for name in uploaded.keys() if name.lower().endswith('.zip')]

if not zip_names:
    print('No .zip files uploaded. Skip.')
else:
    target_root = INPUT_ROOT / 'custom'
    target_root.mkdir(parents=True, exist_ok=True)
    for name in zip_names:
        src = Path('/content') / name
        out_dir = target_root / Path(name).stem
        out_dir.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(src, 'r') as zf:
            zf.extractall(out_dir)
        print('Extracted:', name, '->', out_dir)


### Optional: download Navrasa-5000 from HuggingFace

Set `DOWNLOAD_NAVRASA = True` to download and extract dataset archives into `INPUT_ROOT/navrasa`.


In [None]:
from pathlib import Path
import zipfile
from huggingface_hub import snapshot_download

DOWNLOAD_NAVRASA = False
NAVRASA_SUBSETS = ['Global', 'India']   # choose: ['Global'] or ['India'] or both
NAVRASA_ZIP_LIMIT = None                # for quick test set a number, e.g. 3
HF_DATASET_ID = 'beastLucifer/navrasa-5000-dataset'
HF_LOCAL_DIR = Path('/content/navrasa_hf')

if DOWNLOAD_NAVRASA:
    allow_patterns = [f'{subset}/*.zip' for subset in NAVRASA_SUBSETS] + ['master_manifest.json']
    print('Downloading dataset snapshot...')
    snapshot_download(
        repo_id=HF_DATASET_ID,
        repo_type='dataset',
        local_dir=str(HF_LOCAL_DIR),
        allow_patterns=allow_patterns,
        local_dir_use_symlinks=False,
    )

    zip_files = []
    for subset in NAVRASA_SUBSETS:
        zip_files.extend(sorted((HF_LOCAL_DIR / subset).glob('*.zip')))
    if NAVRASA_ZIP_LIMIT is not None:
        zip_files = zip_files[: int(NAVRASA_ZIP_LIMIT)]

    target_root = INPUT_ROOT / 'navrasa'
    target_root.mkdir(parents=True, exist_ok=True)

    print('Extracting zip archives:', len(zip_files))
    for idx, zip_path in enumerate(zip_files, start=1):
        out_dir = target_root / zip_path.stem
        out_dir.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(zip_path, 'r') as zf:
            zf.extractall(out_dir)
        if idx % 10 == 0 or idx == len(zip_files):
            print(f'  extracted {idx}/{len(zip_files)}: {zip_path.name}')

    print('Navrasa extracted to:', target_root)
else:
    print('Skip Navrasa download. Set DOWNLOAD_NAVRASA=True to enable.')


In [None]:
# Inspect final audio corpus before tokenization
from collections import Counter

exts = {'.mp3', '.wav', '.flac', '.ogg', '.m4a'}
audio_files = [p for p in INPUT_ROOT.rglob('*') if p.is_file() and p.suffix.lower() in exts]
counts = Counter(p.suffix.lower() for p in audio_files)

print('INPUT_ROOT:', INPUT_ROOT)
print('Total audio files:', len(audio_files))
print('By extension:', dict(sorted(counts.items())))
for p in audio_files[:10]:
    print('-', p)


## 2) Run MOSS tokenization

In [None]:
# Full run with detailed logging (absolute script path)
import shlex
import subprocess
import os

script_path = REPO_ROOT / 'src' / 'tokenizer' / 'moss_tokenize.py'
if not script_path.exists():
    raise RuntimeError(f'Script not found: {script_path}. Re-run clone cell.')

exts = {'.mp3', '.wav', '.flac', '.ogg', '.m4a'}
if INPUT_ROOT.is_file():
    found = [INPUT_ROOT]
else:
    found = sorted([p for p in INPUT_ROOT.rglob('*') if p.is_file() and p.suffix.lower() in exts])

print('Found audio files:', len(found))
for p in found[:10]:
    print('-', p)
if not found:
    raise RuntimeError(f'No audio files found under: {INPUT_ROOT}')

cmd = [
    'python', str(script_path),
    '--input-root', str(INPUT_ROOT),
    '--output-root', str(OUTPUT_ROOT),
    '--device', 'auto',
]
print('Running:', ' '.join(shlex.quote(x) for x in cmd))
env = os.environ.copy()
env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
result = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, text=True, capture_output=True)
print('\n--- STDOUT ---')
print(result.stdout)
if result.returncode != 0:
    print('--- STDERR ---')
    print(result.stderr)
    raise RuntimeError(f'moss_tokenize failed with code {result.returncode}')


In [None]:
# Smoke run on a subset (uncomment and run if needed)
# import shlex, subprocess, os
# script_path = REPO_ROOT / 'src' / 'tokenizer' / 'moss_tokenize.py'
# cmd = [
#     'python', str(script_path),
#     '--input-root', str(INPUT_ROOT),
#     '--output-root', str(OUTPUT_ROOT),
#     '--device', 'auto',
#     '--max-files', '10',
# ]
# print('Running:', ' '.join(shlex.quote(x) for x in cmd))
# env = os.environ.copy()
# env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
# subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, check=True)


## 3) Inspect token outputs

In [None]:
from pathlib import Path
import torch

token_files = sorted(Path(OUTPUT_ROOT).glob('*.pt')) + sorted(Path(OUTPUT_ROOT).glob('*.npz'))
print('Token files:', len(token_files))
for p in token_files[:5]:
    print('-', p.name)

if token_files and token_files[0].suffix == '.pt':
    sample = torch.load(token_files[0], map_location='cpu')
    print('sample track_id:', sample.get('track_id'))
    print('sample token_shape:', sample.get('token_shape'))
    print('tensor shape:', tuple(sample['tokens'].shape))

### Build split files

Run the next code cell to generate `train/val/test` split files.


In [None]:
import shlex
import subprocess
import os

script_path = REPO_ROOT / 'src' / 'dataset' / 'build_splits.py'
if not script_path.exists():
    raise RuntimeError(f'Script not found: {script_path}. Re-run clone cell.')

cmd = [
    'python', str(script_path),
    '--tokens-root', str(OUTPUT_ROOT),
    '--output-root', str(SPLITS_ROOT),
    '--val-ratio', '0.1',
    '--test-ratio', '0.1',
]
print('Running:', ' '.join(shlex.quote(x) for x in cmd))
env = os.environ.copy()
env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
result = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, text=True, capture_output=True)
print('\n--- STDOUT ---')
print(result.stdout)
if result.returncode != 0:
    print('--- STDERR ---')
    print(result.stderr)
    raise RuntimeError(f'build_splits failed with code {result.returncode}')


In [None]:
for name in ['train.txt', 'val.txt', 'test.txt', 'summary.json']:
    p = SPLITS_ROOT / name
    print('\\n===', name, '===')
    if p.exists():
        print(p.read_text(encoding='utf-8')[:500])
    else:
        print('not found')

## 4) Train and evaluate retrieval

Continue in this same notebook: training + offline retrieval metrics.

Default eval below uses `--eval-protocol cross_chunk --exclude-self` (more honest).
It needs at least 2 chunks per track, so tokenization should use `--chunk-seconds`.


In [None]:
from pathlib import Path

CHECKPOINTS_ROOT = REPO_ROOT / 'data' / 'checkpoints'
REPORTS_ROOT = REPO_ROOT / 'data' / 'reports'
CHECKPOINTS_ROOT.mkdir(parents=True, exist_ok=True)
REPORTS_ROOT.mkdir(parents=True, exist_ok=True)

print('OUTPUT_ROOT =', OUTPUT_ROOT)
print('SPLITS_ROOT =', SPLITS_ROOT)
print('CHECKPOINTS_ROOT =', CHECKPOINTS_ROOT)
print('REPORTS_ROOT =', REPORTS_ROOT)


In [None]:
# Full training run
import shlex
import subprocess
import os

cmd = [
    'python', '-m', 'src.train.train_contrastive',
    '--config', 'configs/train.yaml',
    '--device', 'auto',
    '--output-dir', str(CHECKPOINTS_ROOT),
    '--batch-size', '4',
    '--num-workers', '2',
]
print('Running:', ' '.join(shlex.quote(x) for x in cmd))
env = os.environ.copy()
env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
result = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, text=True, capture_output=True)
print('\n--- STDOUT ---')
print(result.stdout)
if result.returncode != 0:
    print('--- STDERR ---')
    print(result.stderr)
    raise RuntimeError(f'train_contrastive failed with code {result.returncode}')


In [None]:
# Smoke training run (optional)
# import shlex, subprocess, os
# cmd = [
#     'python', '-m', 'src.train.train_contrastive',
#     '--config', 'configs/train.yaml',
#     '--device', 'auto',
#     '--max-steps-per-epoch', '5',
#     '--output-dir', str(REPO_ROOT / 'data' / 'checkpoints_smoke'),
# ]
# print('Running:', ' '.join(shlex.quote(x) for x in cmd))
# env = os.environ.copy()
# env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
# subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, check=True)


In [None]:
for name in ['best.pt', 'last.pt', 'history.json']:
    p = CHECKPOINTS_ROOT / name
    print(name, '->', 'exists' if p.exists() else 'missing', p)


In [None]:
# Exact retrieval metrics
import shlex
import subprocess
import os

cmd = [
    'python', '-m', 'src.index.evaluate_retrieval',
    '--config', 'configs/train.yaml',
    '--checkpoint', str(CHECKPOINTS_ROOT / 'best.pt'),
    '--tokens-root', str(OUTPUT_ROOT),
    '--splits-root', str(SPLITS_ROOT),
    '--split', 'val',
    '--topk', '1,10,100',
    '--device', 'auto',
    '--eval-protocol', 'cross_chunk',
    '--exclude-self',
    '--batch-size', '16',
    '--output-json', str(REPORTS_ROOT / 'val_exact.json'),
]
print('Running:', ' '.join(shlex.quote(x) for x in cmd))
env = os.environ.copy()
env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
result = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, text=True, capture_output=True)
print('\n--- STDOUT ---')
print(result.stdout)
if result.returncode != 0:
    print('--- STDERR ---')
    print(result.stderr)
    raise RuntimeError(f'evaluate_retrieval failed with code {result.returncode}')


In [None]:
# Exact + FAISS metrics (optional)
import shlex
import subprocess
import os

cmd = [
    'python', '-m', 'src.index.evaluate_retrieval',
    '--config', 'configs/train.yaml',
    '--checkpoint', str(CHECKPOINTS_ROOT / 'best.pt'),
    '--tokens-root', str(OUTPUT_ROOT),
    '--splits-root', str(SPLITS_ROOT),
    '--split', 'val',
    '--topk', '1,10,100',
    '--device', 'auto',
    '--eval-protocol', 'cross_chunk',
    '--exclude-self',
    '--batch-size', '16',
    '--use-faiss',
    '--output-json', str(REPORTS_ROOT / 'val_exact_faiss.json'),
]
print('Running:', ' '.join(shlex.quote(x) for x in cmd))
env = os.environ.copy()
env['PYTHONPATH'] = str(REPO_ROOT) + (':' + env['PYTHONPATH'] if env.get('PYTHONPATH') else '')
result = subprocess.run(cmd, cwd=str(REPO_ROOT), env=env, text=True, capture_output=True)
print('\n--- STDOUT ---')
print(result.stdout)
if result.returncode != 0:
    print('--- STDERR ---')
    print(result.stderr)
    raise RuntimeError(f'evaluate_retrieval (faiss) failed with code {result.returncode}')


In [None]:
import json

for name in ['val_exact.json', 'val_exact_faiss.json']:
    p = REPORTS_ROOT / name
    print('\n===', name, '===')
    if p.exists():
        payload = json.loads(p.read_text(encoding='utf-8'))
        print(json.dumps(payload.get('results', {}), indent=2, ensure_ascii=False))
    else:
        print('not found')


In [None]:
# Optional: save artifacts to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# !mkdir -p /content/drive/MyDrive/CL_ml_runs/run1
# !cp -r /content/CL_ml/data/checkpoints /content/drive/MyDrive/CL_ml_runs/run1/
# !cp -r /content/CL_ml/data/reports /content/drive/MyDrive/CL_ml_runs/run1/
# !cp -r /content/CL_ml/configs /content/drive/MyDrive/CL_ml_runs/run1/


## Done

You now have one Colab notebook for end-to-end baseline: `audio -> tokens -> splits -> train -> retrieval metrics`.
