# What Happens When You Train a CNN to Think like an LLM
## Understanding DPO with Digit Classification

Direct Preference Optimization (DPO) has been the go-to technique for aligning LLMs. It's elegant, it works, and it's everywhere. But most discussions live in the world of transformers, tokenization, and text generation.

**The Experiment:**
I wanted to strip away all that complexity and see the DPO algorithm in its purest form. So, I trained a **Convolutional Neural Net (CNN)** with DPO to identify handwritten digits (MNIST).

**What is DPO Really?**
Imagine you're teaching a kid to identify animals. Instead of just saying "this is correct" or "this is wrong," you show them pairs: "This fluffy thing is more likely a dog than a cat" or "This one with stripes is definitely a zebra, not a horse."

That's DPO in a nutshell. Rather than training on absolute labels, you're training on **preferences**. You're telling the model: *"Between these two options, prefer this one over that one."*

In [None]:
# 1. Setup & Imports
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.metrics import accuracy_score, f1_score, ConfusionMatrixDisplay
from tqdm.auto import tqdm
import tensorflow as tf # For loading original MNIST easily
import os

# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. The Data (MNIST)
MNIST is the data equivalent of "Hello World" in machine learning: 70,000 handwritten digits (0–9).

The task for a typical MNIST training is simple: *"Look at an image and tell me which digit it is."*

**Our DPO training is different:** *"Whenever you see an image, increase its probability for the preferred choice over the rejected one."*

In [None]:
# Load standard MNIST using TensorFlow/Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Normalize pixel values to [0, 1]
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

print(f"Training data shape: {x_train.shape}")
print(f"Test data shape: {x_test.shape}")

# Helper function to plot images
def plot_mnist_image(index, dataset, title=None):
    """Plots a single MNIST image from a numpy array or torch tensor."""
    image = dataset[index]
    if isinstance(image, torch.Tensor):
        image = image.squeeze().cpu().numpy()
    elif isinstance(image, np.ndarray) and image.ndim == 3:
        image = image.squeeze()
    
    plt.imshow(image, cmap='gray')
    if title:
        plt.title(title)
    else:
        plt.title(f"Index: {index}")
    plt.axis('off')
    plt.show()

### Corrupting the Data
To make this interesting, we need "hard" examples. We will synthetically corrupt the MNIST images by adding Gaussian noise.

What was once a crisp "8" will now look like it went through a blender.

In [None]:
# Generate Synthetic Noisy Data
print("Generating noisy data...")

# Add Gaussian noise
noise_factor = 0.5
aug_mnist_noise = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
aug_mnist_noise_test = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)

# Clip to valid range [0, 1]
aug_mnist_noise = np.clip(aug_mnist_noise, 0., 1.)
aug_mnist_noise_test = np.clip(aug_mnist_noise_test, 0., 1.)

# Labels remain the same
aug_labels = y_train
aug_test_labels = y_test

print(f"Augmented Noise Data Shape: {aug_mnist_noise.shape}")

# Visualize a noisy sample
plot_mnist_image(0, aug_mnist_noise, title=f"Noisy Label: {aug_labels[0]}")

## 3. The Key Components
To apply DPO to any model, you need three key ingredients:

1.  **The Policy (Main Model) $\pi_\theta$**: This is your student model. It starts off knowing nothing about preferences and learns through the DPO process. In our case, it's a fresh CNN.
2.  **The Reference Model $\pi_{ref}$**: Think of this as your anchor point. It's a frozen, pre-trained model that provides a baseline. The policy model learns to deviate from this reference in the "right" direction.
3.  **Preferences**: Pairs where one is preferred over the other.
    *   $y_w$: For the preferred choice (winner).
    *   $y_l$: For the rejected choice (loser).

### Step 1: Train the Reference Model ($\pi_{ref}$)
We train a standard CNN on clean MNIST. Nothing fancy — just a single convolutional layer and a classifier head.

In [None]:
class PyTorchRefModel(nn.Module):
    def __init__(self):
        super(PyTorchRefModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * 28 * 28, 10) # 64 channels * 28x28 image size

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Instantiate the model
ref_model = PyTorchRefModel().to(device)
print("Reference model initialized.")

