# 01d: RL With Human Feedback (MNIST)

This notebook demonstrates a small interactive loop that combines:

- A conditional variational autoencoder (VAE) for digit generation (unsupervised pretraining).
- A lightweight policy over latent codes updated with REINFORCE from human thumbs-up/down on generated images.
- A simple classifier trained both with standard supervision (when available) and policy-gradient style updates from human correctness feedback.

The goal is to provide an interactive demo where a human guides both generative and discriminative behavior via feedback.


## Setup
- Uses the repo's `mlp.data_providers.MNISTDataProvider` (expects `MLP_DATA_DIR` with `mnist-*.npz`).
- Falls back to torchvision MNIST if data provider is unavailable.
- Interactivity uses `ipywidgets`. If widgets don't render, install `ipywidgets` and enable JupyterLab extension if needed.
- Training is kept intentionally small for demo speed.


In [None]:
# Imports and environment
import os, math, time, random, sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Try to import torch + widgets
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.distributions import Categorical
except Exception as e:
    print('PyTorch is required for this demo. Please install torch.')
    raise

try:
    import ipywidgets as widgets
    from IPython.display import display, clear_output
except Exception as e:
    print('ipywidgets not available. Install with `pip install ipywidgets`.')
    raise

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Ensure MLP_DATA_DIR is set if possible by searching common locations
if 'MLP_DATA_DIR' not in os.environ:
    candidates = [
        Path.cwd() / 'data',
        Path.cwd() / 'notebooks' / 'res',
        Path.cwd().parent / 'data'
    ]
    for cand in candidates:
        if (cand / 'mnist-train.npz').exists():
            os.environ['MLP_DATA_DIR'] = str(cand.resolve())
            break
print('Using MLP_DATA_DIR =', os.environ.get('MLP_DATA_DIR', '<unset>'))


In [None]:
# Data loading: prefer repo data provider, fallback to torchvision
use_provider = False
train_x = valid_x = train_y = valid_y = None

try:
    from mlp.data_providers import MNISTDataProvider
    use_provider = True
    print('Loading MNIST via MNISTDataProvider...')
    train_dp = MNISTDataProvider('train', batch_size=256, max_num_batches=200, shuffle_order=True)
    valid_dp = MNISTDataProvider('valid', batch_size=256, max_num_batches=50, shuffle_order=False)
    # Aggregate a subset into memory for quick demo loops
    def dp_to_arrays(dp):
        xs, ys = [], []
        for xb, yb in dp:
            xs.append(xb.astype(np.float32))
            ys.append(yb.astype(np.float32))
        X = np.vstack(xs)
        Y = np.vstack(ys)
        return X, Y
    train_x, train_y = dp_to_arrays(train_dp)
    valid_x, valid_y = dp_to_arrays(valid_dp)
    print('Train/Valid shapes:', train_x.shape, train_y.shape, '|', valid_x.shape, valid_y.shape)
except Exception as e:
    print('Falling back to torchvision MNIST due to:', e)
    from torchvision import datasets, transforms
    tfm = transforms.Compose([transforms.ToTensor()])
    mnist_train = datasets.MNIST(root=str(Path.cwd()/'data'), train=True, download=True, transform=tfm)
    mnist_test = datasets.MNIST(root=str(Path.cwd()/'data'), train=False, download=True, transform=tfm)
    # Subset for speed
    def subset(ds, n=15000):
        idx = np.random.choice(len(ds), size=n, replace=False)
        xs, ys = [], []
        for i in idx:
            x, y = ds[i]
            xs.append(x.view(-1).numpy())
            one = np.zeros(10, dtype=np.float32); one[y] = 1.0
            ys.append(one)
        return np.stack(xs).astype(np.float32), np.stack(ys).astype(np.float32)
    train_x, train_y = subset(mnist_train, n=15000)
    valid_x, valid_y = subset(mnist_test, n=3000)
    print('Train/Valid shapes:', train_x.shape, train_y.shape, '|', valid_x.shape, valid_y.shape)

IMG_SHAPE = (28, 28)
INPUT_DIM = 28*28
NUM_CLASSES = 10
def to_img(arr):
    return arr.reshape(IMG_SHAPE)

# Torch tensors
train_X_t = torch.from_numpy(train_x).to(DEVICE)
train_Y_t = torch.from_numpy(train_y).to(DEVICE)
valid_X_t = torch.from_numpy(valid_x).to(DEVICE)
valid_Y_t = torch.from_numpy(valid_y).to(DEVICE)


## Conditional VAE (unsupervised pretraining)
We train a small VAE conditioned on digit class. This serves as a base generator; RL will later select latent codes based on human reward.


