# Energy-Based Model (EBM) on MNIST
This notebook implements a basic EBM using PyTorch and MNIST.

In [1]:
# Install and import dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

import torch.optim as optim

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
BATCH_SIZE = 128
LEARNING_RATE = 0.0001
num_epochs = 10


In [4]:
# ------------------------------------------
# Load the MNIST dataset in PyTorch
# ------------------------------------------

# Step 1: Define a transform to convert images to PyTorch tensors
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts images to range [0, 1] as float32 tensors
])

# Step 2: Download and load the training and test sets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Step 3: Create DataLoaders for easy batch access
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [5]:
# ------------------------------------------
# Custom preprocessing function (if needed)
# ------------------------------------------

def preprocess_pytorch(images):
    """
    Normalize and reshape the images similar to your original TensorFlow preprocessing.

    - Normalize pixel values from [0, 255] → [-1, 1]
    - Pad from 28x28 → 32x32 with constant value -1
    - Add a channel dimension if not present
    """

    # Scale from [0, 1] to [-1, 1]
    images = images * 2 - 1

    # images: shape (batch_size, 1, 28, 28) → pad to (batch_size, 1, 32, 32)
    images = F.pad(images, pad=(2, 2, 2, 2), mode='constant', value=-1.0)

    return images

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),         # Converts image to [0, 1] tensor
    transforms.Lambda(preprocess_pytorch)  # Then normalize and pad
])

In [7]:
# Download MNIST dataset
mnist_data = datasets.MNIST(root='.', train=True, download=True)
x_train = mnist_data.data.numpy()

mnist_test = datasets.MNIST(root='.', train=False, download=True)
x_test = mnist_test.data.numpy()

In [8]:
# Convert numpy arrays to PyTorch tensors and add channel dimension
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).unsqueeze(1)  # shape: [B, 1, 28, 28]
x_test_tensor = torch.tensor(x_test, dtype=torch.float32).unsqueeze(1)    # shape: [B, 1, 28, 28]

# Create TensorDatasets from tensors
train_dataset = TensorDataset(x_train_tensor)
test_dataset = TensorDataset(x_test_tensor)

# Wrap datasets in DataLoader to enable batching and shuffling
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [9]:
# Define a PyTorch EBM model similar to your TensorFlow implementation
class EBM(nn.Module):
    def __init__(self, image_size=28, channels=1):
        super(EBM, self).__init__()

        # Convolutional layers with Swish (SiLU) activations
        self.conv1 = nn.Conv2d(channels, 16, kernel_size=5, stride=2, padding=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)

        # Calculate the flattened output size after 4 conv layers
        conv_output_size = image_size // (2**4)  # 4 strides of 2
        flattened_dim = 64 * conv_output_size * conv_output_size

        # Dense layers
        self.fc1 = nn.Linear(flattened_dim, 64)
        self.fc2 = nn.Linear(64, 1)  # Single energy output

    def forward(self, x):
        # Swish activation is available as F.silu in PyTorch
        x = F.silu(self.conv1(x))
        x = F.silu(self.conv2(x))
        x = F.silu(self.conv3(x))
        x = F.silu(self.conv4(x))

        x = x.view(x.size(0), -1)  # Flatten for dense layers
        x = F.silu(self.fc1(x))
        energy = self.fc2(x)  # Output energy (unnormalized score)
        return energy

