# Reward Estimator Training – **ResNet‑18 + Multi‑Frame + Reward Labels**  (v4)
Now with live **progress bars** and a quick sanity check that the GPU is actually used.
*If your MSI Afterburner still shows 0 % GPU, scroll to the first code cell – it prints what Torch thinks the current device is and whether CUDA is available.*

In [1]:
import os, random, numpy as np, pandas as pd, torch, torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from tqdm.auto import tqdm
from PIL import Image
from pathlib import Path
import config
from models import RewardEstimatorResNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device, '| CUDA visible →', torch.cuda.is_available())

Using device: cuda | CUDA visible → True


In [2]:
# ---------------- Hyper‑parameters / paths ----------------
# CHECKPOINT_DIR = Path("outputs/reward_estimator_resnet")
# CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

REWARD_MODEL_OUTPUT_DIR = os.path.join(config.OUTPUT_DIR, "reward_estimator")
os.makedirs(REWARD_MODEL_OUTPUT_DIR, exist_ok=True)

REWARD_CSV_PATH = config.MANUAL_COLLECTED_REWARD_CSV   # labels CSV
MAIN_CSV_PATH   = config.CSV_PATH                      # master frames CSV
MAIN_DATA_DIR   = config.DATA_DIR                      # image folder

# print("CHECKPOINT_DIR           →", CHECKPOINT_DIR.resolve())
print("REWARD_MODEL_OUTPUT_DIR  →", os.path.abspath(REWARD_MODEL_OUTPUT_DIR))
print("REWARD_CSV_PATH          →", REWARD_CSV_PATH)
print("MAIN_CSV_PATH            →", MAIN_CSV_PATH)
print("MAIN_DATA_DIR            →", MAIN_DATA_DIR)

config.NUM_PREV_FRAMES = 4               # N previous frames (→ 5‑frame input)
config.BATCH_SIZE      = 64
config.LR              = 3e-4
config.IMAGE_SIZE      = getattr(config, 'IMAGE_SIZE', 128)
print('Config ready ✨')

REWARD_MODEL_OUTPUT_DIR  → C:\Projects\jetbot-diffusion-world-model-kong-finder-aux\output_model_small_session_split_data\reward_estimator
REWARD_CSV_PATH          → C:\Projects\jetbot-diffusion-world-model-kong-finder-aux\jetbot_data_two_actions\interactive_reward_labels_subset.csv
MAIN_CSV_PATH            → C:\Projects\jetbot-diffusion-world-model-kong-finder-aux\jetbot_data_two_actions\data.csv
MAIN_DATA_DIR            → C:\Projects\jetbot-diffusion-world-model-kong-finder-aux\jetbot_data_two_actions
Config ready ✨


In [3]:
# ---------------- Dataset (unchanged from v3) ----------------
class StackedRewardDataset(Dataset):
    def __init__(self, main_csv_path, reward_csv_path, data_dir, image_size, num_prev_frames, transform=None):
        super().__init__()
        self.main_df   = pd.read_csv(main_csv_path)
        self.reward_df = pd.read_csv(reward_csv_path)
        self.data_dir  = data_dir
        self.image_size = image_size
        self.transform = transform
        self.num_prev_frames = num_prev_frames

        self.reward_map = dict(zip(self.reward_df['dataframe_index'], self.reward_df['assigned_reward']))
        self.valid_indices = [
            i for i in range(self.num_prev_frames, len(self.main_df))
            if i in self.reward_map and (
                'session_id' not in self.main_df.columns or
                self.main_df.iloc[i]['session_id'] == self.main_df.iloc[i - self.num_prev_frames]['session_id']
            )
        ]
        if not self.valid_indices:
            raise ValueError('No valid indices found')
        print(f'Dataset loaded → {len(self.valid_indices)} sequences with labels')

    def __len__(self):
        return len(self.valid_indices)

    def _load(self, rel):
        img = Image.open(os.path.join(self.data_dir, rel)).convert('RGB')
        return self.transform(img) if self.transform else transforms.ToTensor()(img)

    def __getitem__(self, idx):
        i = self.valid_indices[idx]
        reward = self.reward_map[i]
        curr_row = self.main_df.iloc[i]
        curr = self._load(curr_row['image_path'])
        prev = [self._load(self.main_df.iloc[i - off]['image_path']) for off in range(self.num_prev_frames, 0, -1)]
        stacked = torch.cat(prev + [curr], dim=0)
        return stacked, torch.tensor(reward, dtype=torch.float32)

