In [1]:
!uv pip install torch torchvision
!uv pip install lightning
!uv pip install wandb
!uv pip install scikit-image
!uv pip install jax
!uv pip install matplotlib
!uv pip install seaborn
!uv pip install equinox
!uv pip install jaxtyping
!uv pip install optax

[2K[2mResolved [1m12 packages[0m in 98ms[0m                                                 [0m
[2K[2mInstalled [1m11 packages[0m in 398ms[0m                                      [0m
 [32m+[39m [1mfilelock[0m[2m==3.15.4[0m
 [32m+[39m [1mfsspec[0m[2m==2024.6.1[0m
 [32m+[39m [1mjinja2[0m[2m==3.1.4[0m
 [32m+[39m [1mmarkupsafe[0m[2m==2.1.5[0m
 [32m+[39m [1mmpmath[0m[2m==1.3.0[0m
 [32m+[39m [1mnetworkx[0m[2m==3.3[0m
 [32m+[39m [1mnumpy[0m[2m==2.1.0[0m
 [32m+[39m [1mpillow[0m[2m==10.4.0[0m
 [32m+[39m [1msympy[0m[2m==1.13.2[0m
 [32m+[39m [1mtorch[0m[2m==2.4.0[0m
 [32m+[39m [1mtorchvision[0m[2m==0.19.0[0m
[2K[2mResolved [1m26 packages[0m in 192ms[0m                                                [0m
[2K[2mInstalled [1m15 packages[0m in 41ms[0mning==2.4.0                            [0m
 [32m+[39m [1maiohappyeyeballs[0m[2m==2.4.0[0m
 [32m+[39m [1maiohttp[0m[2m==3.10.5[0m
 [32m+[39m [1maiosign

In [1]:
import wandb

import torch
import torch.utils.data as data_utils
import torchvision

import jax
import jax.numpy as jnp
import equinox as eqx
import optax


from torchvision import transforms
from torchvision.datasets import CIFAR10

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "./data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "./saved_models/ebm"
NOISE_LEVEL = 0.1
SEED = 5678

key = jax.random.PRNGKey(SEED)

In [2]:
wandb.login()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwc5118[0m ([33miclac[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                               ])

# Loading the training dataset. We need to split it into a training and validation part
train_set = CIFAR10(root=DATASET_PATH, train=True, transform=transform, download=True)

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
train_loader = data_utils.DataLoader(train_set, batch_size=64, shuffle=True,  drop_last=True,  num_workers=4, pin_memory=True)
test_loader  = data_utils.DataLoader(test_set,  batch_size=128, shuffle=False, drop_last=False, num_workers=4)

In [5]:
dummy_x, dummy_y = next(iter(train_loader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x3x32x32
print(dummy_y.shape)  # 64
print(dummy_y)

(64, 3, 32, 32)
(64,)
[1 2 8 0 3 7 7 3 9 9 7 2 5 4 4 9 8 4 5 9 4 8 6 5 2 8 5 5 6 6 6 7 0 1 6 6 4
 7 3 0 9 3 1 5 3 9 9 6 3 4 3 1 1 1 6 2 6 1 6 3 2 2 7 3]


In [6]:
class CNN(eqx.Module):
    layers: list
    input_channels: int
    hidden_features: int
    depth: int
    out_dim: int
    activation_fn: callable
    pool_type: str
    pool_every: int
    kernel_size: int
    stride: int
    padding: int
    final_pooling: bool
    input_size: tuple

    def __init__(self, key, 
                 input_channels=3, 
                 hidden_features=64, 
                 depth=4, 
                 out_dim=1, 
                 activation_fn=jax.nn.swish, 
                 pool_type='max', 
                 pool_every=2, 
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 final_pooling=True,
                 input_size=(32, 32)):
        
        self.input_channels = input_channels
        self.hidden_features = hidden_features
        self.depth = depth
        self.out_dim = out_dim
        self.activation_fn = activation_fn
        self.pool_type = pool_type
        self.pool_every = pool_every
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.final_pooling = final_pooling
        self.input_size = input_size

        keys = jax.random.split(key, depth + 2)  # +2 for final linear layers
        
        self.layers = []
        in_channels = input_channels
        current_height, current_width = input_size
        
        for i in range(depth):
            out_channels = hidden_features * (2 ** i)
            self.layers.append(eqx.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, key=keys[i]))
            self.layers.append(activation_fn)
            in_channels = out_channels

            current_height = (current_height - kernel_size + 2 * padding) // stride + 1
            current_width = (current_width - kernel_size + 2 * padding) // stride + 1

            if (i + 1) % pool_every == 0:
                if pool_type == 'max':
                    self.layers.append(eqx.nn.MaxPool2d(kernel_size=2, stride=2))
                elif pool_type == 'avg':
                    self.layers.append(eqx.nn.AvgPool2d(kernel_size=2, stride=2))
                current_height //= 2
                current_width //= 2

        if final_pooling:
            if pool_type == 'max':
                self.layers.append(eqx.nn.MaxPool2d(kernel_size=2, stride=2))
            elif pool_type == 'avg':
                self.layers.append(eqx.nn.AvgPool2d(kernel_size=2, stride=2))
            current_height //= 2
            current_width //= 2

        self.layers.append(jnp.ravel)

        flattened_size = in_channels * current_height * current_width
        self.layers.append(eqx.nn.Linear(flattened_size, in_channels, key=keys[-2]))
        self.layers.append(activation_fn)
        self.layers.append(eqx.nn.Linear(in_channels, out_dim, key=keys[-1]))

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x.squeeze(axis=-1)

In [7]:
key, subkey = jax.random.split(key, 2)
model = CNN(subkey)

In [9]:
print(model)

CNN(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[64,3,3,3],
      bias=f32[64,1,1],
      in_channels=3,
      out_channels=64,
      kernel_size=(3, 3),
      stride=(1, 1),
      padding=((1, 1), (1, 1)),
      dilation=(1, 1),
      groups=1,
      use_bias=True,
      padding_mode='ZEROS'
    ),
    <wrapped function silu>,
    Conv2d(
      num_spatial_dims=2,
      weight=f32[128,64,3,3],
      bias=f32[128,1,1],
      in_channels=64,
      out_channels=128,
      kernel_size=(3, 3),
      stride=(1, 1),
      padding=((1, 1), (1, 1)),
      dilation=(1, 1),
      groups=1,
      use_bias=True,
      padding_mode='ZEROS'
    ),
    <wrapped function silu>,
    MaxPool2d(
      init=-inf,
      operation=<function max>,
      num_spatial_dims=2,
      kernel_size=(2, 2),
      stride=(2, 2),
      padding=((0, 0), (0, 0)),
      use_ceil=False
    ),
    Conv2d(
      num_spatial_dims=2,
      weight=f32[256,128,3,3],
      bias=f32[256,1,1],
      in_channel

In [14]:
def stein_score(model, x):
    score_fn = jax.grad(model, argnums=0)
    return score_fn(x)

def denoising_score_matching_loss(model, x, key):
    noise = jax.random.normal(key, x.shape) * NOISE_LEVEL
    x_noisy = x + noise
    scores = jax.vmap(stein_score, in_axes=(None, 0))(model, x_noisy)
    loss = 0.5 * jnp.mean(jnp.linalg.norm(jnp.reshape(scores + (noise / NOISE_LEVEL), (scores.shape[0], -1)), axis=1, ord=2) ** 2)

    return loss

def sample_langevin_dynamics(key, model, x, num_steps=10, step_size=0.01):
    def step(i, x):
        noise = jax.random.normal(jax.random.fold_in(key, i), shape=x.shape)
        grad = stein_score(model, x)
        x = x - (step_size / 2.) * grad + jnp.sqrt(step_size) * noise
        return x
    
    return jax.lax.fori_loop(0, num_steps, step, x)

vectorized_sample_langevin_dynamics = eqx.filter_jit(jax.vmap(sample_langevin_dynamics, in_axes=(0, None, 0)))


def sample_images(key, model, num_samples=32):
    z = jax.random.normal(key, (num_samples, 3, 32, 32))
    samples = vectorized_sample_langevin_dynamics(jax.random.split(key, num_samples), model, z)
    return samples

In [None]:
batch_size, channels, height, width = x.shape  # Unpack dimensions
input_dim = channels * height * width

# Repeat x for num_slices slices
x = jnp.tile(x[:, None, :, :, :], (1, num_slices, 1, 1, 1)).reshape(
    -1, channels, height, width
)

# Generate random projection vectors
key, subkey = jax.random.split(key)
vectors = jax.random.normal(subkey, shape=x.shape)
vectors = vectors / (jnp.linalg.norm(vectors, axis=(1, 2, 3), keepdims=True) + 1e-8)

h_x, h_x_v = eqx.filter_jvp(jax.vmap(model), (x,), (vectors,))

# Compute loss components
loss_1 = jnp.sum(h_x_v * vectors, axis=(-3, -2, -1))
loss_2 = 0.5 * jnp.sum(h_x * vectors, axis=(-3, -2, -1)) ** 2

loss = loss_1 + loss_2

In [27]:
batch_size, channels, height, width = dummy_x.shape
input_dim = channels * height * width

x = jnp.tile(dummy_x[:, None, :, :, :], (1, 5, 1, 1, 1)).reshape(
    -1, channels, height, width
)

key, subkey = jax.random.split(key)
vectors = jax.random.normal(subkey, shape=x.shape)
vectors = vectors / (jnp.sqrt(jnp.sum(vectors ** 2, axis=(1, 2, 3), keepdims=True)) + 1e-8)

In [29]:
eqx.filter_jvp(model, (x[0],), (vectors[0],))

(Array(-0.00144907, dtype=float32), Array(3.3390643e-07, dtype=float32))

In [26]:
jnp.sqrt(jnp.sum(vectors ** 2, axis=(1, 2, 3), keepdims=True)).shape

(320, 1, 1, 1)

In [31]:
value, grads = eqx.filter_value_and_grad(denoising_score_matching_loss)(model, dummy_x, key)

(64, 3, 32, 32)


In [11]:
vectorized_sample_langevin_dynamics(key, model, dummy_x[:5]).shape

(5, 3, 32, 32)

In [15]:
samples = sample_images(key, model, num_samples=12)

(32, 3, 32, 32)

In [12]:
def train(
    model: CNN,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    optim: optax.GradientTransformation,
    key,
    epochs: int = 100,
    print_every: int = 100
) -> CNN:
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def make_step(
        model: CNN,
        opt_state,
        x,
        key
    ):
        loss_value, grads = eqx.filter_value_and_grad(denoising_score_matching_loss)(model, x, key)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    step = 0
    for epoch in range(epochs):
        for x, _ in train_loader:
            x = x.numpy()
            subkey, key = jax.random.split(key)
            model, opt_state, train_loss = make_step(model, opt_state, x, subkey)
            step += 1
            if step % print_every == 0:
                print(f"{step=}, train_loss={train_loss.item()}")
    
    return model

In [13]:
optim = optax.adam(1e-3)

train(model, train_loader, test_loader, optim, key, 1000, 100)

step=0, train_loss=1526.641357421875
step=1, train_loss=1536.654052734375
step=2, train_loss=1526.6640625
step=3, train_loss=1525.7783203125
step=4, train_loss=1485.645263671875
step=5, train_loss=1422.9149169921875
step=6, train_loss=1352.9351806640625
step=7, train_loss=1238.6435546875
step=8, train_loss=1187.0799560546875
step=9, train_loss=1146.7607421875
step=10, train_loss=1072.43115234375
step=11, train_loss=1043.5189208984375
step=12, train_loss=1028.2183837890625
step=13, train_loss=1025.9178466796875
step=14, train_loss=1002.3946533203125
step=15, train_loss=994.133544921875


KeyboardInterrupt: 