In [10]:
def generate_samples(
    model,                 # The energy-based model
    inp_imgs,              # Initial images (random noise or seeds)
    steps,                 # Number of Langevin steps
    step_size,             # Step size (learning rate)
    noise,                 # Stddev of added Gaussian noise
    return_img_per_step=False,  # Whether to save images at each step
):
    imgs_per_step = []

    inp_imgs = inp_imgs.clone().detach().to(device).requires_grad_(True)

    for _ in range(steps):
        # Step 1: Add Gaussian noise to encourage exploration
        inp_imgs.data += torch.randn_like(inp_imgs) * noise

        # Step 2: Clamp values to stay in [-1, 1] range (MNIST normalized)
        inp_imgs.data = torch.clamp(inp_imgs.data, -1.0, 1.0)

        # Step 3: Forward pass to compute score (energy)
        out_score = model(inp_imgs)

        # Step 4: Compute gradient of score w.r.t. input image
        grads = torch.autograd.grad(outputs=out_score.sum(), inputs=inp_imgs)[0]

        # Step 5: Clip gradients for stability
        grads = torch.clamp(grads, -GRADIENT_CLIP, GRADIENT_CLIP)

        # Step 6: Gradient ascent step on input image
        inp_imgs.data += step_size * grads

        # Step 7: Clamp again to stay in valid range
        inp_imgs.data = torch.clamp(inp_imgs.data, -1.0, 1.0)

        if return_img_per_step:
            imgs_per_step.append(inp_imgs.detach().clone())

    if return_img_per_step:
        return torch.stack(imgs_per_step, dim=0)
    else:
        return inp_imgs.detach()

In [11]:
class EBM(nn.Module):
    def __init__(self, base_model, alpha=0.1):
        super().__init__()
        self.model = base_model  # scoring network
        self.alpha = alpha       # regularization weight

    def forward(self, x):
        return self.model(x).squeeze()

def compute_loss(model, real_imgs, steps, step_size, noise_scale):
    """
    Contrastive divergence loss between real data and fake (noise) samples.
    """
    batch_size = real_imgs.size(0)

    # Generate fake images from random noise
    fake_imgs = torch.empty_like(real_imgs).uniform_(-1, 1).to(real_imgs.device)
    fake_imgs.requires_grad = True

    # Langevin dynamics steps (optional: here 0 steps means no update)
    for _ in range(steps):
        fake_imgs.data += torch.randn_like(fake_imgs) * noise_scale
        fake_imgs.data = torch.clamp(fake_imgs.data, -1, 1)

        energy = model(fake_imgs)
        grads = torch.autograd.grad(energy.sum(), fake_imgs, create_graph=True)[0]
        fake_imgs.data += step_size * grads
        fake_imgs.data = torch.clamp(fake_imgs.data, -1, 1)

    # Get scores
    real_scores = model(real_imgs)
    fake_scores = model(fake_imgs.detach())

    # Contrastive Divergence (CD-1) Loss
    cdiv_loss = fake_scores.mean() - real_scores.mean()

    # Regularization: penalize high scores
    reg_loss = model.alpha * ((real_scores ** 2).mean() + (fake_scores ** 2).mean())

    total_loss = cdiv_loss + reg_loss

    return total_loss, cdiv_loss.item(), reg_loss.item(), real_scores.mean().item(), fake_scores.mean().item()

In [12]:
class ScoreNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # [batch, 32, 14, 14]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # [batch, 64, 7, 7]
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Final scalar output (energy score)
        )

    def forward(self, x):
        return self.net(x)

In [13]:
base_model = ScoreNet()
ebm = EBM(base_model=base_model, alpha=0.1)

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ebm = ebm.to(device)  # Move the model to the device

In [17]:
optimizer = optim.Adam(ebm.parameters(), lr=LEARNING_RATE)

In [18]:
for epoch in range(num_epochs):
    for batch in train_loader:
        if isinstance(batch, (list, tuple)):
            real_images = batch[0]
        else:
            real_images = batch

        real_images = real_images.to(device)
        real_images.requires_grad = True

        # Forward pass
        scores = ebm(real_images)
        loss = -scores.mean()  # Basic negative log-score loss

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")

Epoch 1/10, Loss: -1949711.2500
Epoch 2/10, Loss: -34621388.0000
Epoch 3/10, Loss: -168554992.0000
Epoch 4/10, Loss: -460923584.0000
Epoch 5/10, Loss: -974023040.0000
Epoch 6/10, Loss: -1813268352.0000
Epoch 7/10, Loss: -2972573696.0000
Epoch 8/10, Loss: -4526898176.0000
Epoch 9/10, Loss: -6693363200.0000
Epoch 10/10, Loss: -9541715968.0000
