In [1]:
!nvidia-smi

Mon Aug 19 10:30:37 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Quadro RTX 6000                On  | 00000000:01:00.0 Off |                  Off |
| 46%   68C    P2             262W / 260W |  20529MiB / 24576MiB |    100%   E. Process |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Quadro RTX 6000                On  | 00000000:25:00.0 Off |  

In [7]:
%pip install torch torchvision
%pip install torch_geometric
%pip install lightning
%pip install wandb
%pip install scikit-image
%pip install egnn-pytorch
%pip install matplotlib
%pip install seaborn

Collecting torch
  Downloading torch-2.4.0-cp312-cp312-manylinux1_x86_64.whl.metadata (26 kB)
Collecting filelock (from torch)
  Downloading filelock-3.15.4-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy (from torch)
  Downloading sympy-1.13.2-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.3-py3-none-any.whl.metadata (5.1 kB)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2024.6.1-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86

In [2]:
## Standard libraries
import os
import json
import math
import numpy as np
import random

## Imports for plotting
import matplotlib.pyplot as plt
from matplotlib import cm
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
from mpl_toolkits.mplot3d.axes3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# Metrics
import wandb
# PyTorch Lightning
import lightning as pl
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "./data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "./saved_models/ebm"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

  set_matplotlib_formats('svg', 'pdf') # For export
Seed set to 42


Device: cuda:0


In [3]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mwc5118[0m ([33miclac[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset. We need to split it into a training and validation part
train_set = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_loader = data_utils.DataLoader(train_set, batch_size=256, shuffle=True,  drop_last=True,  num_workers=4, pin_memory=True)
test_loader  = data_utils.DataLoader(test_set,  batch_size=512, shuffle=False, drop_last=False, num_workers=4)

In [6]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)
    
class ConfigurableCNNModel(nn.Module):
    def __init__(self, 
                 input_channels=3, 
                 hidden_features=64, 
                 depth=4, 
                 out_dim=1, 
                 activation_fn=Swish, 
                 pool_type='max', 
                 pool_every=2, 
                 kernel_size=3,
                 stride=1,  # Adjusting stride for maintaining dimensions
                 padding=1, # Adjusting padding for maintaining dimensions
                 final_pooling=True,
                 input_size=(32, 32),  # Input size of the image
                 **kwargs):
        super().__init__()
        
        # List to hold the layers
        cnn_layers = []
        in_channels = input_channels
        
        # Activation function
        self.activation_fn = activation_fn()
        
        # Pooling layer setup
        if pool_type == 'max':
            pooling_layer = nn.MaxPool2d(kernel_size=2, stride=2)
        elif pool_type == 'avg':
            pooling_layer = nn.AvgPool2d(kernel_size=2, stride=2)
        else:
            raise ValueError("pool_type must be either 'max' or 'avg'")
        
        # Track the current spatial dimensions of the feature map
        current_height, current_width = input_size
        
        # Generate CNN layers based on the specified depth
        for i in range(depth):
            out_channels = hidden_features * (2 ** i)
            cnn_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding))
            cnn_layers.append(self.activation_fn)
            in_channels = out_channels  # Update in_channels for the next layer
            
            # Calculate the output size after this convolution
            current_height = (current_height - kernel_size + 2 * padding) // stride + 1
            current_width = (current_width - kernel_size + 2 * padding) // stride + 1
            
            # Apply pooling if needed
            if (i + 1) % pool_every == 0:
                cnn_layers.append(pooling_layer)
                current_height //= 2
                current_width //= 2
        
        # Optionally add a final pooling layer to reduce the spatial dimension further
        if final_pooling:
            cnn_layers.append(pooling_layer)
            current_height //= 2
            current_width //= 2
        
        # Flatten the output and add fully connected layers
        cnn_layers.append(nn.Flatten())
        
        # Determine the flattened size after convolution and pooling
        flattened_size = in_channels * current_height * current_width
        
        cnn_layers.append(nn.Linear(flattened_size, in_channels))
        cnn_layers.append(self.activation_fn)
        cnn_layers.append(nn.Linear(in_channels, out_dim))
        
        # Wrap the list of layers into an nn.Sequential module
        self.cnn_layers = nn.Sequential(*cnn_layers)

    def forward(self, x):
        x = self.cnn_layers(x).squeeze(dim=-1)
        return x

In [7]:
from torchvision.models import resnet18

class ConfigurableResNetModel(nn.Module):
    def __init__(self, input_channels=3, hidden_features=512, out_dim=1, **kwargs):
        super().__init__()
        
        # Load a pre-trained ResNet-18 model and modify it for CIFAR-10 input size
        self.resnet = resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
        
        # Modify the input layer if the number of channels is different from 3
        if input_channels != 3:
            self.resnet.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        
        # Adjust the maxpooling and fully connected layers
        self.resnet.maxpool = nn.Identity()  # Remove maxpooling for CIFAR-10's small input size
        self.resnet.fc = nn.Sequential(
            nn.Linear(512, hidden_features),
            Swish(),
            nn.Linear(hidden_features, out_dim)
        )
        
    def forward(self, x):
        x = self.resnet(x)
        return x.squeeze(dim=-1)

In [8]:
class Sampler:
    def __init__(self, model, img_shape, sample_size, max_len=8192):
        self.model = model
        self.img_shape = img_shape
        self.sample_size = sample_size
        self.max_len = max_len
        self.examples = [(torch.rand((1,)+img_shape)*2-1) for _ in range(self.sample_size)]

    def sample_new_exmps(self, steps=60, step_size=10):
        """
        Function for getting a new batch of "fake" images.
        Inputs:
            steps - Number of iterations in the MCMC algorithm
            step_size - Learning rate nu in the algorithm above
        """
        # Choose 95% of the batch from the buffer, 5% generate from scratch
        n_new = np.random.binomial(self.sample_size, 0.05)
        rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1
        old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size-n_new), dim=0)
        inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(device)

        # Perform MCMC sampling
        inp_imgs = Sampler.generate_samples(self.model, inp_imgs, steps=steps, step_size=step_size)

        # Add new images to the buffer and remove old ones if needed
        self.examples = list(inp_imgs.to(torch.device("cpu")).chunk(self.sample_size, dim=0)) + self.examples
        self.examples = self.examples[:self.max_len]
        return inp_imgs

    @staticmethod
    def generate_samples(model, inp_imgs, steps=60, step_size=10, return_img_per_step=False):
        is_training = model.training
        model.eval()
        for p in model.parameters():
            p.requires_grad_(False)

        noise = torch.randn_like(inp_imgs)
        imgs_per_step = []
        inp_imgs.requires_grad_(True)

        # Enable gradient calculation if not already the case
        had_gradients_enabled = torch.is_grad_enabled()
        torch.set_grad_enabled(True)

        for _ in range(steps):
            noise.normal_(0, 0.005)
            inp_imgs.data.add_(noise.data)
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            # Part 2: calculate gradients for the current input.
            out_imgs = -model(inp_imgs)
            out_imgs.sum().backward()
            inp_imgs.grad.data.clamp_(-0.03, 0.03) # For stabilizing and preventing too high gradients

            # Apply gradients to our current samples
            inp_imgs.data.add_(-step_size * inp_imgs.grad.data)
            inp_imgs.grad.detach_()
            inp_imgs.grad.zero_()
            inp_imgs.data.clamp_(min=-1.0, max=1.0)

            if return_img_per_step:
                imgs_per_step.append(inp_imgs.clone().detach())

        for p in model.parameters():
            p.requires_grad = True
        model.train(is_training)
        torch.set_grad_enabled(had_gradients_enabled)

        if return_img_per_step:
            return torch.stack(imgs_per_step, dim=0)
        else:
            return inp_imgs

