# ⚡️ Energy Based Models

In this notebook, we'll walk through the steps required to train your own Energy Based Model to predict the distribution of a demo dataset

## Table of contents
0. [Parameters](#parameters)
1. [Prepare the Data](#prepare)
2. [Build the Energy Based Model](#build)
3. [Train the Energy Based Model](#train)
4. [Generate images](#generate)

In [15]:
%load_ext autoreload
%autoreload 2

Global seed set to 42


Device: cpu
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
## Standard libraries
import os
import json
import math
import numpy as np
import random

## Imports for plotting
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "./data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "./models"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

from notebooks.ebm.demo.classes import CNNModel, Sampler, DeepEnergyModel, GenerateCallback, SamplerCallback, OutlierCallback

  set_matplotlib_formats('svg', 'pdf') # For export
Global seed set to 42


Device: cpu


## 0. Parameters <a name="parameters"></a>

In [17]:
COUPLING_DIM = 256
COUPLING_LAYERS = 2
INPUT_DIM = 2
REGULARIZATION = 0.01
BATCH_SIZE = 256
EPOCHS = 300

In [18]:
# Transformations applied on each image => make them a tensor and normalize between -1 and 1
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset. We need to split it into a training and validation part
train_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
# Note that for actually training a model, we will use different data loaders
# with a lower batch size.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True,  drop_last=True,  num_workers=4, pin_memory=True)
test_loader  = data.DataLoader(test_set,  batch_size=256, shuffle=False, drop_last=False, num_workers=4)


In [19]:
def train_model(**kwargs):
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "MNIST"),
                         gpus=1 if str(device).startswith("cuda") else 0,
                         max_epochs=60,
                         gradient_clip_val=0.1,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor='val_contrastive_divergence'),
                                    GenerateCallback(every_n_epochs=5),
                                    SamplerCallback(every_n_epochs=5),
                                    OutlierCallback(),
                                    LearningRateMonitor("epoch")
                                   ],
                        progress_bar_refresh_rate=1)
    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "MNIST.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model, loading...")
        model = DeepEnergyModel.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)
        model = DeepEnergyModel(**kwargs)
        trainer.fit(model, train_loader, test_loader)
        model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    # No testing as we are more interested in other properties
    return model

In [20]:
img_shape=(1,28,28)
batch_size = 128
cnn = CNNModel()
sampler = Sampler(cnn, img_shape=img_shape, sample_size=batch_size)
example_input_array = torch.zeros(1, *img_shape)
input_imgs = sampler.sample_new_exmps()

In [21]:
sampler.generate_samples(sampler.model, input_imgs)

