Skip to content

Commit

Permalink
Final touch
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Apr 6, 2024
1 parent 858cb32 commit b7c9b39
Show file tree
Hide file tree
Showing 19 changed files with 466 additions and 352 deletions.
40 changes: 32 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
## About

This repository implements the method proposed in [Score-based Diffusion Models in Function Space](https://arxiv.org/abs/2302.07400), i.e.,
This repository implements the method, denoising diffusion operator (DDO), proposed in [Score-based Diffusion Models in Function Space](https://arxiv.org/abs/2302.07400), i.e.,
a function-space version of diffusion probabilistic models, using JAX and Flax.

> [!IMPORTANT]
Expand All @@ -25,20 +25,44 @@ The `experiments` folder contains a use case on MNIST-SDF. For training on 32x32

```bash
cd experiments/mnist_sdf
python main.py --mode=train --model={unet/uno} --epochs=1000
python main.py \
--config=config.py \
--mode=train \
--model=<uno|unet> \
--dataset=mnist_sdf \
--workdir=<dir>
```

Pre-trained weights can be found in `experiments/mnist_sdf/checkpoints/`
(which is also where they should be left, since the checkpoint manager looks for checkpoints in this folder).

Then, in order to sample, call:
Then, sample images viaL

```bash
cd experiments/mnist_sdf
python main.py --mode=sample --model={unet/uno}
python main.py \
--config=config.py \
--mode=sample \
--model=<uno|unet> \
--dataset=mnist_sdf \
--workdir=<dir>
```

This samples 32x32-, 64x64- and 128x128-dimensional images and creates some figures in `experiments/mnist_sdf/figures`.

Below are DDIM-sampled images from the DDO when either a UNet or a UNO is used as score model (a DDO with a UNet is just a DDPM). The UNet parameterization yields high-quality results already after
20 epochs or so. The UNO works worse, but fine, than the UNet when 32x32-dimensional images are sampled. When sampling 64x64-dimensional images it mainly produces noise

<div align="center">
<div>UNet 32x32</div>
<img src="fig/mnist_sdf-unet-32x32.png" width="750">
</div>

<div align="center">
<div>UNO 32x32</div>
<img src="fig/mnist_sdf-uno-32x32.png" width="750">
</div>

<div align="center">
<div>UNO 64x64</div>
<img src="fig/mnist_sdf-uno-64x64.png" width="750">
</div>

## Installation

Expand Down
31 changes: 7 additions & 24 deletions ddo/ddo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class DenoisingDiffusionOperator(nn.Module):

def setup(self):
self.n_diffusions = len(self.alpha_schedule)
self._alphas = self.alpha_schedule
self._betas = jnp.asarray(1.0 - self._alphas)
self._alphas = jnp.asarray(self.alpha_schedule)
self._betas = 1.0 - self._alphas
self._alphas_bar = jnp.cumprod(self._alphas)
self._sqrt_alphas_bar = jnp.sqrt(self._alphas_bar)
self._sqrt_1m_alphas_bar = jnp.sqrt(1.0 - self._alphas_bar)
Expand All @@ -29,7 +29,7 @@ def loss(self, y0, is_training):
time_key, rng_key = jr.split(rng_key)
times = jr.randint(
key=time_key,
minval=1,
minval=0,
maxval=self.n_diffusions,
shape=(y0.shape[0],),
)
Expand All @@ -52,32 +52,15 @@ def q_pred_reparam(self, y0, t, noise):
scale = self._sqrt_1m_alphas_bar[t].reshape(shape) * noise
return mean + scale

def sample(self, sample_shape=(32, 32, 32, 1), **kwargs):
init_key, rng_key = jr.split(self.make_rng("sample"))

yt = jr.normal(init_key, sample_shape)
for t in reversed(range(self.n_diffusions)):
z = jr.normal(jr.fold_in(rng_key, t), sample_shape)
eps = self.score_model(
yt, jnp.full(yt.shape[0], t), is_training=False
)
yn = self._betas[t] / self._sqrt_1m_alphas_bar[t] * eps
yn = yt - yn
yn = yn / jnp.sqrt(self._alphas[t])
yt = yn + jnp.sqrt(self._betas[t]) * z

return yt

def sample_ddim(self, sample_shape=(32, 32, 32, 1), n=100, **kwargs):
init_key, rng_key = jr.split(self.make_rng("sample"))
def sample(self, sample_shape=(32, 32, 32, 1), n=100, **kwargs):
timesteps = np.arange(0, self.n_diffusions, self.n_diffusions // n)
yt = jr.normal(init_key, sample_shape)
yt = jr.normal(self.make_rng("sample"), sample_shape)
for t in reversed(np.arange(1, n)):
tprev, tcurr = timesteps[(t - 1) : (t + 1)]
yt = self.denoise_ddim(jr.fold_in(rng_key, tcurr), yt, tcurr, tprev)
yt = self._denoise(yt, tcurr, tprev)
return yt

def denoise_ddim(self, rng_key, yt, t, tprev):
def _denoise(self, yt, t, tprev):
eps = self.score_model(yt, jnp.full(yt.shape[0], t), is_training=False)
lhs = (yt - eps * self._sqrt_1m_alphas_bar[t]) / self._sqrt_alphas_bar[
t
Expand Down
2 changes: 1 addition & 1 deletion ddo/noise_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def cosine_alpha_schedule(timesteps, s=0.008):
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = alphas_cumprod[1:] / alphas_cumprod[:-1]
alphas = np.clip(alphas, a_min=0.001, a_max=0.9999)
alphas = np.clip(alphas, a_min=0.0001, a_max=0.9999)
return alphas


Expand Down
80 changes: 42 additions & 38 deletions ddo/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,23 @@ def _timestep_embedding(timesteps, embedding_dim: int, dtype=jnp.float32):
return emb


class UNetBlock(nn.Module):
class ResidualBlock(nn.Module):
"""UNet block for diffusion models.
Does two convolutions and adds the time embedding in between. Also does
group normalisation and dropout
"""

n_out_channels: int
n_groups: int = 8
dropout_rate: float = 0.1
dropout_rate: float

@nn.compact
def __call__(self, inputs, times, is_training):
time_embedding = nn.Dense(self.n_out_channels, use_bias=False)(
nn.swish(times)
)
time_embedding = nn.Dense(self.n_out_channels)(nn.swish(times))
hidden = inputs

# convolution with pre-layer norm
hidden = nn.GroupNorm(self.n_groups)(hidden)
hidden = nn.BatchNorm()(hidden, use_running_average=not is_training)
hidden = nn.swish(hidden)
hidden = nn.Conv(
self.n_out_channels,
Expand All @@ -54,8 +51,8 @@ def __call__(self, inputs, times, is_training):
# time conditioning
hidden = hidden + time_embedding[:, None, None, :]

# convolution with pre-layer norm
hidden = nn.GroupNorm(self.n_groups)(hidden)
# convolution with pre-layer norm and dropout
hidden = nn.BatchNorm()(hidden, use_running_average=not is_training)
hidden = nn.swish(hidden)
hidden = nn.Dropout(self.dropout_rate)(
hidden, deterministic=not is_training
Expand All @@ -68,13 +65,15 @@ def __call__(self, inputs, times, is_training):
)(hidden)

if inputs.shape[-1] != self.n_out_channels:
inputs = nn.Conv(
residual = nn.Conv(
self.n_out_channels,
kernel_size=(3, 3),
kernel_size=(1, 1),
strides=(1, 1),
padding="SAME",
)(inputs)
return hidden + inputs
else:
residual = inputs
return (hidden + residual) / 1.414213


class UNet(nn.Module):
Expand All @@ -85,17 +84,17 @@ class UNet(nn.Module):

n_blocks: int
n_channels: int
dim_embedding: int
channel_multipliers: Sequence[int]
n_groups: int = 8
dropout_rate: float

def time_embedding(self, times):
times = _timestep_embedding(times, self.dim_embedding)
times = _timestep_embedding(times, self.n_channels * 2)
times = nn.Sequential(
[
nn.Dense(self.dim_embedding * 2),
nn.Dense(self.n_channels * 8),
nn.swish,
nn.Dense(self.n_channels * 8),
nn.swish,
nn.Dense(self.dim_embedding * 2),
]
)(times)
return times
Expand All @@ -109,43 +108,48 @@ def __call__(self, inputs, times, is_training, **kwargs):
hidden = inputs
# lift data
hidden = nn.Conv(
self.n_channels, kernel_size=(3, 3), strides=(1, 1), padding="SAME"
self.n_channels, kernel_size=(1, 1), strides=(1, 1), padding="SAME"
)(hidden)

hs = []
# left block of UNet
for i, channel_mult in enumerate(self.channel_multipliers):
for channel_mult in self.channel_multipliers[:-1]:
n_outchannels = channel_mult * self.n_channels
for _ in range(self.n_blocks):
hidden = UNetBlock(n_outchannels)(hidden, times, is_training)
hs.append(hidden)
hidden = nn.max_pool(hidden, window_shape=(2, 2), strides=(2, 2))
hidden = ResidualBlock(n_outchannels, self.dropout_rate)(
hidden, times, is_training
)
hs.append(hidden)
hidden = nn.avg_pool(hidden, window_shape=(2, 2), strides=(2, 2))

# middle block of UNet
for _ in range(self.n_blocks):
hidden = UNetBlock(n_out_channels=hidden.shape[-1])(
n_outchannels = self.channel_multipliers[-1] * self.n_channels
hidden = ResidualBlock(n_outchannels, self.dropout_rate)(
hidden, times, is_training
)
hs.append(hidden)

hidden = hs.pop()
# right block of UNet
for i, channel_mult in enumerate(reversed(self.channel_multipliers)):
for channel_mult in reversed(self.channel_multipliers[:-1]):
n_outchannels = channel_mult * self.n_channels
hidden = nn.ConvTranspose(
n_outchannels, kernel_size=(2, 2), strides=(2, 2)
)(hidden)
for bl in range(self.n_blocks):
hidden = (
jnp.concatenate([hidden, hs.pop()], axis=-1)
if bl == 0
else hidden
)
hidden = UNetBlock(n_out_channels=n_outchannels)(
hidden = jax.image.resize(
hidden,
(B, hidden.shape[1] * 2, hidden.shape[2] * 2, hidden.shape[3]),
method="bilinear",
)
for _ in range(self.n_blocks):
hidden = jnp.concatenate([hidden, hs.pop()], axis=-1)
hidden = ResidualBlock(n_outchannels, self.dropout_rate)(
hidden, times, is_training
)

hidden = nn.GroupNorm(self.n_groups)(hidden)
hidden = nn.BatchNorm()(hidden, use_running_average=not is_training)
hidden = nn.swish(hidden)
outputs = nn.Conv(C, kernel_size=(1, 1))(hidden)
outputs = nn.Conv(
C,
kernel_size=(1, 1),
strides=(1, 1),
padding="SAME",
kernel_init=nn.initializers.zeros,
)(hidden)
return outputs
Loading

0 comments on commit b7c9b39

Please sign in to comment.