# Prepare DataLoaders for Clean Data
x_train_tensor = torch.tensor(x_train).unsqueeze(1).float()
y_train_tensor = torch.tensor(y_train).long()
train_loader = DataLoader(TensorDataset(x_train_tensor, y_train_tensor), batch_size=64, shuffle=True)

x_test_tensor = torch.tensor(x_test).unsqueeze(1).float()
y_test_tensor = torch.tensor(y_test).long()
test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=64, shuffle=False)

# Training Function
def train_model(model, loader, epochs=2):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1} Loss: {running_loss/len(loader):.4f}")

# Train the Reference Model
print("Training Reference Model on Clean Data...")
train_model(ref_model, train_loader, epochs=2)

## 4. Manufacturing Preferences
To make this work without human labeling, we need a strategy to generate preference pairs ($y_w$, $y_l$).

**The Strategy:**
1.  Feed the **corrupted images** to the Reference Model ($\pi_{ref}$).
2.  Because the images are so noisy, the Reference Model will often fail (confidently wrong).
3.  For every corrupted image where $\pi_{ref}$ fails, we create a pair:
    *   **$y_w$ (Winner)**: The Ground Truth label (what the digit really is).
    *   **$y_l$ (Loser)**: The Reference Model's incorrect prediction.

**The Instruction:** "On this noisy image, prefer the ground truth ($y_w$) over the reference model's mistake ($y_l$)."

In [None]:
# Prepare Noisy Data Loader
aug_images = aug_mnist_noise

aug_tensor = torch.tensor(aug_images).unsqueeze(1).float()
aug_labels_tensor = torch.tensor(aug_labels).long()
aug_loader = DataLoader(TensorDataset(aug_tensor, aug_labels_tensor), batch_size=64, shuffle=False)

# Run Inference to find mistakes
ref_model.eval()
all_preds = []
print("Running inference on noisy data to generate preference pairs...")
with torch.no_grad():
    for inputs, _ in aug_loader:
        inputs = inputs.to(device)
        outputs = ref_model(inputs)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())

all_preds = np.array(all_preds)
incorrect_mask = all_preds != aug_labels

# Create DPO Training Data
dpo_images = aug_tensor[incorrect_mask]
dpo_y_plus = aug_labels_tensor[incorrect_mask]       # Chosen (Correct)
dpo_y_minus = torch.tensor(all_preds[incorrect_mask]).long() # Rejected (Wrong)

print(f"Total Noisy Images: {len(aug_images)}")
print(f"Incorrect Predictions (DPO Samples): {len(dpo_images)}")

# Visualize a mistake
if len(dpo_images) > 0:
    idx = 0
    plot_mnist_image(idx, dpo_images, title=f"True: {dpo_y_plus[idx].item()}, Pred: {dpo_y_minus[idx].item()}")

## 5. DPO Training
Now we train the Policy Model using the DPO loss function.

### The DPO Formula
Don't let the math scare you. Here is what it actually says:

$$ \mathcal{L}_{DPO}(\pi_\theta; \pi_{ref}) = -\mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{ref}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{ref}(y_l|x)} \right) \right] $$

**Translation:** "Make the Policy Model assign a higher probability to the Winner ($y_w$) and a lower probability to the Loser ($y_l$) *relative* to what the Reference Model did."

The variable $\beta$ is a **temperature parameter** that controls how tight the leash is. It determines how far the policy model is allowed to wander away from the reference model's behavior.

In [None]:
# DPO Dataset Class
class DPODataset(Dataset):
    def __init__(self, images, y_plus, y_minus):
        self.images = images
        self.y_plus = y_plus
        self.y_minus = y_minus
    def __len__(self): return len(self.images)
    def __getitem__(self, idx): return self.images[idx], self.y_plus[idx], self.y_minus[idx]

dpo_loader = DataLoader(DPODataset(dpo_images, dpo_y_plus, dpo_y_minus), batch_size=64, shuffle=True)

# Initialize Policy Model (Clone of Ref Model)
policy_model = PyTorchRefModel().to(device)
policy_model.load_state_dict(ref_model.state_dict())

