# RBM Package Demo

This notebook demonstrates how to use the reorganized RBM package for training and inference with Perturb-and-MAP methodology.

In [None]:
import sys
from pathlib import Path

# Add src to path for imports
sys.path.insert(0, str(Path.cwd().parent / "src"))

import torch
import torch.optim as optim
import matplotlib.pyplot as plt

## Import RBM Package Components

In [None]:
from rbm.models.rbm import RBM
from rbm.solvers.gurobi import GurobiSolver
from rbm.training.trainer import Trainer
from rbm.data.mnist import load_mnist_data
from rbm.utils.config import ConfigManager
from rbm.utils.visualization import plot_reconstruction, plot_generation
from rbm.inference.reconstruction import reconstruct_image
from rbm.inference.generation import generate_samples

## Load Configuration

In [None]:
# Load configuration for digit 6 experiment
config_manager = ConfigManager(config_dir="../configs")
config = config_manager.load("mnist_digit6")

print("Configuration loaded:")
print(f"Model: {config['model']['model_type']}")
print(f"Visible units: {config['model']['n_visible']}")
print(f"Hidden units: {config['model']['n_hidden']}")
print(f"Solver: {config['solver']['name']}")
print(f"Epochs: {config['training']['epochs']}")

## Create Model and Components

In [None]:
# Create RBM model
model = RBM(
    n_visible=config['model']['n_visible'],
    n_hidden=config['model']['n_hidden']
)

# Create solver (check if Gurobi is available)
if GurobiSolver.is_available:
    solver = GurobiSolver(suppress_output=True)
    print(f"Using {solver.name} solver")
else:
    print("Gurobi not available. Install gurobipy to use this solver.")
    # You could fallback to SCIP or another solver here

# Create optimizer
optimizer = optim.SGD(model.parameters(), lr=config['training']['learning_rate'])

print(f"Model created with {model.n_visible} visible and {model.n_hidden} hidden units")

## Load Data

In [None]:
# Load MNIST data filtered for digit 6
train_loader, dataset_size = load_mnist_data(config, train=True)
test_loader, test_size = load_mnist_data(config, train=False)

print(f"Training samples: {dataset_size}")
print(f"Test samples: {test_size}")
print(f"Image size: {config['data']['image_size']}")

## Training (Short Demo)

In [None]:
# For demo purposes, train for just a few epochs
demo_config = config.copy()
demo_config['training']['epochs'] = 2
demo_config['training']['batch_limit'] = 5  # Limit batches for quick demo

# Create trainer
trainer = Trainer(model, solver, optimizer, demo_config)

# Train the model
print("Starting training...")
results = trainer.train(train_loader)

print(f"Training completed! Final loss: {results['final_loss']:.6f}")

## Inference - Reconstruction

In [None]:
# Get some test images
test_batch = next(iter(test_loader))[0]
test_images = test_batch[:3]  # Take first 3 images

# Reconstruct images
reconstructions = []
for i in range(test_images.size(0)):
    reconstructed = reconstruct_image(model, test_images[i], solver)
    reconstructions.append(reconstructed)

reconstructions = torch.stack(reconstructions)

# Plot reconstruction results
image_shape = tuple(config['data']['image_size'])
plot_reconstruction(
    original=test_images,
    reconstructed=reconstructions,
    image_shape=image_shape,
    title="RBM Reconstruction Demo"
)

## Inference - Generation

In [None]:
# Generate new samples
print("Generating samples...")
generated = generate_samples(
    model=model,
    solver=solver,
    num_samples=5,
    gibbs_steps=100,  # Reduced for demo
    verbose=True
)

# Plot generated samples
plot_generation(
    generated_samples=generated,
    image_shape=image_shape,
    title="RBM Generated Samples Demo"
)

## Package Usage Summary

This demo shows how the reorganized RBM package provides:

1. **Clean imports** - Import only what you need from specific modules
2. **Configuration management** - Easy loading and management of experiment configurations
3. **Modular design** - Separate concerns for models, solvers, training, and inference
4. **Extensible architecture** - Easy to add new solvers, models, or functionality

### Command Line Usage

You can also use the package from the command line:

```bash
# Training
python experiments/train_rbm.py --config mnist_digit6 --epochs 10

# Inference
python experiments/run_inference.py checkpoint.pth --config mnist_digit6 --task both
```

This provides a much cleaner and more maintainable codebase compared to the original scattered notebook code!