In [4]:
# ---------------- Build loaders ----------------
tfm = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor()
])

full_ds = StackedRewardDataset(
    MAIN_CSV_PATH,        # ← new
    REWARD_CSV_PATH,      # ← new
    MAIN_DATA_DIR,        # ← new
    config.IMAGE_SIZE,
    config.NUM_PREV_FRAMES,
    tfm
)
train_ds, val_ds = random_split(full_ds, [int(0.8 * len(full_ds)), len(full_ds) - int(0.8 * len(full_ds))], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_ds, batch_size=config.BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=config.BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

print(f'Train/Val split → {len(train_ds)} / {len(val_ds)} samples')

Dataset loaded → 3164 sequences with labels
Train/Val split → 2531 / 633 samples


In [5]:
# ---------------- Model / Optim / AMP ----------------
model = RewardEstimatorResNet(n_frames=config.NUM_PREV_FRAMES + 1).to(device)
opt   = torch.optim.AdamW(model.parameters(), lr=config.LR)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
loss_fn = nn.MSELoss()

print(f'Param count: {sum(p.numel() for p in model.parameters())/1e6:.2f} M')

Param count: 11.21 M


In [7]:
# ---------------- Training loop with progress bars ----------------
EPOCHS   = 50
best_val = float("inf")

for epoch in range(1, EPOCHS + 1):
    # ── training ────────────────────────────────────────────────
    model.train(); running = 0.0
    for x, y in tqdm(train_loader,
                     desc=f"E{epoch:02}/train",
                     leave=False,
                     unit="batch"):
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            pred  = model(x).squeeze()
            loss  = loss_fn(pred, y)

        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
        running += loss.item() * x.size(0)

    train_loss = running / len(train_loader.dataset)

    # ── validation ──────────────────────────────────────────────
    model.eval(); running = 0.0
    with torch.no_grad():
        for x, y in tqdm(val_loader,
                         desc=f"E{epoch:02}/val ",
                         leave=False,
                         unit="batch"):
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                running += loss_fn(model(x).squeeze(), y).item() * x.size(0)

    val_loss = running / len(val_loader.dataset)

    # ── checkpoint (weights-only copy) ───────────────────────────
    if val_loss < best_val:
        best_val = val_loss
        torch.save(
            model.state_dict(),
            os.path.join(REWARD_MODEL_OUTPUT_DIR,
                         "best_reward_estimator_weights.pth")
        )
        checkpoint_flag = "✨  (best)"
    else:
        checkpoint_flag = ""

    # ── epoch summary line (no overwrite) ───────────────────────
    print(f"Epoch {epoch:02}/{EPOCHS} ▸ "
          f"train {train_loss:.4f}  val {val_loss:.4f}  "
          f"best {best_val:.4f}{checkpoint_flag}")


E01/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E01/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 01/50 ▸ train 0.0024  val 0.0072  best 0.0072✨  (best)


E02/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E02/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 02/50 ▸ train 0.0028  val 0.0049  best 0.0049✨  (best)


E03/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E03/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 03/50 ▸ train 0.0023  val 0.0050  best 0.0049


E04/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E04/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 04/50 ▸ train 0.0015  val 0.0044  best 0.0044✨  (best)


E05/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E05/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 05/50 ▸ train 0.0019  val 0.0056  best 0.0044


E06/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E06/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 06/50 ▸ train 0.0019  val 0.0043  best 0.0043✨  (best)


E07/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E07/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 07/50 ▸ train 0.0018  val 0.0039  best 0.0039✨  (best)


E08/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E08/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 08/50 ▸ train 0.0018  val 0.0038  best 0.0038✨  (best)


E09/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E09/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 09/50 ▸ train 0.0012  val 0.0034  best 0.0034✨  (best)


E10/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E10/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 10/50 ▸ train 0.0014  val 0.0037  best 0.0034


E11/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E11/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 11/50 ▸ train 0.0022  val 0.0068  best 0.0034


E12/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E12/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 12/50 ▸ train 0.0022  val 0.0067  best 0.0034


E13/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E13/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 13/50 ▸ train 0.0014  val 0.0047  best 0.0034


E14/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E14/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 14/50 ▸ train 0.0011  val 0.0039  best 0.0034


E15/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E15/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 15/50 ▸ train 0.0011  val 0.0037  best 0.0034


E16/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E16/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 16/50 ▸ train 0.0008  val 0.0035  best 0.0034


E17/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E17/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 17/50 ▸ train 0.0008  val 0.0034  best 0.0034✨  (best)


E18/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E18/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 18/50 ▸ train 0.0007  val 0.0037  best 0.0034


E19/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E19/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 19/50 ▸ train 0.0009  val 0.0036  best 0.0034


E20/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E20/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 20/50 ▸ train 0.0007  val 0.0036  best 0.0034


E21/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E21/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 21/50 ▸ train 0.0007  val 0.0041  best 0.0034


E22/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E22/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 22/50 ▸ train 0.0010  val 0.0035  best 0.0034


E23/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E23/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 23/50 ▸ train 0.0007  val 0.0036  best 0.0034


E24/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E24/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 24/50 ▸ train 0.0008  val 0.0054  best 0.0034


E25/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E25/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 25/50 ▸ train 0.0007  val 0.0036  best 0.0034


E26/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E26/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 26/50 ▸ train 0.0006  val 0.0045  best 0.0034


E27/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E27/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 27/50 ▸ train 0.0006  val 0.0036  best 0.0034


E28/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E28/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 28/50 ▸ train 0.0007  val 0.0037  best 0.0034


E29/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E29/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 29/50 ▸ train 0.0007  val 0.0059  best 0.0034


E30/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E30/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 30/50 ▸ train 0.0015  val 0.0036  best 0.0034


E31/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E31/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 31/50 ▸ train 0.0008  val 0.0037  best 0.0034


E32/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E32/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 32/50 ▸ train 0.0007  val 0.0035  best 0.0034


E33/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E33/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 33/50 ▸ train 0.0009  val 0.0040  best 0.0034


E34/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E34/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 34/50 ▸ train 0.0010  val 0.0034  best 0.0034✨  (best)


E35/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E35/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 35/50 ▸ train 0.0010  val 0.0033  best 0.0033✨  (best)


E36/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E36/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 36/50 ▸ train 0.0008  val 0.0050  best 0.0033


E37/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E37/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 37/50 ▸ train 0.0015  val 0.0034  best 0.0033


E38/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E38/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 38/50 ▸ train 0.0007  val 0.0033  best 0.0033


E39/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E39/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 39/50 ▸ train 0.0006  val 0.0047  best 0.0033


E40/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E40/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 40/50 ▸ train 0.0005  val 0.0036  best 0.0033


E41/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E41/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 41/50 ▸ train 0.0007  val 0.0044  best 0.0033


E42/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E42/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 42/50 ▸ train 0.0010  val 0.0041  best 0.0033


E43/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E43/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 43/50 ▸ train 0.0005  val 0.0036  best 0.0033


E44/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E44/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 44/50 ▸ train 0.0005  val 0.0035  best 0.0033


E45/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E45/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 45/50 ▸ train 0.0005  val 0.0034  best 0.0033


E46/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E46/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 46/50 ▸ train 0.0004  val 0.0033  best 0.0033


E47/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E47/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 47/50 ▸ train 0.0006  val 0.0039  best 0.0033


E48/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E48/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 48/50 ▸ train 0.0005  val 0.0032  best 0.0032✨  (best)


E49/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E49/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 49/50 ▸ train 0.0006  val 0.0043  best 0.0032


E50/train:   0%|          | 0/40 [00:00<?, ?batch/s]

E50/val :   0%|          | 0/10 [00:00<?, ?batch/s]

Epoch 50/50 ▸ train 0.0007  val 0.0034  best 0.0032


### Why the GPU might still sit idle
1. **CPU transforms bottleneck** – heavy PIL transforms can starve the GPU; enable more `num_workers`.
2. **Small network / batch** – ResNet‑18 + 128×128 images at BS = 64 may use <10 % GPU; try bigger batches.
3. **CUDA toolkit mismatch** – if `torch.cuda.is_available()` prints **False**, reinstall PyTorch with the correct CUDA build.
4. **Data pinned to CPU** – ensure `.to(device)` is called (this notebook does).

Monitor real‑time usage with `nvidia‑smi dmon` or MSI Afterburner while a training epoch is running.