# ⚡️ Energy-Based Models

This notebook is an **unofficial PyTorch implementation** of the excellent [Keras example](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/blob/main/notebooks/07_ebm/01_ebm/ebm.ipynb) for Energy-Based models, originally created by David Foster as part of the companion code for the excellent book [Generative Deep Learning, 2nd Edition](https://www.oreilly.com/library/view/generative-deep-learning/9781098134174/).

_The original code is available [here](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition) and is licensed under the Apache License 2.0._
_This implementation is distributed under the Apache License 2.0. See the LICENSE file for details._

In this notebook, we'll walk through the steps required to train your own Energy Based Model to predict the distribution of a demo dataset using PyTorch.

In [None]:
%load_ext autoreload
%autoreload 2

import os

# Get the working directory and the current notebook directory
working_dir = os.getcwd()
exp_dir = os.path.join(working_dir, "notebooks/07_ebm/01_ebm/")

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
from torch import nn
import torch.nn.functional as F
from torchinfo import summary
import math
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random 
from notebooks.utils import display

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 32
CHANNELS = 1
# STEP_SIZE = 10
STEP_SIZE = 10
STEPS = 60
NOISE = 0.005
ALPHA = 0.1
# ALPHA = 0.3
GRADIENT_CLIP = 0.03
# BATCH_SIZE = 128
BATCH_SIZE = 256
BUFFER_SIZE = 8192
LEARNING_RATE = 0.0001
#LEARNING_RATE = 0.00001
EPOCHS = 60
LOAD_MODEL = False

## 1. Preparing the data <a name="preparing the data"></a>

In [None]:
data_dir = working_dir + "/data"

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,)),
     transforms.Pad((2,2), fill=-1)]
)
x_train = datasets.MNIST(root=data_dir, train=True,
                        download=True, transform=transform)
x_test = datasets.MNIST(root=data_dir, train=False,
                        download=True, transform=transform)

