Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into main
  • Loading branch information
cabralpinto committed Oct 15, 2023
2 parents ec58809 + 9215a1e commit 3aea1a5
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![PyPI version](https://badge.fury.io/py/modular-diffusion.svg)](https://badge.fury.io/py/modular-diffusion)
[![Documentation](https://img.shields.io/badge/docs-stable-blue.svg)](https://cabralpinto.github.io/modular-diffusion/)
[![MIT license](https://img.shields.io/badge/license-MIT-blue.svg)](https://lbesson.mit-license.org/)
[![Discord](https://dcbadge.vercel.app/api/server/zUpYzQBm?style=flat&compact=true)](https://discord.gg/zUpYzQBm)
[![Discord](https://dcbadge.vercel.app/api/server/mYJWQATfTV?style=flat&compact=true)](https://discord.gg/mYJWQATfTV)

Modular Diffusion provides an easy-to-use modular API to design and train custom Diffusion Models with PyTorch. Whether you're an enthusiast exploring Diffusion Models or a hardcore ML researcher, **this framework is for you**.

Expand Down Expand Up @@ -67,7 +67,7 @@ Check out the [Getting Started Guide](https://cabralpinto.github.io/modular-diff

## Contributing

We appreciate your support and welcome your contributions! Please fell free to submit pull requests if you found a bug or typo you want to fix. If you want to contribute a new prebuilt module or feature, please start by opening an issue and discussing it with us. If you don't know where to begin, take a look at the [open issues](https://github.com/cabralpinto/modular-diffusion/issues). Please read our [Contributing Guide](https://github.com/cabralpinto/modular-diffusion/blob/main/CONTRIBUTING.md) for more details.
We appreciate your support and welcome your contributions! Please feel free to submit pull requests if you found a bug or typo you want to fix. If you want to contribute a new prebuilt module or feature, please start by opening an issue and discussing it with us. If you don't know where to begin, take a look at the [open issues](https://github.com/cabralpinto/modular-diffusion/issues). Please read our [Contributing Guide](https://github.com/cabralpinto/modular-diffusion/blob/main/CONTRIBUTING.md) for more details.

## License

Expand Down
3 changes: 2 additions & 1 deletion docs/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/src/pages/guides/custom-modules.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ The `schedule` method precomputes `alpha` and `delta` (cumulative product of `al

## Denoiser neural network

Modular Diffusion comes with general-use `UNet` and `Transformer` classes, which have proven to be effective denoising networks in the context of Diffusion Models. However, it is not uncommon to see authors make modifications to these networks to achieve even better results. To design your own original network, extend the base abstract `Net` class. This class acts as only a thin wrapper over the standard PyTorch `nn.Module` class, meaning you can use it exactly the same way. The `forward` method should take three tensor arguments: the noisy input `x`, the conditioning matrix `y`, and the diffusion time steps `t`.
Modular Diffusion comes with general-use `UNet` and `Transformer` classes, which have proven to be effective denoising networks in the context of Diffusion Models. However, it is not uncommon to see authors make modifications to these networks to achieve even better results. To design your own original network, extend the base abstract `Net` class. This class acts as only a thin wrapper over the standard Pytorch `nn.Module` class, meaning you can use it exactly the same way. The `forward` method should take three tensor arguments: the noisy input `x`, the conditioning matrix `y`, and the diffusion time steps `t`.

> Network output shape
>
Expand Down
4 changes: 2 additions & 2 deletions docs/src/pages/guides/getting-started.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ Before you start, please install Modular Diffusion in your local Python environm
python -m pip install modular-diffusion
```

Additionally, ensure you've installed the correct [PyTorch distribution](https://pytorch.org/get-started/locally/) for your system.
Additionally, ensure you've installed the correct [Pytorch distribution](https://pytorch.org/get-started/locally/) for your system.

## Train a simple model

The first step before training a Diffusion Model is to load your dataset. In this example, we will be using [MNIST](http://yann.lecun.com/exdb/mnist/), which includes 70,000 grayscale images of handwritten digits, and is a great simple dataset to prototype your image models. We are going to load MNIST with [PyTorch Vision](https://pytorch.org/vision/stable/index.html), but you can load your dataset any way you like, as long as it results in a `torch.Tensor` object. We are also going to discard the labels and scale the data to the commonly used $[-1, 1]$ range.
The first step before training a Diffusion Model is to load your dataset. In this example, we will be using [MNIST](http://yann.lecun.com/exdb/mnist/), which includes 70,000 grayscale images of handwritten digits, and is a great simple dataset to prototype your image models. We are going to load MNIST with [Pytorch Vision](https://pytorch.org/vision/stable/index.html), but you can load your dataset any way you like, as long as it results in a `torch.Tensor` object. We are also going to discard the labels and scale the data to the commonly used $[-1, 1]$ range.

```python
import torch
Expand Down
2 changes: 1 addition & 1 deletion docs/src/pages/modules/denoising-network.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ visualizations: maybe

# {frontmatter.title}

The backbone of Diffusion Models is a denoising network, which is trained to gradually denoise data. While earlier works used a **U-Net** architecture, newer research has shown that **Transformers** can be used to achieve comparable or superior results. Modular Diffusion ships with both types of denoising network. Both are implemented in PyTorch and thinly wrapped in a `Net` module.
The backbone of Diffusion Models is a denoising network, which is trained to gradually denoise data. While earlier works used a **U-Net** architecture, newer research has shown that **Transformers** can be used to achieve comparable or superior results. Modular Diffusion ships with both types of denoising network. Both are implemented in Pytorch and thinly wrapped in a `Net` module.

> Future warning
>
Expand Down
4 changes: 4 additions & 0 deletions docs/src/pages/modules/diffusion-model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ In Modular Diffusion, the `Model` class is a high-level interface that allows yo
- `net` -> Denoising network module.
- `loss` -> Loss function module.
- `guidance` (Default: `None`) -> Optional guidance module.
- `optimizer` (Default: `partial(Adam, lr=1e-4)`) -> Pytorch optimizer constructor function.
- `device` (Default: `"cpu"`) -> Device to train the model on.
- `compile` (Default: `true`) -> Whether to compile the model with `torch.compile` for faster training.

Expand All @@ -28,6 +29,8 @@ from diffusion.loss import Simple
from diffusion.net import UNet
from diffusion.noise import Gaussian
from diffusion.schedule import Cosine
from torch.optim import AdamW
from functools import partial

model = diffusion.Model(
data=Identity(x, y, batch=128, shuffle=True),
Expand All @@ -36,6 +39,7 @@ model = diffusion.Model(
net=UNet(channels=(1, 64, 128, 256), labels=10),
loss=Simple(parameter="epsilon"),
guidance=ClassifierFree(dropout=0.1, strength=2),
optimizer=partial(AdamW, lr=3e-4),
device="cuda" if torch.cuda.is_available() else "cpu",
)
```
Expand Down

0 comments on commit 3aea1a5

Please sign in to comment.