In [9]:
class DeepEnergyModel(pl.LightningModule):
    def __init__(self, model_class, img_shape, batch_size, alpha=0.1, lr=1e-4, beta1=0.0, **model_args):
        super().__init__()
        self.save_hyperparameters()

        self.cnn = model_class(**model_args)
        self.sampler = Sampler(self.cnn, img_shape=img_shape, sample_size=batch_size)
        self.example_input_array = torch.zeros(1, *img_shape)

    def forward(self, x):
        z = self.cnn(x)
        return z

    def configure_optimizers(self):
        # Energy models can have issues with momentum as the loss surfaces changes with its parameters.
        # Hence, we set it to 0 by default.
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(self.hparams.beta1, 0.999))
        scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97) # Exponential decay over epochs
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        # We add minimal noise to the original images to prevent the model from focusing on purely "clean" inputs
        real_imgs, _ = batch
        small_noise = torch.randn_like(real_imgs) * 0.005
        real_imgs.add_(small_noise).clamp_(min=-1.0, max=1.0)

        # Obtain samples
        fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)

        # Predict energy score for all images
        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)

        # Calculate losses
        reg_loss = self.hparams.alpha * (real_out ** 2 + fake_out ** 2).mean()
        cdiv_loss = fake_out.mean() - real_out.mean()
        loss = reg_loss + cdiv_loss

        # Logging
        self.log('loss', loss)
        self.log('loss_regularization', reg_loss)
        self.log('loss_contrastive_divergence', cdiv_loss)
        self.log('metrics_avg_real', real_out.mean())
        self.log('metrics_avg_fake', fake_out.mean())
        return loss

    def validation_step(self, batch, batch_idx):
        # For validating, we calculate the contrastive divergence between purely random images and unseen examples
        # Note that the validation/test step of energy-based models depends on what we are interested in the model
        real_imgs, _ = batch
        fake_imgs = torch.rand_like(real_imgs) * 2 - 1

        inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
        real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)

        cdiv = fake_out.mean() - real_out.mean()
        self.log('val_contrastive_divergence', cdiv)
        self.log('val_fake_out', fake_out.mean())
        self.log('val_real_out', real_out.mean())