train_data_loader = DataLoader(x_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_data_loader = DataLoader(x_test, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [None]:
print(f"x_train: {len(x_train)}")
print(f"x_test: {len(x_test)}")

In [None]:
data_iter = iter(train_data_loader)
sample, _ = next(data_iter)
print(sample.shape)
print(f"min: {torch.min(sample)}, max: {torch.max(sample)}")
display(sample)

## 2. Build the EBM network <a name="train"></a>

In [None]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

In [None]:
class EbnModel(nn.Module):
    def __init__(self, input_channels, image_shape):
        super().__init__()
        self.input_channels = input_channels
        self.image_shape = image_shape
        
        self.conv2d_1 = nn.Conv2d(self.input_channels, 
                                  out_channels=16, 
                                  kernel_size=5, 
                                  stride=2, 
                                  padding=self._get_padding_size(self.image_shape[0], stride=2, kernal_size=5))
        self.conv2d_2 = nn.Conv2d(in_channels=16, 
                                  out_channels=32, 
                                  kernel_size=3, 
                                  stride=2, 
                                  padding=self._get_padding_size(self.image_shape[0] / 2, stride=2, kernal_size=3))
        self.conv2d_3 = nn.Conv2d(in_channels=32, 
                                  out_channels=64, 
                                  kernel_size=3, 
                                  stride=2, 
                                  padding=self._get_padding_size(self.image_shape[0] / (2*2), stride=2, kernal_size=3))
        self.conv2d_4 = nn.Conv2d(in_channels=64, 
                                  out_channels=64, 
                                  kernel_size=3, 
                                  stride=2, 
                                  padding=self._get_padding_size(self.image_shape[0] / (2*2*2), stride=2, kernal_size=3))

        with torch.no_grad():
            dummy_input = torch.zeros((1, self.input_channels, *self.image_shape))
            x = self.conv2d_1(dummy_input)
            x = self.conv2d_2(x)
            x = self.conv2d_3(x)
            x = self.conv2d_4(x)
            fc_input_size = x.view(1, -1).shape[1]
        
        self.fc_1 = nn.Linear(in_features=fc_input_size, out_features=64)
        self.fc_2 =nn.Linear(in_features=64, out_features=1)

        self.activation = Swish()
    
    @staticmethod
    def _get_padding_size(input_w, stride, kernal_size):
        p = ((input_w /2) - 1) * stride
        p = (p - input_w) + kernal_size
        p = math.ceil(p/2)

        return p
    
    
    def forward(self, x):

        x = self.conv2d_1(x)
        x = self.activation(x)
        x = self.conv2d_2(x)
        x = self.activation(x)
        x = self.conv2d_3(x)
        x = self.activation(x)
        x = self.conv2d_4(x)
        x = self.activation(x)
        
        x = x.view(x.shape[0], -1)
        x = self.fc_1(x)
        x = self.activation(x)
        x = self.fc_2(x)
        # x = F.tanh(x)

        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
ebn_instance = EbnModel(CHANNELS, (IMAGE_SIZE, IMAGE_SIZE))
summary(ebn_instance, (1, CHANNELS, IMAGE_SIZE, IMAGE_SIZE))

## 3. Set up a Langevin sampler function <a name="sampler"></a>

In [None]:
def generate_samples(model, inp_imgs, steps, step_size, noise, device, return_img_per_step=False):
    imgs_per_step = []
    
    # disable the gradients for the model parameters
    for p in model.parameters():
        p.requires_grad = False
    inp_imgs = inp_imgs.detach().requires_grad_(True)

    for _ in range(steps):
        # randn will produce random number with normal distrbuition of mean =0 and std=1 so to 
        # get std = noise we should multipley by noise
        additional_noise = (torch.randn_like(inp_imgs) * noise)
        inp_imgs = inp_imgs + additional_noise
        inp_imgs = inp_imgs.clamp(min=-1.0, max=1.0)
        inp_imgs.retain_grad()

        if inp_imgs.grad is not None:
            inp_imgs.grad.zero_()
        
        out_score = model(inp_imgs.to(device))
        # calculate the gradient of the score with respect to the input images
        out_score.sum().backward()
        grad = inp_imgs.grad
        grad = grad.clamp(min=-GRADIENT_CLIP, max=GRADIENT_CLIP)


        # Gradient assent of the images
        inp_imgs = inp_imgs + (step_size * grad)
        inp_imgs = inp_imgs.clamp(min=-1.0, max=1.0)
        inp_imgs = inp_imgs.detach().requires_grad_(True)
        # output_imgs = inp_imgs.clone().detach()

        if return_img_per_step:
            imgs_per_step.append(inp_imgs)

    # enable the gradients for the model parameters
    for p in model.parameters():
        p.requires_grad = True
    
    if return_img_per_step:
        return_value = imgs_per_step
    else:
        return_value = inp_imgs
    
    return return_value


In [None]:
class Buffer():
    def __init__(self, model, batch_size, buffer_size, channels, image_size):
        self.model = model
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.channels = channels
        self.image_size = image_size
        self.examples =  [(torch.rand((1, self.channels, self.image_size, self.image_size)) * 2 ) - 1
                          for _ in range(self.batch_size)]

    
    def sample_new_examples(self, steps, step_size, noise, device):
        num_new_samples = np.random.binomial(self.batch_size, p=0.05)
        rand_images = [(torch.rand((1, self.channels, self.image_size, self.image_size)) * 2 ) - 1
                          for _ in range(num_new_samples)]
        old_images = random.choices(self.examples, k=(self.batch_size-num_new_samples))
        input_images = torch.cat(rand_images + old_images, dim=0)

        input_images = generate_samples(self.model, input_images, 
                                        steps=steps, step_size=step_size, 
                                        noise=noise, device=device)

        # add the images to the buffer and trancate it if it exceeded the buffer size
        self.examples = list(torch.chunk(input_images, self.batch_size, dim=0)) + self.examples
        self.examples = self.examples [:self.buffer_size]

        return input_images


In [None]:
# test the buffer class
buffer_test = Buffer(ebn_instance, BATCH_SIZE, BUFFER_SIZE, channels=CHANNELS, image_size=IMAGE_SIZE)
sample = buffer_test.sample_new_examples(10, 10, NOISE, device)
print(sample.shape)


In [None]:
class EBM (nn.Module):
    def __init__(self, channels, image_size, batch_size, buffer_size, log_dir="./log"):
        super().__init__()
        self.channels = channels
        self.image_size = image_size
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        
        self.train_log_writter = SummaryWriter(log_dir=log_dir+"/train/")
        self.test_log_writter = SummaryWriter(log_dir=log_dir+"/val/")
        self.model = EbnModel(self.channels, (self.image_size, self.image_size))
        self.buffer = Buffer(self.model, self.batch_size, self.buffer_size, 
                             self.channels, self.image_size)
    
    def fit(self, optimizer, train_data_loader, test_data_loader, epochs, device, 
            steps, step_size, callbacks=None, reg_alpha=0.1, noise=0.005):

        self.optimizer = optimizer
        self.device = device
        self.reg_alpha = reg_alpha
        self.noise = noise
        self.steps = steps
        self.step_size = step_size
        
        for epoch in range(epochs):
            acc_train_loss = {"cdiv_loss":0, "reg_loss":0}
            acc_test_loss = {"cdiv_loss":0, "reg_loss":0}

            for train_data, _ in train_data_loader:
                train_loss = self.train_step(train_data)

                for key in train_loss.keys():
                    acc_train_loss[key] += train_loss[key]
            
            for key in acc_train_loss.keys():
                acc_train_loss[key] /= len(train_data_loader)

            self.train_log_writter.add_scalar("cdiv_loss", acc_train_loss["cdiv_loss"], global_step=epoch)
            self.train_log_writter.add_scalar("reg_loss", acc_train_loss["reg_loss"], global_step=epoch)
            self.train_log_writter.add_scalar("total_loss", acc_train_loss["cdiv_loss"] + acc_train_loss["reg_loss"], 
                                              global_step=epoch)

            # running validation
            for test_data, _ in test_data_loader:
                test_loss = self.test_step(test_data)

                for key in test_loss.keys():
                    acc_test_loss[key] += test_loss[key]
            
            for key in acc_test_loss.keys():
                acc_test_loss[key] /= len(test_data_loader)

            self.test_log_writter.add_scalar("cdiv_loss", acc_test_loss["cdiv_loss"], global_step=epoch)
            self.test_log_writter.add_scalar("reg_loss", acc_test_loss["reg_loss"], global_step=epoch)
            self.test_log_writter.add_scalar("total_loss", acc_test_loss["reg_loss"] + acc_test_loss["cdiv_loss"], 
                                             global_step=epoch)

            print(f"epoch {epoch}/{epochs}: train_cdiv_loss:{acc_train_loss['cdiv_loss']}, "
                  f"train_reg_loss:{acc_train_loss['reg_loss']}, val_cdiv_loss:{acc_test_loss['cdiv_loss']}"
                  f"val_reg_loss:{acc_test_loss['reg_loss']}")
            
            if callbacks is not None:
                logs = {"epoch":epoch,
                        "model_state_dict": self.model.state_dict(),
                        "loss": acc_train_loss,
                        "examples": torch.cat(random.choices(self.buffer.examples, k=10), dim=0),
                        "model": self.model,
                        "device": device}
                
                for callback in callbacks:
                    callback.on_epoch_end(epoch=epoch, logs=logs)

    @staticmethod
    def contrastive_divergence_loss(real_images_scores, fake_images_scores):
        return (torch.mean(fake_images_scores) - torch.mean(real_images_scores))
    
    def regularization_loss(self, real_images_scores, fake_images_scores):
        return (self.reg_alpha * torch.mean(real_images_scores**2 + fake_images_scores**2))
    
    def train_step(self, train_real_images):

        self.model.train()
        self.optimizer.zero_grad()

        # add noise to the real images
        train_real_images = train_real_images.add(torch.randn_like(train_real_images) * self.noise)
        train_real_images = train_real_images.clamp(min=-1.0, max=1.0)
        train_real_images = train_real_images.to(self.device)

        # get the fake images
        fake_images = self.buffer.sample_new_examples(self.steps, self.step_size, self.noise, self.device).to(device)
        input_images = torch.cat((train_real_images, fake_images), dim=0)

        
        scores = self.model(input_images)
        real_scores, fake_scores = torch.chunk(scores, 2, dim=0)
 
        cdiv_loss = self.contrastive_divergence_loss(real_scores, fake_scores)

        reg_loss = self.regularization_loss(real_scores, fake_scores)

        loss = cdiv_loss + reg_loss

        loss.backward()

        self.optimizer.step()

        train_loss = {"cdiv_loss": cdiv_loss.item(), "reg_loss": reg_loss.item()}

        return train_loss


    def test_step(self, test_real_images):
        
        self.model.eval()

        fake_images = (torch.randn_like(test_real_images) * self.noise).to(self.device)
        test_real_images = test_real_images.to(self.device)

        input_images = torch.cat((test_real_images, fake_images), dim=0)

        with torch.no_grad():
            scores = self.model(input_images)
            real_scores, fake_scores = torch.chunk(scores, 2)

            cdiv_loss = self.contrastive_divergence_loss(real_scores, fake_scores)

            reg_loss = self.regularization_loss(real_scores, fake_scores)

            test_loss = {"cdiv_loss": cdiv_loss.item(), "reg_loss": reg_loss.item()}
        
        return test_loss

Create the required callbacks

In [None]:
class Callback:
    def on_epoch_end(self, epoch, logs=None):
        pass

In [None]:
class SaveCheckpoint(Callback):
    def __init__(self, save_dir, save_every=10):
        super().__init__()
        self.save_dir = save_dir
        self.save_every = save_every
    def on_epoch_end(self, epoch, logs=None):
        
        if (epoch % self.save_every) == 0:
            checkpoint = {"epoch":epoch,
                        "model_state_dict":logs["model_state_dict"],
                        "loss":logs["loss"]
                        }
            checkpoint_file = self.save_dir + f"/checkpoint_{epoch}.pth"

            torch.save(checkpoint, checkpoint_file)

In [None]:
class GenerateImages(Callback):
    def __init__(self, num_images, noise, step_size, steps, save_dir="./gen_examples"):
        super().__init__()
        self.num_images = num_images
        self.save_dir = save_dir
        self.noise = noise
        self.steps = steps
        self.step_size = step_size

    def on_epoch_end(self, epoch, logs=None):
        device = logs["device"]
        model = logs["model"]
        examples = logs["examples"]
        epoch = logs["epoch"]

        intial_images = ((torch.rand(examples.shape) * 2) - 1)

        generated_images = generate_samples(model, intial_images,
                                            steps=self.steps, step_size=self.step_size,
                                            noise=self.noise, device=device)
        

        display(generated_images, self.num_images, save_to=self.save_dir+f"/generated_img_epoch_{epoch}.png")

        display(examples, self.num_images, save_to=self.save_dir+f"/example_img_epoch_{epoch}.png")
        
        

Prepare for EBM training

In [None]:
log_dir =  exp_dir + "/log"
os.makedirs(log_dir, exist_ok=True)

sample_dir =  exp_dir + "/sample_gen"
os.makedirs(sample_dir, exist_ok=True)

checkpoint_dir =  exp_dir + "/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
callbacks = [GenerateImages(10, noise=NOISE, step_size=1000, steps=STEPS, save_dir=sample_dir),
             SaveCheckpoint(save_dir=checkpoint_dir, save_every=30)]

In [None]:
ebm = EBM(channels=CHANNELS, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, 
          buffer_size=BUFFER_SIZE, log_dir=log_dir).to(device)
optimizer = Adam(ebm.model.parameters(), lr=LEARNING_RATE)

In [None]:
# check if we have checkpoint to load
if LOAD_MODEL:
    checkpoint_file = checkpoint_dir + "/checkpoint_270.pth"
    checkpoint = torch.load(checkpoint_file)
    ebm.load_state_dict(checkpoint["model_state_dict"])

In [None]:
ebm.fit(optimizer=optimizer, train_data_loader=train_data_loader, test_data_loader=test_data_loader,
        epochs=EPOCHS, device=device, steps=STEPS, step_size=STEP_SIZE, callbacks=callbacks,
        reg_alpha=ALPHA, noise=NOISE)

## 4. Generate images <a name="generate"></a>

In [None]:
start_images = ((torch.rand((10, CHANNELS, IMAGE_SIZE, IMAGE_SIZE)) * 2) - 1).to(device)

In [None]:
display(start_images, n=10)

In [None]:
generated_images = generate_samples(ebm.model, start_images, steps=1000, step_size=STEP_SIZE, 
                                    device=device, noise=NOISE, return_img_per_step=True)

In [None]:
print(len(generated_images))
print(generated_images[0][2].shape)

In [None]:
display(generated_images[-1], n=10)

In [None]:
imgs = []
for i in [0, 1, 3, 5, 10, 30, 50, 100, 300, 999]:
    imgs.append(generated_images[i][6])

display(torch.stack(imgs))

In [None]:
data_iter = iter(train_data_loader)
sample, _ = next(data_iter)
print(ebm.model(sample.to(device))[0].item())
print(ebm.model(start_images.to(device))[0].item())