<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/MYEMB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## MNIST Digit Recognition

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Load MNIST dataset
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transforms.ToTensor()
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# Load MNIST test dataset
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transforms.ToTensor()
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the ImageEBM with mixed activations
class ImageEBM(nn.Module):
    def __init__(self):
        super(ImageEBM, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(8)
        self.fc1 = nn.Linear(8 * 28 * 28, 10)
        self.fc2 = nn.Linear(10, 1)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))  # ReLU after convolution
        x = x.view(-1, 8 * 28 * 28)
        x = torch.sigmoid(self.fc1(x))  # Sigmoid in fully connected layer
        energy = self.fc2(x)
        return energy

# Initialize the EBM, optimizer, and learning rate scheduler
model = ImageEBM()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training loop with reduced energy regularization
num_epochs = 5
alpha = 0.1  # Reduced regularization strength
for epoch in range(num_epochs):
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        energy = model(data)

        # Add energy regularization to the loss
        loss = torch.mean(energy) + alpha * torch.mean(torch.relu(-energy))

        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}"
            )

    scheduler.step()  # Update learning rate

# Evaluate the model
def evaluate(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            energy = model(data)
            test_loss += torch.mean(energy).item()  # Sum up batch loss

    test_loss /= len(test_loader)
    #print(f"Test Loss: {test_loss:.4f}")

evaluate(model, test_loader)

Epoch [1/5], Batch [1/938], Loss: 0.4417
Epoch [1/5], Batch [101/938], Loss: -0.3597
Epoch [1/5], Batch [201/938], Loss: -0.4088
Epoch [1/5], Batch [301/938], Loss: -0.4551
Epoch [1/5], Batch [401/938], Loss: -0.5005
Epoch [1/5], Batch [501/938], Loss: -0.5460
Epoch [1/5], Batch [601/938], Loss: -0.5912
Epoch [1/5], Batch [701/938], Loss: -0.6365
Epoch [1/5], Batch [801/938], Loss: -0.6816
Epoch [1/5], Batch [901/938], Loss: -0.7267
Epoch [2/5], Batch [1/938], Loss: -0.7439
Epoch [2/5], Batch [101/938], Loss: -0.7890
Epoch [2/5], Batch [201/938], Loss: -0.8341
Epoch [2/5], Batch [301/938], Loss: -0.8791
Epoch [2/5], Batch [401/938], Loss: -0.9242
Epoch [2/5], Batch [501/938], Loss: -0.9692
Epoch [2/5], Batch [601/938], Loss: -1.0143
Epoch [2/5], Batch [701/938], Loss: -1.0593
Epoch [2/5], Batch [801/938], Loss: -1.1044
Epoch [2/5], Batch [901/938], Loss: -1.1494
Epoch [3/5], Batch [1/938], Loss: -1.1665
Epoch [3/5], Batch [101/938], Loss: -1.2115
Epoch [3/5], Batch [201/938], Loss: -1.

In [10]:
# Load MNIST test dataset
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transforms.ToTensor()
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Evaluate the model
def evaluate(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            energy = model(data)
            test_loss += torch.mean(energy).item()  # Sum up batch loss

    test_loss /= len(test_loader)
    print(f"Eval Loss: {test_loss:.4f}")

evaluate(model, test_loader)

Eval Loss: -3.1242


## CIFAR-10 Image Classification

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Load CIFAR-10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the ImageEBM for CIFAR-10 (color images)
class ImageEBM(nn.Module):
    def __init__(self):
        super(ImageEBM, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # Input channels = 3 for color
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)  # Adjust input size for CIFAR-10
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.max_pool2d(x, 2, 2)
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.max_pool2d(x, 2, 2)
        x = x.view(-1, 32 * 8 * 8)  # Adjust size for CIFAR-10
        x = torch.sigmoid(self.fc1(x))
        energy = self.fc2(x)
        return energy

# Initialize the EBM, optimizer, and learning rate scheduler
model = ImageEBM()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training loop with energy regularization
num_epochs = 5  # You might need more epochs for CIFAR-10
alpha = 0.1
for epoch in range(num_epochs):
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        energy = model(data)

        # Add energy regularization to the loss
        loss = torch.mean(energy) + alpha * torch.mean(torch.relu(-energy))

        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}"
            )

    scheduler.step()  # Update learning rate

Files already downloaded and verified
Files already downloaded and verified
Epoch [1/5], Batch [1/782], Loss: -0.3356
Epoch [1/5], Batch [101/782], Loss: -3.6440
Epoch [1/5], Batch [201/782], Loss: -4.3777
Epoch [1/5], Batch [301/782], Loss: -5.0892
Epoch [1/5], Batch [401/782], Loss: -5.8079
Epoch [1/5], Batch [501/782], Loss: -6.5292
Epoch [1/5], Batch [601/782], Loss: -7.2430
Epoch [1/5], Batch [701/782], Loss: -7.9535
Epoch [2/5], Batch [1/782], Loss: -8.5346
Epoch [2/5], Batch [101/782], Loss: -9.2422
Epoch [2/5], Batch [201/782], Loss: -9.9488
Epoch [2/5], Batch [301/782], Loss: -10.6550
Epoch [2/5], Batch [401/782], Loss: -11.3972
Epoch [2/5], Batch [501/782], Loss: -12.1229
Epoch [2/5], Batch [601/782], Loss: -12.8440
Epoch [2/5], Batch [701/782], Loss: -13.5626
Epoch [3/5], Batch [1/782], Loss: -14.1661
Epoch [3/5], Batch [101/782], Loss: -14.9152
Epoch [3/5], Batch [201/782], Loss: -15.6515
Epoch [3/5], Batch [301/782], Loss: -16.4029
Epoch [3/5], Batch [401/782], Loss: -17.1

## eval for CIFAR-10

In [14]:
import torch
from torchvision import datasets, transforms

# Load CIFAR-10 test dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)


# Evaluate the model
def evaluate(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            energy = model(data)
            test_loss += torch.mean(energy).item()  # Sum up batch loss

    test_loss /= len(test_loader)
    print(f"Eval Loss: {test_loss:.4f}")

evaluate(model, test_loader)

Files already downloaded and verified
Eval Loss: -34.9541
