In [2]:
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from app.model import get_model
from app.trainer import train_ebm

# Configuration  
BATCH_SIZE = 64
EPOCHS = 1  
LEARNING_RATE = 1e-4
MCMC_STEPS = 5 #
MCMC_STEP_SIZE = 10
MCMC_NOISE = 0.005
MODEL_SAVE_PATH = "app/ebm_model.pth"
SAMPLES_PATH = "ebm_samples"  

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. Load Data 
transform = transforms.Compose([
    transforms.ToTensor() 
])

train_dataset = CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

print(f"Loaded CIFAR-10 with {len(train_dataset)} training images.")

# 2. Initialize Model 
model = get_model('EBM')
model.to(device)

# 3. Initialize Optimizer  
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 4. Train Model 
trained_model = train_ebm(
    model, 
    train_loader, 
    optimizer, 
    device=device, 
    epochs=EPOCHS,
    mcmc_steps=MCMC_STEPS,
    mcmc_step_size=MCMC_STEP_SIZE,
    mcmc_noise=MCMC_NOISE,
    samples_path=SAMPLES_PATH
)

# 5. Save Model Weights  
torch.save(trained_model.state_dict(), MODEL_SAVE_PATH)
print(f"EBM model weights saved to {MODEL_SAVE_PATH}")


Using device: cpu
Files already downloaded and verified
Loaded CIFAR-10 with 50000 training images.
Starting EBM training for 1 epochs on cpu...


Epoch 1: 100%|████████████████████████████████| 782/782 [24:47<00:00,  1.90s/it]

Epoch 1 completed in 1487.88s | Loss: -0.0513 | E_real: -1.3530 | E_fake: -1.3017
EBM Training Finished.
EBM model weights saved to app/ebm_model.pth