In [None]:
# Conditional VAE
class CondVAE(nn.Module):
    def __init__(self, input_dim=784, num_classes=10, z_dim=16, hidden=400):
        super().__init__()
        self.z_dim = z_dim
        self.num_classes = num_classes
        self.enc1 = nn.Linear(input_dim + num_classes, hidden)
        self.enc_mu = nn.Linear(hidden, z_dim)
        self.enc_logvar = nn.Linear(hidden, z_dim)
        self.dec1 = nn.Linear(z_dim + num_classes, hidden)
        self.dec_out = nn.Linear(hidden, input_dim)
    def encode(self, x, y_onehot):
        h = F.relu(self.enc1(torch.cat([x, y_onehot], dim=-1)))
        mu = self.enc_mu(h)
        logvar = self.enc_logvar(h)
        return mu, logvar
    def reparam(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    def decode(self, z, y_onehot):
        h = F.relu(self.dec1(torch.cat([z, y_onehot], dim=-1)))
        logits = self.dec_out(h)
        return logits
    def forward(self, x, y_onehot):
        mu, logvar = self.encode(x, y_onehot)
        z = self.reparam(mu, logvar)
        logits = self.decode(z, y_onehot)
        return logits, mu, logvar

def vae_loss(x, logits, mu, logvar):
    # Bernoulli likelihood over pixels
    recon = F.binary_cross_entropy_with_logits(logits, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (recon + kld) / x.size(0)

vae = CondVAE(input_dim=INPUT_DIM, num_classes=NUM_CLASSES, z_dim=16, hidden=400).to(DEVICE)
opt_vae = torch.optim.Adam(vae.parameters(), lr=1e-3)

def train_vae(epochs=3, batch=256):
    vae.train()
    N = train_X_t.size(0)
    idxs = torch.randperm(N, device=DEVICE)
    for ep in range(epochs):
        total = 0.0
        for i in range(0, N, batch):
            sel = idxs[i:i+batch]
            xb = train_X_t[sel]
            yb = train_Y_t[sel]
            logits, mu, logvar = vae(xb, yb)
            loss = vae_loss(xb, logits, mu, logvar)
            opt_vae.zero_grad(); loss.backward(); opt_vae.step()
            total += loss.item() * xb.size(0)
        print(f'Epoch {ep+1}: loss={total/N:.4f}')

train_vae(epochs=3, batch=256)  # keep small for demo


## Generator Policy over Latents (RL)
We freeze the decoder and learn a policy that selects among K latent codes per class to maximize human reward (thumbs up).


In [None]:
# Freeze VAE decoder for RL phase
for p in vae.parameters():
    p.requires_grad_(False)
vae.eval()

K = 8  # number of latent options per class
class LatentPolicy(nn.Module):
    def __init__(self, num_classes=10, z_dim=16, K=8):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(num_classes, K))
        self.latents = nn.Parameter(0.1 * torch.randn(num_classes, K, z_dim))
        self.baseline = torch.zeros(num_classes)  # moving avg baseline (not a Parameter)
    def sample(self, c_idx):
        probs = F.softmax(self.logits[c_idx], dim=-1)
        dist = Categorical(probs=probs)
        a = dist.sample()
        z = self.latents[c_idx, a]
        logp = dist.log_prob(a)
        return z, a, logp

gen_policy = LatentPolicy(NUM_CLASSES, z_dim=vae.z_dim, K=K).to(DEVICE)
opt_policy = torch.optim.Adam(gen_policy.parameters(), lr=5e-3)

def decode_logits(z, c):
    # c is int
    y = torch.zeros(1, NUM_CLASSES, device=DEVICE); y[0, c] = 1.0
    with torch.no_grad():
        logits = vae.decode(z.view(1,-1), y)
    return logits.view(-1)

def show_img(vec, title=None):
    img = vec.detach().cpu().numpy().reshape(28,28)
    plt.figure(figsize=(2,2)); plt.axis('off');
    plt.imshow(img, cmap='gray');
    if title: plt.title(title);
    plt.show()

# Interactive widgets for generation
c_dropdown = widgets.Dropdown(options=[(str(i), i) for i in range(10)], value=0, description='Digit:')
btn_generate = widgets.Button(description='Generate', button_style='')
btn_up = widgets.Button(description='👍', button_style='success')
btn_down = widgets.Button(description='👎', button_style='danger')
out_gen = widgets.Output()

state = {'last': None}  # store (c, z, a, logp, img_probs)

def on_generate(_):
    out_gen.clear_output()
    c = int(c_dropdown.value)
    z, a, logp = gen_policy.sample(c)
    logits = decode_logits(z, c)
    probs = torch.sigmoid(logits)
    with out_gen:
        show_img(probs, title=f'Generated digit {c} (option {int(a)})')
    state['last'] = (c, z.detach(), a.detach(), logp, probs.detach())

def rl_update(reward):
    item = state.get('last', None)
    if item is None:
        return
    c, z, a, logp, probs = item
    c_idx = c
    # Simple moving average baseline per class to reduce variance
    b = gen_policy.baseline[c_idx].item()
    gen_policy.baseline[c_idx] = 0.9 * gen_policy.baseline[c_idx] + 0.1 * reward
    adv = reward - b
    loss = -(adv) * logp
    opt_policy.zero_grad(); loss.backward(); opt_policy.step()

def on_up(_): rl_update(1.0)
def on_down(_): rl_update(0.0)

btn_generate.on_click(on_generate)
btn_up.on_click(on_up)
btn_down.on_click(on_down)
display(widgets.HBox([c_dropdown, btn_generate, btn_up, btn_down]))
display(out_gen)


## Classifier + Human Feedback
The classifier predicts a digit for shown images. Provide feedback:
- If incorrect and true label is known, pick the correct label and we train with cross-entropy.
- Otherwise, we can apply a simple policy-gradient update with reward 1/0 for correct/incorrect on the sampled class.


In [None]:
# Simple classifier
class Classifier(nn.Module):
    def __init__(self, input_dim=784, hidden=128, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden)
        self.fc2 = nn.Linear(hidden, num_classes)
    def forward(self, x):
        h = F.relu(self.fc1(x))
        logits = self.fc2(h)
        return logits

clf = Classifier(INPUT_DIM, 128, NUM_CLASSES).to(DEVICE)
opt_clf = torch.optim.Adam(clf.parameters(), lr=1e-3)

# Optionally warm-start the classifier a bit for nicer demos
def warmstart_clf(steps=200, batch=256):
    clf.train()
    N = train_X_t.size(0)
    for _ in range(steps):
        idx = torch.randint(0, N, (batch,), device=DEVICE)
        xb = train_X_t[idx]
        yb = train_Y_t[idx]
        logits = clf(xb)
        target = yb.argmax(dim=-1)
        loss = F.cross_entropy(logits, target)
        opt_clf.zero_grad(); loss.backward(); opt_clf.step()
    print('Warm-start complete')

warmstart_clf(steps=200, batch=256)

# Interactive classifier loop
btn_next = widgets.Button(description='Next sample')
btn_correct = widgets.Button(description='Correct', button_style='success')
btn_incorrect = widgets.Button(description='Incorrect', button_style='danger')
dd_true = widgets.Dropdown(options=[(str(i), i) for i in range(10)], description='True:')
out_clf = widgets.Output()

clf_state = {'xb': None, 'yb': None, 'pred': None, 'logp': None, 'act': None}

def draw_sample():
    i = np.random.randint(0, valid_X_t.size(0))
    x = valid_X_t[i:i+1]
    y = valid_Y_t[i:i+1]
    return x, y

def show_clf(x, pred, prob):
    with out_clf:
        clear_output(wait=True)
        plt.figure(figsize=(2,2)); plt.axis('off');
        plt.imshow(x.view(28,28).cpu(), cmap='gray');
        plt.title(f'Pred: {pred} (p={prob:.2f})')
        plt.show()

def on_next(_):
    x, y = draw_sample()
    logits = clf(x)
    probs = F.softmax(logits, dim=-1)
    dist = Categorical(probs=probs)
    act = dist.sample()
    logp = dist.log_prob(act)
    pred = int(act.item())
    prob = float(probs[0, pred].item())
    clf_state.update({'xb': x, 'yb': y, 'pred': pred, 'logp': logp, 'act': act})
    show_clf(x, pred, prob)

def on_correct(_):
    item = clf_state
    if item['logp'] is None: return
    # Reward 1 for sampled action
    loss = -1.0 * item['logp']
    opt_clf.zero_grad(); loss.backward(); opt_clf.step()

def on_incorrect(_):
    item = clf_state
    if item['xb'] is None: return
    # If user provides true label, use cross-entropy; else negative reward
    if dd_true.value is not None:
        logits = clf(item['xb'])
        target = torch.tensor([dd_true.value], device=DEVICE)
        loss = F.cross_entropy(logits, target)
        opt_clf.zero_grad(); loss.backward(); opt_clf.step()
    else:
        loss = -0.0 * item['logp']  # zero reward (no update)

btn_next.on_click(on_next)
btn_correct.on_click(on_correct)
btn_incorrect.on_click(on_incorrect)
display(widgets.HBox([btn_next, btn_correct, btn_incorrect, dd_true]))
display(out_clf)


## Notes
- VAE pretraining is minimal; increase epochs for better quality.
- Generator RL updates only the latent-selection policy (and latent codes), not the decoder.
- Classifier RL uses sampled actions for proper policy-gradient updates; cross-entropy is used when a true label is provided.
- For persistent improvements, consider saving/loading policy and classifier weights.
