# Incremental Training with Replay Buffer

In [None]:
import os
from importnb import Notebook
import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset
import random
from diamond_world_model_trainer import train_diamond_model
import pickle

import config
with Notebook():
    from jetbot_dataset import JetbotDataset
    from combine_session_data import combine_sessions_append, gather_new_sessions_only
    from compare_diamond_models import load_sampler, evaluate_models_alternating

import models

In [None]:
class ReplayBuffer(Dataset):
    """A simple replay buffer storing dataset indices."""
    def __init__(self, dataset, max_size=50000, index_path=None):
        self.dataset = dataset
        self.max_size = max_size
        self.index_path = index_path
        if index_path and os.path.exists(index_path):
            with open(index_path, 'rb') as f:
                self.indices = pickle.load(f)
        else:
            self.indices = list(range(len(dataset)))[:max_size]
            if index_path:
                with open(index_path, 'wb') as f:
                    pickle.dump(self.indices, f)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def sample(self, k):
        idxs = random.sample(self.indices, min(k, len(self.indices)))
        return [self.dataset[i] for i in idxs]

    def add_episode(self, new_idx):
        self.indices = list(new_idx) + self.indices
        self.indices = self.indices[:self.max_size]
        if self.index_path:
            with open(self.index_path, 'wb') as f:
                pickle.dump(self.indices, f)

class MixedDataset(IterableDataset):
    def __init__(self, fresh_ds, replay_buffer, alpha=0.2):
        self.fresh_ds = fresh_ds
        self.replay_buffer = replay_buffer
        self.alpha = alpha

    def __iter__(self):
        while True:
            if random.random() < self.alpha and len(self.fresh_ds) > 0:
                idx = random.randint(0, len(self.fresh_ds) - 1)
                yield self.fresh_ds[idx]
            else:
                yield self.replay_buffer.sample(1)[0]


def build_batch(samples):
    "Collate function building a models.Batch from dataset samples."
    imgs, acts, prevs = zip(*samples)
    imgs  = torch.stack(imgs, 0)
    acts  = torch.stack(acts, 0)
    prevs = torch.stack(prevs, 0)
    b        = len(samples)
    num_prev = config.NUM_PREV_FRAMES
    c, h, w  = config.DM_IMG_CHANNELS, config.IMAGE_SIZE, config.IMAGE_SIZE
    prev_seq = prevs.view(b, num_prev, c, h, w)
    obs      = torch.cat((prev_seq, imgs.unsqueeze(1)), dim=1)
    act_seq  = acts.repeat(1, num_prev).long()
    mask     = torch.ones(b, num_prev + 1, dtype=torch.bool, device=imgs.device)
    return models.Batch(obs=obs, act=act_seq, mask_padding=mask, info=[{}] * b)


In [None]:
def main():
\n    gather_new_sessions_only(
        config.SESSION_DATA_DIR,
        config.CSV_PATH,
        config.NEW_IMAGE_DIR,
        config.NEW_CSV_PATH,
    )
    fresh_ds = JetbotDataset(\n        config.NEW_CSV_PATH,\n        config.NEW_DATA_DIR,\n        config.IMAGE_SIZE,\n        config.NUM_PREV_FRAMES,\n        transform=config.TRANSFORM,\n    ) if os.path.exists(config.NEW_CSV_PATH) else []\n\n    full_ds = JetbotDataset(\n        config.CSV_PATH,\n        config.DATA_DIR,\n        config.IMAGE_SIZE,\n        config.NUM_PREV_FRAMES,\n        transform=config.TRANSFORM,\n    )\n    replay_ds = ReplayBuffer(full_ds, max_size=50000, index_path=config.REPLAY_INDEX_PATH)\n\n    mixed_dataset = MixedDataset(fresh_ds, replay_ds, alpha=0.2)\n    train_loader = DataLoader(\n        mixed_dataset,\n        batch_size=config.BATCH_SIZE,\n        collate_fn=build_batch,\n        num_workers=4,\n        pin_memory=True,\n        drop_last=True,\n    )\n\n    val_dataset = JetbotDataset(\n        config.HOLDOUT_CSV_PATH,\n        config.HOLDOUT_DATA_DIR,\n        config.IMAGE_SIZE,\n        config.NUM_PREV_FRAMES,\n        transform=config.TRANSFORM,\n    )\n    val_loader = DataLoader(\n        val_dataset,\n        batch_size=config.BATCH_SIZE,\n        shuffle=False,\n        collate_fn=build_batch,\n        num_workers=4,\n        pin_memory=True,\n    )\n\n    # Step 2: train a new model starting from the last best checkpoint\n    ckpt_path = os.path.join(config.CHECKPOINT_DIR, 'denoiser_model_best_val_loss.pth')\n    new_ckpt = train_diamond_model(\n        train_loader,\n        val_loader,\n        start_checkpoint=ckpt_path,\n        max_steps=config.NUM_TRAIN_STEPS,\n    )\n\n    # Step 3: compare old best with the newly trained checkpoint\n    if os.path.exists(ckpt_path):\n        sampler_a = load_sampler(ckpt_path, config.DEVICE)\n        sampler_b = load_sampler(new_ckpt, config.DEVICE)\n        dataset_holdout = JetbotDataset(\n            config.HOLDOUT_CSV_PATH,\n            config.HOLDOUT_DATA_DIR,\n            config.IMAGE_SIZE,\n            config.NUM_PREV_FRAMES,\n            transform=config.TRANSFORM,\n        )\n        dl_holdout = DataLoader(dataset_holdout, batch_size=1, shuffle=False)\n        results = evaluate_models_alternating(\n            sampler_a, sampler_b, dl_holdout, config.DEVICE, config.NUM_PREV_FRAMES\n        )\n        if results['B']['avg_mse'] < results['A']['avg_mse']:\n            os.replace(new_ckpt, ckpt_path)\n        else:\n            os.remove(new_ckpt)\n    else:\n        os.replace(new_ckpt, ckpt_path)\n\n    # After training, permanently add new sessions to the full dataset\n    old_len = len(full_ds)\n    combine_sessions_append(config.SESSION_DATA_DIR, config.IMAGE_DIR, config.CSV_PATH)\n    updated_ds = JetbotDataset(config.CSV_PATH, config.DATA_DIR, config.IMAGE_SIZE, config.NUM_PREV_FRAMES, transform=config.TRANSFORM)\n    new_indices = range(old_len, len(updated_ds))\n    replay_ds.dataset = updated_ds\n    replay_ds.add_episode(new_indices)\n\nif __name__ == '__main__':\n    main()\n