In [10]:
class GenerateCallback(pl.Callback):
    def __init__(self, batch_size=8, vis_steps=8, num_steps=256, every_n_epochs=5):
        super().__init__()
        self.batch_size = batch_size         # Number of images to generate
        self.vis_steps = vis_steps           # Number of steps within generation to visualize
        self.num_steps = num_steps           # Number of steps to take during generation
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)

    def on_train_epoch_end(self, trainer, pl_module):
        # Skip for all other epochs
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Generate images
            imgs_per_step = self.generate_imgs(pl_module)
            # Plot and add to tensorboard
            for i in range(imgs_per_step.shape[1]):
                step_size = self.num_steps // self.vis_steps
                imgs_to_plot = imgs_per_step[step_size-1::step_size,i]
                grid = torchvision.utils.make_grid(imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, value_range=(-1,1))
                trainer.logger.log_image(f"generation_{i}", [grid], step=trainer.current_epoch)

    def generate_imgs(self, pl_module):
        pl_module.eval()
        start_imgs = torch.rand((self.batch_size,) + pl_module.hparams["img_shape"]).to(pl_module.device)
        start_imgs = start_imgs * 2 - 1
        torch.set_grad_enabled(True)  # Tracking gradients for sampling necessary
        imgs_per_step = Sampler.generate_samples(pl_module.cnn, start_imgs, steps=self.num_steps, step_size=10, return_img_per_step=True)
        torch.set_grad_enabled(False)
        pl_module.train()
        return imgs_per_step

In [11]:
class SamplerCallback(pl.Callback):
    def __init__(self, num_imgs=32, every_n_epochs=5):
        super().__init__()
        self.num_imgs = num_imgs             # Number of images to plot
        self.every_n_epochs = every_n_epochs # Only save those images every N epochs (otherwise tensorboard gets quite large)

    def on_train_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.every_n_epochs == 0:
            exmp_imgs = torch.cat(random.choices(pl_module.sampler.examples, k=self.num_imgs), dim=0)
            grid = torchvision.utils.make_grid(exmp_imgs, nrow=4, normalize=True, value_range=(-1,1))
            trainer.logger.log_image("sampler", [grid], step=trainer.current_epoch)

In [12]:
class OutlierCallback(pl.Callback):
    def __init__(self, batch_size=1024):
        super().__init__()
        self.batch_size = batch_size

    def on_epoch_end(self, trainer, pl_module):
        with torch.no_grad():
            pl_module.eval()
            rand_imgs = torch.rand((self.batch_size,) + pl_module.hparams["img_shape"]).to(pl_module.device)
            rand_imgs = rand_imgs * 2 - 1.0
            rand_out = pl_module.cnn(rand_imgs).mean()
            pl_module.train()

        trainer.logger.experiment.add_scalar("rand_out", rand_out, global_step=trainer.current_epoch)


In [13]:
def train_model(**kwargs):
    wandb_logger = WandbLogger(project='ebm-cifar10', log_model="all")
    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "CIFAR10"),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=120,
                         gradient_clip_val=0.1,
                         logger=wandb_logger,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="min", monitor='val_contrastive_divergence'),
                                    GenerateCallback(every_n_epochs=1),
                                    SamplerCallback(every_n_epochs=1),
                                    OutlierCallback(),
                                    LearningRateMonitor("epoch")
                                   ])
    # Check whether pretrained model exists. If yes, load it and skip training
    # pretrained_filename = os.path.join(CHECKPOINT_PATH, "MNIST.ckpt")
    # if os.path.isfile(pretrained_filename):
    #     print("Found pretrained model, loading...")
    #     model = DeepEnergyModel.load_from_checkpoint(pretrained_filename)
    # else:
    pl.seed_everything(55)
    model = DeepEnergyModel(**kwargs)
    # compiled_model = torch.compile(model)
    trainer.fit(model, train_loader, test_loader)
    # model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    # No testing as we are more interested in other properties
    wandb.finish()
    return model

In [14]:
model = train_model(
    model_class=ConfigurableCNNModel,
    img_shape=(3,32,32),
    batch_size=train_loader.batch_size,
    lr=5e-4,
    beta1=0.0,
    hidden_features=64,
    depth=3,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 55


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name | Type                 | Params | Mode  | In sizes       | Out sizes
-----------------------------------------------------------------------------------
0 | cnn  | ConfigurableCNNModel | 4.6 M  | train | [1, 3, 32, 32] | [1]      
-----------------------------------------------------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.263    Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Epoch 16:  56%|█████▌    | 109/195 [02:41<02:07,  0.67it/s, v_num=23sb]    

In [None]:
model

ConfigurableCNNModel(
  (activation_fn): Swish()
  (cnn_layers): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Swish()
    (2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Swish()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Swish()
    (7): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): Swish()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Swish()
    (12): Conv2d(2048, 4096, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): Swish()
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (16): Flat

In [20]:
wandb.finish()

0,1
lr-Adam,▁
trainer/global_step,▁

0,1
lr-Adam,0.001
trainer/global_step,0.0
