### General Settings

Change the respective settings to run appropriately

Use `limit_train_batches`, `limit_val_batches`, `limit_test_batches` as required

In [1]:
project_dir = '/Users/rajjain/PycharmProjects/ADRL-Course-Work/'
data_dir = project_dir + 'data/'
bitmoji_data_dir = '/Users/rajjain/Desktop/CourseWork/Bitmoji/'
use_gpu = False
num_cpus = 2

In [2]:
from torch.nn import init, Linear, Sequential, Conv2d, PReLU, BatchNorm2d, Flatten, MaxPool2d, InstanceNorm2d
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningModule
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from torchvision import transforms
from datetime import datetime
from torchinfo import summary
from torch.optim import Adam
from torch import autograd
import torch
import numpy
import os
import gc

In [3]:
def custom_collate_fn(batch):
    imgs = torch.stack([elem[0] for elem in batch])
    return [imgs]

# Model Definition

In [4]:
class MALAConstEBM(LightningModule):
    """
    EBM for Bitmoji images.
    We assume a form p_theta(x) = exp(-E_theta(x)) / z_theta
    """
    tau = 0.01

    def __init__(self, bs, t):
        super(MALAConstEBM, self).__init__()
        self.save_hyperparameters()
        self.bs = bs
        self.t = t

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Conv2d(in_channels=15, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Flatten(),

            Linear(in_features=135, out_features=1),
        )

        self.initialise()
        self.float()

    def initialise(self):
        seed_everything(0)
        for i in range(0, 37, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['41'].weight)

    def gen_samples(self, num_samples, steps=100, seed=0, step_method: str = 'langevin_const'):
        seed_everything(seed)
        xs = 2 * torch.rand((num_samples, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps):
            if step_method == 'langevin_const':
                xs = self.one_langevin_step(xs, t)
            elif step_method == 'langevin_decay':
                xs = self.one_langevin_decay_step(xs, t)
            elif step_method == 'mala':
                xs = self.one_mala_step(xs, t)
        return xs.clone()

    def get_transition(self, steps=100, jump=1, seed=0, step_method: str = 'langevin_const'):
        seed_everything(seed)
        transitions = []
        xs = 2 * torch.rand((1, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps * jump):
            if step_method == 'langevin_const':
                xs = self.one_langevin_step(xs, t)
            elif step_method == 'langevin_decay':
                xs = self.one_langevin_decay_step(xs, t)
            elif step_method == 'mala':
                xs = self.one_mala_step(xs, t)
            if t % jump == 0:
                transitions.append(xs[0])
        return torch.stack(transitions)

    def one_langevin_step(self, xs: torch.Tensor, t):
        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        xs = xs - 0.95 * xs_grad + 0.0005 * torch.randn_like(xs)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def one_langevin_decay_step(self, xs: torch.Tensor, t):
        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        step_size = 1e-4 / (t + 1) ** 0.75
        xs = xs - step_size * xs_grad + numpy.sqrt(2 * step_size) * torch.randn_like(xs)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def one_mala_step(self, xs: torch.Tensor, t):
        batch_size = xs.shape[0]

        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        proposal_xs = xs - self.tau * xs_grad + numpy.sqrt(2 * self.tau) * torch.randn_like(xs)
        proposal_xs.detach_()
        proposal_xs.clip_(min=-1, max=1)

        proposal_xs.requires_grad_(True)
        proposal_xs_output = self.model(proposal_xs)
        proposal_xs_grad = autograd.grad(outputs=proposal_xs_output.sum(), inputs=proposal_xs, only_inputs=True)[0]

        xs.detach_()
        proposal_xs.detach_()
        xs_output = xs_output.detach().reshape((-1,))
        proposal_xs_output = proposal_xs_output.detach().reshape((-1,))
        xs_grad.detach_()
        proposal_xs_grad.detach_()

        xs_flattened = xs.reshape((batch_size, -1))
        proposal_xs_flattened = proposal_xs.reshape((batch_size, -1))
        xs_grad_flattened = xs_grad.reshape((batch_size, -1))
        proposal_xs_grad_flattened = proposal_xs_grad.reshape((batch_size, -1))

        numerator = xs_flattened - proposal_xs_flattened + self.tau * proposal_xs_grad_flattened
        numerator = numerator.norm(p=2, dim=1) ** 2

        denominator = proposal_xs_flattened - xs_flattened + self.tau * xs_grad_flattened
        denominator = denominator.norm(p=2, dim=1) ** 2

        exp_power = -proposal_xs_output + xs_output - numerator / (4 * self.tau) + denominator / (4 * self.tau)
        prob = torch.exp(exp_power)

        alpha = torch.minimum(prob, torch.ones_like(prob))

        u = torch.rand(alpha.shape)

        xss = []
        for i in range(batch_size):
            if u[i] <= alpha[i]:
                xss.append(proposal_xs[i])
            else:
                xss.append(xs[i])

        xs = torch.stack(xss)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def mala_steps(self, xs: torch.Tensor):
        for t in range(self.t):
            xs = self.one_mala_step(xs, t)
        return xs.clone()

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        xs, = batch
        real_score = self.model(xs).mean()
        mala_samples = self.mala_steps(xs.clone())
        mala_score = self.model(mala_samples).mean()
        loss = real_score - mala_score
        self.log(f'{btype}/loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                 sync_dist=not_training)
        self.log(f'{btype}/real_score', real_score, on_step=False, on_epoch=True, logger=True, sync_dist=not_training)
        self.log(f'{btype}/fake_score', mala_score, on_step=False, on_epoch=True, logger=True, sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        """Setting num_workers = 0 as some issue with jupyter and pytorch. in normal implementation, 
        num_cpus is used. Check GitHub code"""
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'train/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=True, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def val_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'val/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def test_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'test/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def summary(self) -> str:
        _summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                               row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        _bitmoji = torch.randn((10, 3, 128, 128), dtype=torch.float)
        _summary_string = str(summary(model=self.model, input_data=_bitmoji, **_summary_kwargs))
        return _summary_string


## Summary

In [5]:
print(MALAConstEBM(bs=1, t=1).summary())

Global seed set to 0


Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #
Sequential (Sequential)                  [10, 3, 128, 128]         [10, 1]                   --
├─Conv2d (0): 1-1                        [10, 3, 128, 128]         [10, 6, 126, 126]         168
├─PReLU (1): 1-2                         [10, 6, 126, 126]         [10, 6, 126, 126]         6
├─MaxPool2d (2): 1-3                     [10, 6, 126, 126]         [10, 6, 124, 124]         --
├─BatchNorm2d (3): 1-4                   [10, 6, 124, 124]         [10, 6, 124, 124]         12
├─Conv2d (4): 1-5                        [10, 6, 124, 124]         [10, 6, 122, 122]         330
├─PReLU (5): 1-6                         [10, 6, 122, 122]         [10, 6, 122, 122]         6
├─MaxPool2d (6): 1-7                     [10, 6, 122, 122]         [10, 6, 120, 120]         --
├─BatchNorm2d (7): 1-8                   [10, 6, 120, 120]         [10, 6, 120, 120]         12
├─Conv2d (8): 1-9                  

# Training Utilities

In [6]:
def train_and_test(max_epochs: int, tags: list[str], gpu_num: list[int],
                   model_class, model_kwargs: dict, model_desc: str,
                   limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0):
    seed_everything(0, workers=True)

    folder_name = datetime.utcnow().isoformat(sep="T", timespec="microseconds")
    results_dir = project_dir + f'ebm/results/run_{folder_name}/'
    os.makedirs(results_dir, exist_ok=False)

    checkpoint_callback = ModelCheckpoint(monitor='val/loss', mode='min', dirpath=results_dir, filename='best',
                                          save_last=True)

    trainer_kwargs = dict(accelerator="gpu", devices=gpu_num) if use_gpu else dict()

    model = model_class(**model_kwargs)

    tf_logger = TensorBoardLogger(save_dir=results_dir, version=f'tf_logs', default_hp_metric=False)
    trainer = Trainer(default_root_dir=results_dir, max_epochs=max_epochs, callbacks=[checkpoint_callback],
                      logger=[tf_logger], log_every_n_steps=1, num_sanity_val_steps=0, deterministic=True,
                      limit_train_batches=limit_train_batches, limit_val_batches=limit_val_batches,
                      limit_test_batches=limit_test_batches,
                      **trainer_kwargs)
    trainer.fit(model)
    trainer.test(model, ckpt_path='best')

    summary = model.summary() + '\n' + model_desc
    with open(results_dir + 'model_desc.md', 'w') as f:
        f.write(summary)

    gc.collect()


# Train & Test

In [7]:
# bs is batch_size 
train_and_test(max_epochs=2, tags=[], gpu_num=[], model_class=MALAConstEBM,
               model_kwargs=dict(bs=10, t=1), model_desc='EBM Model with MALA with constant step size', 
               limit_train_batches=2, limit_val_batches=2, limit_test_batches=2)

Global seed set to 0
Global seed set to 0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 14.0 K
-------------------------------------
14.0 K    Trainable params
0         Non-trainable params
14.0 K    Total params
0.056     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Restoring states from the checkpoint path at /Users/rajjain/PycharmProjects/ADRL-Course-Work/ebm/results/run_2022-11-09T14:05:28.298565/best.ckpt
Loaded model weights from checkpoint at /Users/rajjain/PycharmProjects/ADRL-Course-Work/ebm/results/run_2022-11-09T14:05:28.298565/best.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test/fake_score        -1.0585254430770874
        test/loss           0.10533779859542847
     test/real_score        -0.9531876444816589
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


# Plots & Analysis

In [8]:
def plot_side_by_side(title, samples, fname, size):
    fig, axes = plt.subplots(size, size, figsize=(8, 8))
    fig.subplots_adjust(wspace=0.01, hspace=0.01, left=0, bottom=0, right=1, top=0.95)
    axes = axes.flat

    for i in range(samples.shape[0]):
        ax = axes[i]
        ax.set_axis_off()
        ax.imshow(samples[i])

    fig.suptitle(title)
    plt.savefig(project_dir + f'ebm/img_results/{fname}.png')


def convert_to_image(ndarray):
    # ndarray = numpy.clip(ndarray, -1, 1)  # -1 to 1
    ndarray = ndarray * 0.5 + 0.5  # 0 to 1
    ndarray *= 255  # 0 to 255
    ndarray = numpy.round(ndarray, decimals=0)  # rounded off
    return ndarray.astype(int)


def see_some_generations(model, model_type, step_method):
    samples = model.gen_samples(num_samples=100, steps=100, step_method=step_method)
    samples = convert_to_image(numpy.transpose(samples.detach().numpy(), (0, 2, 3, 1)))
    plot_side_by_side(f'{model.__class__.__name__} EBM Bitmoji - t = {model.t} - {model_type} - {step_method}', samples,
                      f'{model.__class__.__name__}_ebm_gen_t={model.t}_{model_type}_{step_method}', size=10)


def see_some_transitions(model, model_type, jump, step_method):
    samples = model.get_transition(steps=225, jump=jump, seed=0, step_method=step_method)
    samples = convert_to_image(numpy.transpose(samples.detach().numpy(), (0, 2, 3, 1)))
    plot_side_by_side(f'{model.__class__.__name__} Bitmoji Transitions - t = {model.t}, jump = {jump} - {model_type} - {step_method}', samples,
                      f'{model.__class__.__name__}_ebm_trans_t={model.t}_jump={jump}_{model_type}_{step_method}', size=15)

# Other Models Tried

In [None]:
class LangConstEBM(LightningModule):
    """
    EBM for Bitmoji images.
    We assume a form p_theta(x) = exp(-E_theta(x)) / z_theta
    """

    def __init__(self, bs, t):
        super(LangConstEBM, self).__init__()
        self.save_hyperparameters()
        self.bs = bs
        self.t = t

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Conv2d(in_channels=15, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Flatten(),

            Linear(in_features=135, out_features=1),
        )

        self.initialise()
        self.float()

    def initialise(self):
        seed_everything(0)
        for i in range(0, 37, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['41'].weight)

    def gen_samples(self, num_samples, steps=100, seed=0):
        seed_everything(seed)
        xs = 2 * torch.rand((num_samples, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps):
            xs = self.one_langevin_step(xs, t)
        return xs.clone()

    def get_transition(self, steps=100, jump=1, seed=0):
        seed_everything(seed)
        transitions = []
        xs = 2 * torch.rand((1, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps * jump):
            xs = self.one_langevin_step(xs, t)
            if t % jump == 0:
                transitions.append(xs[0])
        return torch.stack(transitions)

    def one_langevin_step(self, xs: torch.Tensor, t):
        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        xs = xs - 0.95 * xs_grad + 0.0005 * torch.randn_like(xs)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def langevin_steps(self, xs: torch.Tensor):
        for t in range(self.t):
            xs = self.one_langevin_step(xs, t)
        return xs.clone()

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        xs, = batch
        real_score = self.model(xs).mean()
        langevin_samples = self.langevin_steps(xs.clone())
        langevin_score = self.model(langevin_samples).mean()
        loss = real_score - langevin_score
        self.log(f'{btype}/loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                 sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'train/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=True, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def val_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'val/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def test_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'test/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def summary(self) -> str:
        _summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                               row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        _bitmoji = torch.randn((10, 3, 128, 128), dtype=torch.float)
        _summary_string = str(summary(model=self.model, input_data=_bitmoji, **_summary_kwargs))
        return _summary_string


class LangDecayEBM(LightningModule):
    """
    EBM for Bitmoji images.
    We assume a form p_theta(x) = exp(-E_theta(x)) / z_theta
    """

    def __init__(self, bs, t):
        super(LangDecayEBM, self).__init__()
        self.save_hyperparameters()
        self.bs = bs
        self.t = t

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Conv2d(in_channels=15, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Flatten(),

            Linear(in_features=135, out_features=1),
        )

        self.initialise()
        self.float()

    def initialise(self):
        seed_everything(0)
        for i in range(0, 37, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['41'].weight)

    def gen_samples(self, num_samples, steps=100, seed=0):
        seed_everything(seed)
        xs = 2 * torch.rand((num_samples, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps):
            xs = self.one_langevin_step(xs, t)
        return xs.clone()

    def get_transition(self, steps=100, jump=1, seed=0):
        seed_everything(seed)
        transitions = []
        xs = 2 * torch.rand((1, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps * jump):
            xs = self.one_langevin_step(xs, t)
            if t % jump == 0:
                transitions.append(xs[0])
        return torch.stack(transitions)

    def one_langevin_step(self, xs: torch.Tensor, t):
        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        step_size = 1e-4 / (t + 1) ** 0.75
        xs = xs - step_size * xs_grad + numpy.sqrt(2 * step_size) * torch.randn_like(xs)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def langevin_steps(self, xs: torch.Tensor):
        for t in range(self.t):
            xs = self.one_langevin_step(xs, t)
        return xs.clone()

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        xs, = batch
        real_score = self.model(xs).mean()
        langevin_samples = self.langevin_steps(xs.clone())
        langevin_score = self.model(langevin_samples).mean()
        loss = real_score - langevin_score
        self.log(f'{btype}/loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                 sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'train/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=True, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def val_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'val/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def test_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'test/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def summary(self) -> str:
        _summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                               row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        _bitmoji = torch.randn((10, 3, 128, 128), dtype=torch.float)
        _summary_string = str(summary(model=self.model, input_data=_bitmoji, **_summary_kwargs))
        return _summary_string


class LangConstEBMReg(LightningModule):
    """
    EBM for Bitmoji images.
    We assume a form p_theta(x) = exp(-E_theta(x)) / z_theta
    """

    def __init__(self, bs, t):
        super(LangConstEBMReg, self).__init__()
        self.save_hyperparameters()
        self.bs = bs
        self.t = t

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Conv2d(in_channels=15, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Flatten(),

            Linear(in_features=135, out_features=1),
        )

        self.initialise()
        self.float()

    def initialise(self):
        seed_everything(0)
        for i in range(0, 37, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['41'].weight)

    def gen_samples(self, num_samples, steps=100, seed=0):
        seed_everything(seed)
        xs = 2 * torch.rand((num_samples, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps):
            xs = self.one_langevin_step(xs, t)
        return xs.clone()

    def get_transition(self, steps=100, jump=1, seed=0):
        seed_everything(seed)
        transitions = []
        xs = 2 * torch.rand((1, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps * jump):
            xs = self.one_langevin_step(xs, t)
            if t % jump == 0:
                transitions.append(xs[0])
        return torch.stack(transitions)

    def one_langevin_step(self, xs: torch.Tensor, t):
        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        xs = xs - 0.95 * xs_grad + 0.0005 * torch.randn_like(xs)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def langevin_steps(self, xs: torch.Tensor):
        for t in range(self.t):
            xs = self.one_langevin_step(xs, t)
        return xs.clone()

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        xs, = batch
        real_score = self.model(xs).mean()
        langevin_samples = self.langevin_steps(xs.clone())
        langevin_score = self.model(langevin_samples).mean()
        loss = real_score - langevin_score
        self.log(f'{btype}/loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                 sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3, weight_decay=0.01)

    def train_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'train/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=True, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def val_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'val/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def test_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'test/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def summary(self) -> str:
        _summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                               row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        _bitmoji = torch.randn((10, 3, 128, 128), dtype=torch.float)
        _summary_string = str(summary(model=self.model, input_data=_bitmoji, **_summary_kwargs))
        return _summary_string


class MALAConstOPRegEBM(LightningModule):
    """
    EBM for Bitmoji images.
    We assume a form p_theta(x) = exp(-E_theta(x)) / z_theta
    """
    tau = 0.01

    def __init__(self, bs, t):
        super(MALAConstOPRegEBM, self).__init__()
        self.save_hyperparameters()
        self.bs = bs
        self.t = t

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=6),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            BatchNorm2d(num_features=9),

            Conv2d(in_channels=9, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=12),

            Conv2d(in_channels=12, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Conv2d(in_channels=15, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            BatchNorm2d(num_features=15),

            Flatten(),

            Linear(in_features=135, out_features=1),
        )

        self.initialise()
        self.float()

    def initialise(self):
        seed_everything(0)
        for i in range(0, 37, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['41'].weight)

    def gen_samples(self, num_samples, steps=100, seed=0):
        seed_everything(seed)
        xs = 2 * torch.rand((num_samples, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps):
            xs = self.one_langevin_step(xs, t)
        return xs.clone()

    def get_transition(self, steps=100, jump=1, seed=0):
        seed_everything(seed)
        transitions = []
        xs = 2 * torch.rand((1, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps * jump):
            xs = self.one_langevin_step(xs, t)
            if t % jump == 0:
                transitions.append(xs[0])
        return torch.stack(transitions)

    def one_langevin_step(self, xs: torch.Tensor, t):
        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        xs = xs - 0.95 * xs_grad + 0.0005 * torch.randn_like(xs)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def one_mala_step(self, xs: torch.Tensor, t):
        batch_size = xs.shape[0]

        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        proposal_xs = xs - self.tau * xs_grad + numpy.sqrt(2 * self.tau) * torch.randn_like(xs)
        proposal_xs.detach_()
        proposal_xs.clip_(min=-1, max=1)

        proposal_xs.requires_grad_(True)
        proposal_xs_output = self.model(proposal_xs)
        proposal_xs_grad = autograd.grad(outputs=proposal_xs_output.sum(), inputs=proposal_xs, only_inputs=True)[0]

        xs.detach_()
        proposal_xs.detach_()
        xs_output = xs_output.detach().reshape((-1,))
        proposal_xs_output = proposal_xs_output.detach().reshape((-1,))
        xs_grad.detach_()
        proposal_xs_grad.detach_()

        xs_flattened = xs.reshape((batch_size, -1))
        proposal_xs_flattened = proposal_xs.reshape((batch_size, -1))
        xs_grad_flattened = xs_grad.reshape((batch_size, -1))
        proposal_xs_grad_flattened = proposal_xs_grad.reshape((batch_size, -1))

        numerator = xs_flattened - proposal_xs_flattened + self.tau * proposal_xs_grad_flattened
        numerator = numerator.norm(p=2, dim=1) ** 2

        denominator = proposal_xs_flattened - xs_flattened + self.tau * xs_grad_flattened
        denominator = denominator.norm(p=2, dim=1) ** 2

        exp_power = -proposal_xs_output + xs_output - numerator / (4 * self.tau) + denominator / (4 * self.tau)
        prob = torch.exp(exp_power)

        alpha = torch.minimum(prob, torch.ones_like(prob))

        u = torch.rand(alpha.shape)

        xss = []
        for i in range(batch_size):
            if u[i] <= alpha[i]:
                xss.append(proposal_xs[i])
            else:
                xss.append(xs[i])

        xs = torch.stack(xss)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def mala_steps(self, xs: torch.Tensor):
        for t in range(self.t):
            xs = self.one_mala_step(xs, t)
        return xs.clone()

    def training_step(self, batch, batch_idx):
        xs, = batch

        real_outputs = self.model(xs)
        mala_samples = self.mala_steps(xs.clone())
        mala_outputs = self.model(mala_samples)

        reg_loss = 0.1 * (real_outputs ** 2 + mala_outputs ** 2).mean()

        real_score = real_outputs.mean()
        mala_score = mala_outputs.mean()
        div_loss = real_score - mala_score
        loss = div_loss + reg_loss

        self.log(f'train/loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=False)
        self.log(f'train/real_score', real_score, on_step=False, on_epoch=True, logger=True, sync_dist=False)
        self.log(f'train/fake_score', mala_score, on_step=False, on_epoch=True, logger=True, sync_dist=False)
        self.log(f'train/div_loss', div_loss, on_step=False, on_epoch=True, logger=True, sync_dist=False)
        self.log(f'train/reg_loss', reg_loss, on_step=False, on_epoch=True, logger=True, sync_dist=False)
        return loss

    def _common_eval_step(self, batch, btype):
        xs, = batch
        fakes = 2 * torch.rand_like(xs) - 1

        real_score = self.model(xs).mean()
        fake_score = self.model(fakes).mean()
        div_loss = real_score - fake_score
        loss = torch.abs(div_loss)

        self.log(f'{btype}/real_score', real_score, on_step=False, on_epoch=True, logger=True, sync_dist=True)
        self.log(f'{btype}/fake_score', fake_score, on_step=False, on_epoch=True, logger=True, sync_dist=True)
        self.log(f'{btype}/div_loss', div_loss, on_step=False, on_epoch=True, logger=True, sync_dist=True)
        self.log(f'{btype}/loss', loss, on_step=False, on_epoch=True, logger=True, sync_dist=True)

    def validation_step(self, batch, batch_idx):
        self._common_eval_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        self._common_eval_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'train/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=True, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def val_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'val/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def test_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'test/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def summary(self) -> str:
        _summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                               row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        _bitmoji = torch.randn((10, 3, 128, 128), dtype=torch.float)
        _summary_string = str(summary(model=self.model, input_data=_bitmoji, **_summary_kwargs))
        return _summary_string


class MALAConstInstanceEBM(LightningModule):
    """
    EBM for Bitmoji images.
    We assume a form p_theta(x) = exp(-E_theta(x)) / z_theta
    """
    tau = 0.01

    def __init__(self, bs, t):
        super(MALAConstInstanceEBM, self).__init__()
        self.save_hyperparameters()
        self.bs = bs
        self.t = t

        self.model = Sequential(
            Conv2d(in_channels=3, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            InstanceNorm2d(num_features=6, affine=True),

            Conv2d(in_channels=6, out_channels=6, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=6, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            InstanceNorm2d(num_features=6, affine=True),

            Conv2d(in_channels=6, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            InstanceNorm2d(num_features=9, affine=True),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            InstanceNorm2d(num_features=9, affine=True),

            Conv2d(in_channels=9, out_channels=9, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=9, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(2, 2)),
            InstanceNorm2d(num_features=9, affine=True),

            Conv2d(in_channels=9, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            InstanceNorm2d(num_features=12, affine=True),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(3, 3), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            InstanceNorm2d(num_features=12, affine=True),

            Conv2d(in_channels=12, out_channels=12, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=12, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            InstanceNorm2d(num_features=12, affine=True),

            Conv2d(in_channels=12, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            InstanceNorm2d(num_features=15, affine=True),

            Conv2d(in_channels=15, out_channels=15, kernel_size=(4, 4), stride=(1, 1)),
            PReLU(num_parameters=15, init=0.25),
            MaxPool2d(kernel_size=(3, 3), stride=(1, 1)),
            InstanceNorm2d(num_features=15, affine=True),

            Flatten(),

            Linear(in_features=135, out_features=1),
        )

        self.initialise()
        self.float()

    def initialise(self):
        seed_everything(0)
        for i in range(0, 37, 4):
            init.kaiming_normal_(self.model._modules[str(i)].weight, a=0.25, nonlinearity='leaky_relu')
        init.xavier_normal_(self.model._modules['41'].weight)

    def gen_samples(self, num_samples, steps=100, seed=0):
        seed_everything(seed)
        xs = 2 * torch.rand((num_samples, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps):
            xs = self.one_langevin_step(xs, t)
        return xs.clone()

    def get_transition(self, steps=100, jump=1, seed=0):
        seed_everything(seed)
        transitions = []
        xs = 2 * torch.rand((1, 3, 128, 128), dtype=torch.float) - 1
        for t in range(steps * jump):
            xs = self.one_langevin_step(xs, t)
            if t % jump == 0:
                transitions.append(xs[0])
        return torch.stack(transitions)

    def one_langevin_step(self, xs: torch.Tensor, t):
        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        xs = xs - 0.95 * xs_grad + 0.0005 * torch.randn_like(xs)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def one_mala_step(self, xs: torch.Tensor, t):
        batch_size = xs.shape[0]

        xs.requires_grad_(True)
        xs_output = self.model(xs)
        xs_grad = autograd.grad(outputs=xs_output.sum(), inputs=xs, only_inputs=True)[0]
        proposal_xs = xs - self.tau * xs_grad + numpy.sqrt(2 * self.tau) * torch.randn_like(xs)
        proposal_xs.detach_()
        proposal_xs.clip_(min=-1, max=1)

        proposal_xs.requires_grad_(True)
        proposal_xs_output = self.model(proposal_xs)
        proposal_xs_grad = autograd.grad(outputs=proposal_xs_output.sum(), inputs=proposal_xs, only_inputs=True)[0]

        xs.detach_()
        proposal_xs.detach_()
        xs_output = xs_output.detach().reshape((-1,))
        proposal_xs_output = proposal_xs_output.detach().reshape((-1,))
        xs_grad.detach_()
        proposal_xs_grad.detach_()

        xs_flattened = xs.reshape((batch_size, -1))
        proposal_xs_flattened = proposal_xs.reshape((batch_size, -1))
        xs_grad_flattened = xs_grad.reshape((batch_size, -1))
        proposal_xs_grad_flattened = proposal_xs_grad.reshape((batch_size, -1))

        numerator = xs_flattened - proposal_xs_flattened + self.tau * proposal_xs_grad_flattened
        numerator = numerator.norm(p=2, dim=1) ** 2

        denominator = proposal_xs_flattened - xs_flattened + self.tau * xs_grad_flattened
        denominator = denominator.norm(p=2, dim=1) ** 2

        exp_power = -proposal_xs_output + xs_output - numerator / (4 * self.tau) + denominator / (4 * self.tau)
        prob = torch.exp(exp_power)

        alpha = torch.minimum(prob, torch.ones_like(prob))

        u = torch.rand(alpha.shape)

        xss = []
        for i in range(batch_size):
            if u[i] <= alpha[i]:
                xss.append(proposal_xs[i])
            else:
                xss.append(xs[i])

        xs = torch.stack(xss)
        xs.detach_()
        xs.clip_(min=-1, max=1)
        return xs

    def mala_steps(self, xs: torch.Tensor):
        for t in range(self.t):
            xs = self.one_mala_step(xs, t)
        return xs.clone()

    def _common_step(self, batch, btype):
        not_training = btype != 'train'
        xs, = batch
        real_score = self.model(xs).mean()
        mala_samples = self.mala_steps(xs.clone())
        mala_score = self.model(mala_samples).mean()
        loss = real_score - mala_score
        self.log(f'{btype}/loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                 sync_dist=not_training)
        self.log(f'{btype}/real_score', real_score, on_step=False, on_epoch=True, logger=True, sync_dist=not_training)
        self.log(f'{btype}/fake_score', mala_score, on_step=False, on_epoch=True, logger=True, sync_dist=not_training)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, 'train')
        return loss

    def validation_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        torch.set_grad_enabled(True)
        self._common_step(batch, 'test')

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def train_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'train/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=True, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def val_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'val/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def test_dataloader(self):
        bitmoji_dataset = ImageFolder(bitmoji_data_dir + 'test/',
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),  # Gives a scaled version i.e., 0 to 1
                                          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                      ]))
        bitmoji_dataloader = DataLoader(bitmoji_dataset, self.bs, shuffle=False, num_workers=0,
                                        collate_fn=custom_collate_fn)
        return bitmoji_dataloader

    def summary(self) -> str:
        _summary_kwargs = dict(dtypes=[torch.float], depth=4, col_names=['input_size', 'output_size', 'num_params'],
                               row_settings=['depth', 'var_names'], verbose=0, device=torch.device('cpu'))
        _bitmoji = torch.randn((10, 3, 128, 128), dtype=torch.float)
        _summary_string = str(summary(model=self.model, input_data=_bitmoji, **_summary_kwargs))
        return _summary_string
