Skip to content

Commit

Permalink
Docs: Fix errors and omissions in Getting Started
Browse files Browse the repository at this point in the history
  • Loading branch information
cabralpinto committed Aug 29, 2023
1 parent 4b84e02 commit d28060b
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions docs/src/pages/guides/getting-started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

x, _ = zip(*MNIST("data", download=True, transform=ToTensor()))
x, _ = zip(*MNIST(str(input), transform=ToTensor(), download=True))
x = torch.stack(x) * 2 - 1
```

Expand All @@ -47,11 +47,11 @@ from diffusion.schedule import Linear

model = diffusion.Model(
data=Identity(x, batch=128, shuffle=True),
schedule=Linear(steps=1000, start=0.9999, end=0.98),
schedule=Linear(1000, 0.9999, 0.98),
noise=Gaussian(parameter="epsilon", variance="fixed"),
net=UNet(channels=(1, 64, 128, 256)),
loss=Simple(parameter="epsilon"),
device="cuda",
device="cuda" if torch.cuda.is_available() else "cpu",
)
```

Expand All @@ -62,13 +62,17 @@ losses = [*model.train(epochs=20)]
z = model.sample(batch=10)
```

> Tip
>
> If you are getting a `Process killed` message when training your model, try reducing the batch size in the data module. This error is caused by running out of RAM.
The `sample` function returns a tensor with the same shape as the dataset tensor, but with an extra diffusion time dimension. In this case, the dataset has shape `[b, c, h, w]`, so our output `z` has shape `[t, b, c, h, w]`. Now we just need to rearrange the dimensions of the output tensor to produce one final image.

```python
from einops import rearrange
from torchvision.utils import save_image

z = z[torch.linspace(0, z.shape[0] - 1, 10).long()]
z = z[torch.linspace(0, z.shape[0] - 1, 10).int()]
z = rearrange(z, "t b c h w -> c (b h) (t w)")
save_image((z + 1) / 2, "output.png")
```
Expand All @@ -77,11 +81,25 @@ And that's it! The image we just saved should look something like this:

![Random numbers being generated from noise.](/modular-diffusion/images/guides/getting-started/unconditional-linear.png)

### Add a validation loop

You might have noticed that the `train` method returns a generator object. This is to allow you to validate the model between epochs inside a `for` loop. For instance, you can see how your model is coming along by sampling from it between each training epoch, rather than only at the end.

```python
for epoch, loss in enumerate(model.train(epochs=20)):
z = model.sample(batch=10)
z = z[torch.linspace(0, z.shape[0] - 1, 10).int()]
z = rearrange(z, "t b c h w -> c (b h) (t w)")
save_image((z + 1) / 2, f"{epoch}.png")
```

> Tip
>
> If you're only interested in the final results, sample the model with the following syntax: `*_, z = model.sample(batch=10)`. In this example, this will yield a tensor with shape `[b, c, h, w]` containing only the final images.
> If you're only interested in seeing the final results, sample the model with the following syntax: `*_, z = model.sample(batch=10)`. In this example, this will yield a tensor with shape `[b, c, h, w]` containing only the generated images.
The beauty in Modular Diffusion is how easy it is to make changes to an existing model. To showcase this, let's plug in the `Cosine` schedule introduced in [Nichol & Dhariwal (2021)](https://arxiv.org/abs/2102.09672). All it does is destroy information at a slower rate in the diffusion process, which was shown to improve sample quality.
### Swap modules

The beauty in Modular Diffusion is how easy it is to make changes to an existing model. To showcase this, let's plug in the `Cosine` schedule introduced in [Nichol & Dhariwal (2021)](https://arxiv.org/abs/2102.09672). All it does is destroy information at a slower rate in the forward diffusion process, which was shown to improve sample quality.

```python
from diffusion.schedule import Cosine
Expand All @@ -92,7 +110,7 @@ model = diffusion.Model(
noise=Gaussian(parameter="epsilon", variance="fixed"),
net=UNet(channels=(1, 64, 128, 256)),
loss=Simple(parameter="epsilon"),
device="cuda",
device="cuda" if torch.cuda.is_available() else "cpu",
)
```

Expand All @@ -118,16 +136,16 @@ from diffusion.guidance import ClassifierFree

model = diffusion.Model(
data=Identity(x, y, batch=128, shuffle=True), # added y in here!
schedule=Cosine(1000),
noise=Normal(parameter="epsilon", variance="fixed"),
net=UNet(x.shape[2], labels=10), # added labels here!
schedule=Cosine(steps=1000),
noise=Gaussian(parameter="epsilon", variance="fixed"),
net=UNet(channels=(1, 64, 128, 256), labels=10), # added labels here!
guidance=ClassifierFree(dropout=0.1, strength=2), # added classifier guidance!
loss=Simple(parameter="epsilon"),
guidance=ClassifierFree(weight=0.8, dropout=0.1), # added classifier guidance!
device="cuda",
device="cuda" if torch.cuda.is_available() else "cpu",
)
```

One final change we will be making compared to our previous example is to provide the labels of the images we wish to generate to the `sample` function. As an example, let's request one image of each digit by replacing `model.sample(batch=10)` with `model.sample(y=torch.arange(10))`. We then end up with the following image:
One final change we will be making compared to our previous example is to provide the labels of the images we wish to generate to the `sample` function. As an example, let's request one image of each digit by replacing `model.sample(batch=10)` with `model.sample(y=torch.arange(1, 11))`. We then end up with the following image:

![Numbers 0 through 9 being generated from noise.](/modular-diffusion/images/guides/getting-started/conditional.png)

Expand Down

0 comments on commit d28060b

Please sign in to comment.