# Energy-Based Models (EBM) - MNIST Generation and Classification

This notebook provides a guided walkthrough for running the Energy-Based Model (EBM) project. It covers data loading, model definition, training, and visualization of both real and generated images using Langevin dynamics.

## 1. Setup and Configuration

First, we'll set up the environment by importing necessary libraries and loading the project's configuration. The `config.py` file centralizes all hyperparameters, making it easy to experiment with different settings.

In [None]:
import torch
import torch.optim as optim
import logging
import os
import matplotlib.pyplot as plt

# Ensure the project root is in the path for imports
import sys
sys.path.insert(0, os.path.abspath('..'))

from config import Config
from src.data_loader import get_mnist_loaders
from src.models import EnergyNet
from src.training import train_ebm, eval_ebm, loss_function
from src.utils import visualize_real, visualize_generated, energy_gradient, langevin_dynamics_step, sample

# Setup logging to console (optional, main.py logs to file as well)
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Initialize configuration
config = Config()
logger.info(f"Current device: {config.DEVICE}.")

## 2. Load Dataset

We'll load the MNIST dataset, apply transformations (resizing and normalization to `[-1, 1]`), and create `DataLoader` instances for training and testing.

In [None]:
logger.info("Loading MNIST data...")
trainloader, testloader = get_mnist_loaders(
    root=config.DATA_ROOT,
    image_size=config.IMAGE_SIZE,
    batch_size=config.BATCH_SIZE,
    device=config.DEVICE
)
logger.info(f"MNIST data loaded. Train: {len(trainloader.dataset)} samples, Test: {len(testloader.dataset)} samples.")

# Visualize a batch of real images
logger.info("Visualizing real images from the dataset...")
visualize_real(trainloader, num_images=16, save_path=os.path.join(config.FIGURE_SAVE_PATH, "real_images_notebook.png"))

## 3. Initialize Model and Optimizer

Here, we instantiate our `EnergyNet` model and the Adam optimizer. The model will be moved to the appropriate device (GPU, MPS, or CPU) as determined by the configuration.

In [None]:
model = EnergyNet(
    input_dim=config.INPUT_DIM,
    output_dim=config.NUM_CLASSES,
    hidden_dims=config.HIDDEN_DIMS
).to(config.DEVICE)
logger.info(f"Model initialized:\n{model}")

optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
logger.info(f"Optimizer initialized with learning rate: {config.LEARNING_RATE}")

## 4. Train the EBM

We will now train the EBM using the `train_ebm` function. This function handles the training loop, loss calculation, optimization, and periodic visualization of generated samples.

In [None]:
logger.info("Starting EBM training...")
train_losses, val_losses = train_ebm(
    model=model,
    optimizer=optimizer,
    train_loader=trainloader,
    test_loader=testloader,
    epochs=config.EPOCHS,
    eta=config.LANGEVIN_STEPS,
    alpha=config.LANGEVIN_ALPHA,
    sigma=config.LANGEVIN_SIGMA,
    device=config.DEVICE,
    visualization_freq=config.VISUALIZATION_FREQ,
    figure_save_path=config.FIGURE_SAVE_PATH
)
logger.info("EBM training finished.")

## 5. Visualize Final Generated Images and Loss Curves

After training, we'll visualize the final set of generated images and plot the training and validation loss curves to assess the model's performance.

In [None]:
logger.info("Visualizing final generated images...")
visualize_generated(
    model=model,
    eta=config.LANGEVIN_STEPS,
    alpha=config.LANGEVIN_ALPHA,
    sigma=config.LANGEVIN_SIGMA,
    batch_size=config.BATCH_SIZE,
    image_size=config.IMAGE_SIZE,
    device=config.DEVICE,
    save_path=os.path.join(config.FIGURE_SAVE_PATH, "final_generated_images_notebook.png")
)

logger.info("Plotting loss curves...")
plt.figure(figsize=(10, 6))
plt.plot(range(1, config.EPOCHS + 1), train_losses, label="Training Loss")
plt.plot(range(1, config.EPOCHS + 1), val_losses, label="Validation Loss", linestyle='--')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss Curves")
plt.legend()
plt.grid()
plt.savefig(os.path.join(config.FIGURE_SAVE_PATH, "loss_curves_notebook.png"))
plt.show()
logger.info("Loss curves plotted and saved.")

# Optionally save the trained model
model_save_path = os.path.join(config.MODEL_SAVE_PATH, "ebm_model_notebook.pth")
torch.save(model.state_dict(), model_save_path)
logger.info(f"Trained model saved to {model_save_path}")

## 6. Experiment with Hyperparameters

Feel free to modify the parameters in `config.py` (or directly in the cells above if you re-run them) and re-execute the training and visualization steps to observe their impact on the EBM's performance and the quality of generated samples. Key parameters to experiment with include:

-   `LANGEVIN_STEPS` (`eta`): Number of steps for Langevin dynamics. More steps can lead to better samples but also higher computational cost.
-   `LANGEVIN_ALPHA` (`alpha`): Step size for Langevin dynamics. A larger value allows faster movement but can lead to instability.
-   `LANGEVIN_SIGMA` (`sigma`): Noise level for Langevin dynamics. Higher noise encourages more exploration of the energy landscape.
-   `EPOCHS` and `LEARNING_RATE` for the main training loop.
-   `HIDDEN_DIMS` in `EnergyNet` to adjust model capacity.