# DPO Loss Function
def dpo_loss(policy_logits, ref_logits, y_chosen, y_rejected, beta=0.1):
    # Convert logits to log probabilities
    policy_log_probs = F.log_softmax(policy_logits, dim=1)
    ref_log_probs = F.log_softmax(ref_logits, dim=1)
    
    # Gather log probs for chosen and rejected classes
    policy_chosen_log_prob = policy_log_probs.gather(1, y_chosen.unsqueeze(1)).squeeze()
    policy_rejected_log_prob = policy_log_probs.gather(1, y_rejected.unsqueeze(1)).squeeze()
    
    ref_chosen_log_prob = ref_log_probs.gather(1, y_chosen.unsqueeze(1)).squeeze()
    ref_rejected_log_prob = ref_log_probs.gather(1, y_rejected.unsqueeze(1)).squeeze()
    
    # Calculate logits for the sigmoid
    # (log(pi(yw)/ref(yw)) - log(pi(yl)/ref(yl)))
    logits = beta * ((policy_chosen_log_prob - ref_chosen_log_prob) - (policy_rejected_log_prob - ref_rejected_log_prob))
    
    # Loss is negative log sigmoid
    loss = -F.logsigmoid(logits).mean()
    return loss

# DPO Training Loop
optimizer = optim.Adam(policy_model.parameters(), lr=0.001)
print("Starting DPO Training...")

ref_model.eval() # Reference model must be frozen/eval mode
policy_model.train()

for epoch in range(5):
    running_loss = 0.0
    for inputs, y_c, y_r in tqdm(dpo_loader, desc=f"DPO Epoch {epoch+1}"):
        inputs, y_c, y_r = inputs.to(device), y_c.to(device), y_r.to(device)
        
        optimizer.zero_grad()
        policy_logits = policy_model(inputs)
        
        # Get ref logits (no grad)
        with torch.no_grad():
            ref_logits = ref_model(inputs)
            
        loss = dpo_loss(policy_logits, ref_logits, y_c, y_r)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1} DPO Loss: {running_loss/len(dpo_loader):.4f}")

## 6. Results and Key Insights
We compare the accuracy of the original Reference Model and the new Policy Model on the noisy test set.

**Why does this matter?**
You might ask: *"Why not just train a supervised model on the noisy data?"*

Here is the critical distinction:
*   **Supervised Learning** asks: *"What is the truth?"* It tries to maximize the probability of the correct digit (e.g., 7).
*   **DPO** asks: *"What is preferred?"* It explicitly tries to widen the gap between the correct digit (7) and the *specific error* the reference model makes (e.g., 1).

In objective tasks like digit classification, "truth" is all you need. But in subjective tasks (like LLM responses), "truth" is fuzzy, and "preference" becomes the most powerful signal we have.

In [None]:
def evaluate(model, loader, name="Model"):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = correct / total
    print(f"{name} Accuracy on Noisy Test Data: {acc:.4f}")

# Prepare Noisy Test Loader
aug_test_images = aug_mnist_noise_test
aug_test_tensor = torch.tensor(aug_test_images).unsqueeze(1).float()
aug_test_labels_tensor = torch.tensor(aug_test_labels).long()
aug_test_loader = DataLoader(TensorDataset(aug_test_tensor, aug_test_labels_tensor), batch_size=64, shuffle=False)

# Compare
print("--- Results ---")
evaluate(ref_model, aug_test_loader, "Reference Model (Baseline)")
evaluate(policy_model, aug_test_loader, "Policy Model (DPO Trained)")

## Final Words
This experiment wasn't really about handwritten digits. It was about demystifying the algorithm that currently aligns the world's most powerful models.

We proved that **DPO isn't some magic that only lives inside a Transformer**. It is a model-agnostic mathematical framework.
*   In our case, we used it to help a CNN navigate **visual noise** (blur).
*   In the real world, researchers use the exact same math to help LLMs navigate **semantic noise** (hallucinations, toxicity, and style errors).

By stripping away the complexity of tokenizers and billions of parameters, we can see DPO for what it truly is: **A way to teach a model not just what is right, but specifically what is wrong.**