tensor([[[[-0.2486, -0.8098,  0.9071,  ...,  0.6993,  0.9172,  0.4312],
          [-0.8259,  0.0522, -0.9367,  ..., -0.0095, -0.9852, -0.7568],
          [-0.4837, -0.3348,  0.6453,  ..., -0.8984,  0.2860,  0.4679],
          ...,
          [-0.5985, -0.5953,  0.6912,  ..., -0.2188, -0.1369, -0.6605],
          [ 0.5333, -0.3516,  0.6255,  ...,  0.9216,  0.6735,  0.8546],
          [ 0.9660, -0.3096,  0.7235,  ...,  0.7404,  0.9809,  0.3164]]],


        [[[-0.6689, -0.3632,  0.8941,  ...,  0.0585,  0.1784,  0.8069],
          [-0.1678,  0.5656,  0.9212,  ...,  0.4515, -1.0000,  0.2704],
          [ 0.7980, -0.6779,  0.0068,  ...,  0.8010, -0.1771, -0.7489],
          ...,
          [-0.3786, -0.4167,  0.1068,  ...,  0.3646, -0.9591,  0.6883],
          [-0.1401,  0.3082,  0.9009,  ..., -0.6789,  0.1785, -0.3770],
          [ 0.0663,  0.0209,  0.8108,  ...,  0.3384,  0.3455,  0.5107]]],


        [[[-0.9875, -0.7701,  0.5784,  ...,  0.6955,  0.2608, -0.1162],
          [ 0.2540,  0.043

In [22]:
len(sampler.examples)

256

In [23]:
model = train_model(img_shape=(1,28,28),
                    batch_size=train_loader.batch_size,
                    lr=1e-4,
                    beta1=0.0)

  rank_zero_deprecation(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 42
  rank_zero_deprecation(
Missing logger folder: models/MNIST/lightning_logs

  | Name | Type     | Params | In sizes       | Out sizes
---------------------------------------------------------------
0 | cnn  | CNNModel | 77.0 K | [1, 1, 28, 28] | [1]      
---------------------------------------------------------------
77.0 K    Trainable params
0         Non-trainable params
77.0 K    Total params
0.308     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


IsADirectoryError: [Errno 21] Is a directory: '/app/notebooks/ebm/demo'

## 4. Generate images <a name="generate"></a>

In [None]:
model.to(device)
pl.seed_everything(43)
callback = GenerateCallback(batch_size=4, vis_steps=8, num_steps=256)
imgs_per_step = callback.generate_imgs(model)
imgs_per_step = imgs_per_step.cpu()

In [None]:
for i in range(imgs_per_step.shape[1]):
    step_size = callback.num_steps // callback.vis_steps
    imgs_to_plot = imgs_per_step[step_size-1::step_size,i]
    imgs_to_plot = torch.cat([imgs_per_step[0:1,i],imgs_to_plot], dim=0)
    grid = torchvision.utils.make_grid(imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, range=(-1,1), pad_value=0.5, padding=2)
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(8,8))
    plt.imshow(grid)
    plt.xlabel("Generation iteration")
    plt.xticks([(imgs_per_step.shape[-1]+2)*(0.5+j) for j in range(callback.vis_steps+1)],
               labels=[1] + list(range(step_size,imgs_per_step.shape[0]+1,step_size)))
    plt.yticks([])
    plt.show()

In [None]:
with torch.no_grad():
    rand_imgs = torch.rand((128,) + model.hparams.img_shape).to(model.device)
    rand_imgs = rand_imgs * 2 - 1.0
    rand_out = model.cnn(rand_imgs).mean()
    print(f"Average score for random images: {rand_out.item():4.2f}")

In [None]:
with torch.no_grad():
    train_imgs,_ = next(iter(train_loader))
    train_imgs = train_imgs.to(model.device)
    train_out = model.cnn(train_imgs).mean()
    print(f"Average score for training images: {train_out.item():4.2f}")

In [None]:
@torch.no_grad()
def compare_images(img1, img2):
    imgs = torch.stack([img1, img2], dim=0).to(model.device)
    score1, score2 = model.cnn(imgs).cpu().chunk(2, dim=0)
    grid = torchvision.utils.make_grid([img1.cpu(), img2.cpu()], nrow=2, normalize=True, range=(-1,1), pad_value=0.5, padding=2)
    grid = grid.permute(1, 2, 0)
    plt.figure(figsize=(4,4))
    plt.imshow(grid)
    plt.xticks([(img1.shape[2]+2)*(0.5+j) for j in range(2)],
               labels=["Original image", "Transformed image"])
    plt.yticks([])
    plt.show()
    print(f"Score original image: {score1:4.2f}")
    print(f"Score transformed image: {score2:4.2f}")

In [None]:
test_imgs, _ = next(iter(test_loader))
exmp_img = test_imgs[0].to(model.device)

In [None]:
img_noisy = exmp_img + torch.randn_like(exmp_img) * 0.3
img_noisy.clamp_(min=-1.0, max=1.0)
compare_images(exmp_img, img_noisy)

In [None]:
img_flipped = exmp_img.flip(dims=(1,2))
compare_images(exmp_img, img_flipped)

In [None]:
img_tiny = torch.zeros_like(exmp_img)-1
img_tiny[:,exmp_img.shape[1]//2:,exmp_img.shape[2]//2:] = exmp_img[:,::2,::2]
compare_images(exmp_img, img_tiny)