#Image generation with diffusion models
Implement and train unconditional diffusion models, such as DDPM (Denoising Diffusion Probabilistic Model) or DDIM (Denoising Diffusion Implicit Model) for generating realistic images. Evaluate the capabilities of the models on two different datasets, such as CelebA and Flowers102.

Related GitHub repositories:
https://huggingface.co/blog/annotated-diffusion
https://github.com/huggingface/diffusers
https://keras.io/examples/generative/ddim/

Related papers:
https://arxiv.org/abs/2006.11239
https://arxiv.org/abs/2010.02502

In [1]:
!pip install -q datasets

In [2]:
import torch
import torchvision

#Downloading datasets

In [3]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
#CelebA Dataset (Using Hugging Face datasets Library):
from datasets import load_dataset

# Download the CelebA dataset
celeb_a = load_dataset("nielsr/CelebA-faces")

# Save the dataset to disk (Optional, for local use)
celeb_a.save_to_disk('/content/celeba')


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Saving the dataset (0/3 shards):   0%|          | 0/202599 [00:00<?, ? examples/s]

In [5]:
#Flowers102 Dataset (Using torchvision):
from torchvision import datasets, transforms

# Download the Flowers102 dataset
flowers102 = datasets.Flowers102(root='/content/flowers102', download=True)

#Data Preprocessing
TODO: Hyperparameters settings

In [6]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Define transformations (e.g., resizing, normalizing)
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load the CelebA dataset
dataset = datasets.CelebA(root='/content', split='all', transform=transform, download=True)

# Split into training, validation, and test sets (80-10-10 split)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders for each set
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


Files already downloaded and verified


In [7]:
flowers_dataset = datasets.Flowers102(root='data/flowers102', split='train', download=True, transform=transform)

# Splitting into train, val, test
train_size = int(0.8 * len(flowers_dataset))
val_size = int(0.1 * len(flowers_dataset))
test_size = len(flowers_dataset) - train_size - val_size
train_flowers, val_flowers, test_flowers = random_split(flowers_dataset, [train_size, val_size, test_size])

# DataLoader for Flowers102
train_loader_flowers = DataLoader(train_flowers, batch_size=32, shuffle=True)
val_loader_flowers = DataLoader(val_flowers, batch_size=32, shuffle=False)
test_loader_flowers = DataLoader(test_flowers, batch_size=32, shuffle=False)


#Define the Environment

In [8]:
!pip freeze > requirements.txt

#Training the Diffusion Models
TODO

In [9]:
!pip install diffusers



In [17]:
!pip show jax jaxlib diffusers

Name: jax
Version: 0.4.34
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, dopamine_rl, flax, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.34
Summary: XLA library for JAX
Home-page: https://github.com/jax-ml/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, dopamine_rl, jax, optax, orbax-checkpoint
---
Name: diffusers
Version: 0.7.2
Summary: Diffusers
Home-page: https://github.com/huggingface/diffusers
Author: The HuggingFace team
Author-email: patrick@huggingface.co
License: Apache
Location: /usr/local/lib/python3.10/dist-packages
Requires: filelock, huggingface-hub, importlib-metadata, numpy, Pillow, regex, req

In [10]:
!pip install --upgrade jax jaxlib



In [18]:
!pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax-cuda12-plugin<=0.4.34,>=0.4.34 (from jax-cuda12-plugin[with_cuda]<=0.4.34,>=0.4.34; extra == "cuda"->jax[cuda])
  Downloading jax_cuda12_plugin-0.4.34-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting jax-cuda12-pjrt==0.4.34 (from jax-cuda12-plugin<=0.4.34,>=0.4.34->jax-cuda12-plugin[with_cuda]<=0.4.34,>=0.4.34; extra == "cuda"->jax[cuda])
  Downloading jax_cuda12_pjrt-0.4.34-py3-none-manylinux2014_x86_64.whl.metadata (349 bytes)
Downloading jax_cuda12_plugin-0.4.34-cp310-cp310-manylinux2014_x86_64.whl (14.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.9/14.9 MB[0m [31m31.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jax_cuda12_pjrt-0.4.34-py3-none-manylinux2014_x86_64.whl (100.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.3/100.3 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jax-

In [12]:
!pip install diffusers torch accelerate



In [14]:
!pip install transformers



In [15]:
from diffusers import DDPMPipeline

# Load pretrained DDPM pipeline (this can be customized)
model = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")

# Perform training (this will require you to create a training loop)
model.train()

# Example forward pass to generate an image
generated_image = model()


AttributeError: module 'jax.random' has no attribute 'KeyArray'

#Download Files Locally

In [None]:
from google.colab import files
files.download('/content/celeba_preprocessed.zip')
