# K-view loader profiling (k=2, CPU-only)

This notebook compares data loading performance between:
- Current K-view loader (`ActivationDataset` via `ActivationParser.get_dataset(..., num_views=2)`)
- Legacy two-view loader (`LegacyActivationDataset`)

The setup mirrors `b_contrastive_training_with_new_trainer.ipynb` path/layer/loader settings, but does not run model forward/backward and does not use GPU.

In [None]:
import os
import time
import random
from statistics import median

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

from activation_logging.activation_parser import ActivationParser
from activation_logging.legacy_activation_dataset import LegacyActivationDataset

In [None]:
# ---- Paths (match training notebook defaults) ----
inference_json = 'shared/goodwiki_jsonv2/generation.jsonl'
eval_json = 'shared/goodwiki.zarr/eval_results.json'
activations_path = 'shared/goodwiki.zarr/activations.zarr'

# ---- Dataset parameters ----
backend = 'zarr'
relevant_layers = list(range(14, 30))
fixed_layer = None
pad_length = 63
min_target_layers = 2
num_views = 2  # profile exactly k=2

# ---- DataLoader parameters (similar conditions to training notebook) ----
batch_size = 512
num_workers_candidates = [4, 16]
persistent_workers = True
prefetch_factor = 2

# ---- Profiling controls ----
seed = 42
getitem_samples = 2048
dataloader_batches = 40
warmup_batches = 5

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.set_num_threads(max(1, os.cpu_count() // 2))

print('torch version:', torch.__version__)
print('cpu count:', os.cpu_count())

In [None]:
# ---- Build parser and datasets ----
ap = ActivationParser(
    inference_json=inference_json,
    eval_json=eval_json,
    activations_path=activations_path,
    verbose=False,
)

current_train = ap.get_dataset(
    'train',
    relevant_layers=relevant_layers,
    fixed_layer=fixed_layer,
    pad_length=pad_length,
    min_target_layers=min_target_layers,
    num_views=num_views,
    backend=backend,
)

legacy_train = LegacyActivationDataset(
    df=ap.df,
    activations_path=activations_path,
    split='train',
    relevant_layers=relevant_layers,
    fixed_layer=fixed_layer,
    pad_length=pad_length,
    min_target_layers=min_target_layers,
    logger_type='zarr',
    random_seed=seed,
    verbose=False,
    return_all_activations=False,
)

print('train size (current):', len(current_train))
print('train size (legacy) :', len(legacy_train))

In [None]:
def _summary_from_durations(durations, units='sample'):
    arr = np.asarray(durations, dtype=np.float64)
    total = float(arr.sum())
    count = int(arr.size)
    return {
        'count': count,
        f'total_s_per_{units}': total,
        f'mean_s_per_{units}': float(arr.mean()) if count else float('nan'),
        f'median_s_per_{units}': float(np.median(arr)) if count else float('nan'),
        f'p95_s_per_{units}': float(np.percentile(arr, 95)) if count else float('nan'),
        f'p99_s_per_{units}': float(np.percentile(arr, 99)) if count else float('nan'),
    }


def profile_getitem(dataset, n_samples=1024):
    n = min(n_samples, len(dataset))
    durations = []
    for idx in range(n):
        t0 = time.perf_counter()
        _ = dataset[idx]
        durations.append(time.perf_counter() - t0)

    summary = _summary_from_durations(durations, units='sample')
    summary['samples_per_sec'] = n / summary['total_s_per_sample'] if summary['total_s_per_sample'] > 0 else float('inf')
    return summary


def _passthrough_collate(batch):
    return batch


def profile_dataloader(dataset, *, batch_size, num_workers, persistent_workers, prefetch_factor, warmup_batches, timed_batches):
    loader_kwargs = dict(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=False,
        collate_fn=_passthrough_collate,
    )

    if num_workers > 0:
        loader_kwargs['persistent_workers'] = persistent_workers
        loader_kwargs['prefetch_factor'] = prefetch_factor

    dl = DataLoader(**loader_kwargs)

    it = iter(dl)
    for _ in range(warmup_batches):
        try:
            _ = next(it)
        except StopIteration:
            it = iter(dl)
            _ = next(it)

    batch_durations = []
    samples_seen = 0
    for _ in range(timed_batches):
        t0 = time.perf_counter()
        try:
            batch = next(it)
        except StopIteration:
            it = iter(dl)
            batch = next(it)
        dt = time.perf_counter() - t0
        batch_durations.append(dt)
        samples_seen += len(batch)

    summary = _summary_from_durations(batch_durations, units='batch')
    total_t = summary['total_s_per_batch']
    summary['samples_seen'] = samples_seen
    summary['samples_per_sec'] = samples_seen / total_t if total_t > 0 else float('inf')
    summary['batches_per_sec'] = len(batch_durations) / total_t if total_t > 0 else float('inf')
    return summary

In [None]:
# ---- 1) Pure __getitem__ timing (single-process baseline) ----
getitem_results = []
for name, ds in [('current_k2', current_train), ('legacy_2view', legacy_train)]:
    summary = profile_getitem(ds, n_samples=getitem_samples)
    summary['loader'] = name
    getitem_results.append(summary)

getitem_df = pd.DataFrame(getitem_results).set_index('loader').sort_index()
getitem_df

In [None]:
# ---- 2) DataLoader timing sweep (multi-worker) ----
rows = []
for nw in num_workers_candidates:
    for name, ds in [('current_k2', current_train), ('legacy_2view', legacy_train)]:
        summary = profile_dataloader(
            ds,
            batch_size=batch_size,
            num_workers=nw,
            persistent_workers=(persistent_workers and nw > 0),
            prefetch_factor=prefetch_factor,
            warmup_batches=warmup_batches,
            timed_batches=dataloader_batches,
        )
        summary['loader'] = name
        summary['num_workers'] = nw
        rows.append(summary)
        print(f"done: loader={name} num_workers={nw} samples/s={summary['samples_per_sec']:.2f}")

dl_df = pd.DataFrame(rows)
dl_df = dl_df.sort_values(['num_workers', 'loader']).reset_index(drop=True)
dl_df

In [None]:
# ---- Relative comparison table ----
pivot = dl_df.pivot(index='num_workers', columns='loader', values='samples_per_sec')
if {'current_k2', 'legacy_2view'}.issubset(set(pivot.columns)):
    pivot['legacy_over_current_speedup'] = pivot['legacy_2view'] / pivot['current_k2']
pivot

## Notes

- This profile isolates data loading; model compute is intentionally excluded.
- The current K-view loader with `num_views=2` may still load all relevant layers in its Zarr path, which can dominate IO.
- If the machine cannot handle high worker counts, reduce `num_workers_candidates`.