In [5]:
!pip install -qq -U diffusers datasets transformers accelerate ftfy
!pip install -qq wandb
!pip install -qq torchinfo
!pip install -qq matplotlib

In [None]:
from huggingface_hub import notebook_login
from datasets import load_dataset
import wandb
import torch
from diffusers import DDPMScheduler, UNet2DModel, DDPMPipeline
from pathlib import Path


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
wandb.login()

In [7]:
def make_grid(images, size=64):
    """Given a list of PIL images, stack them together into a line for easy viewing"""
    
    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im

https://huggingface.co/settings/tokens

In [None]:
notebook_login()

In [None]:
dataset = load_dataset('huggan/pokemon')

In [10]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image'],
        num_rows: 7357
    })
})

In [11]:
PROJECT_NAME = "sd-pokemon"
MODEL_NAME = "sd-pokemon-64"

In [13]:
!accelerate launch --mixed_precision=fp16 train_unconditional_v3.py \
    --project_name={PROJECT_NAME} \
    --ddpm_beta_schedule='squaredcos_cap_v2' \
    --dataset_name='huggan/pokemon' \
    --dataset_sample_size=100 \
    --seed=2077 \
    --resolution=64 \
    --output_dir={MODEL_NAME} \
    --train_batch_size=64 \
    --eval_batch_size=4 \
    --num_epochs=30 \
    --gradient_accumulation_steps=1 \
    --learning_rate=4e-4 \
    --lr_scheduler="cosine" \
    --lr_warmup_steps=500 \
    --dataloader_num_workers=12 \
    --mixed_precision="fp16" \
    --save_images_epochs=10 \
    --save_model_epochs=30 \
    --push_to_wandb

The following values were not passed to `accelerate launch` and had defaults used instead:
	`--num_processes` was set to a value of `1`
	`--num_machines` was set to a value of `1`
	`--dynamo_backend` was set to a value of `'no'`
Using custom data configuration huggan--pokemon-8faf296a1351d650
Found cached dataset parquet (/root/.cache/huggingface/datasets/huggan___parquet/huggan--pokemon-8faf296a1351d650/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/huggan___parquet/huggan--pokemon-8faf296a1351d650/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-1713741c4c5cdf2b.arrow
[34m[1mwandb[0m: Currently logged in as: [33mmatt24[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.13.6
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/workspace/diffusion-models-class/unit1/wandb/run-20221211

In [None]:
# SAVE MODEL'S CHECKPOINT ON WANDB 

# ref: https://docs.wandb.ai/guides/track/advanced/resuming
# ref: https://docs.wandb.ai/ref/python/artifact

RUN_ID = '1tbe0g9n'
MODEL_PATH = f'./{MODEL_NAME}'

with wandb.init(project=PROJECT_NAME, id=RUN_ID, resume=True) as run:
    artifact = wandb.Artifact(RUN_ID, type='model')
    artifact.add_dir(MODEL_PATH)
    run.log_artifact(artifact)

In [None]:
# DOWNLOAD MODEL'S CHECKPOINT FROM WANDB

RUN_ID = '1tbe0g9n'
with wandb.init(project=PROJECT_NAME, id=RUN_ID, resume=True) as run:
    artifact = wandb.use_artifact(f'matt24/{PROJECT_NAME}/{RUN_ID}:v0', type='model')
    artifact_dir = artifact.download()

print(artifact_dir)

In [None]:
# LOAD MODEL'S CHECKPOINT IN THE GENERATIVE PIPELINE

ckpt = Path(artifact_dir)

model = UNet2DModel.from_pretrained(pretrained_model_name_or_path=ckpt/'unet')
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path=ckpt/'scheduler')
pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler).to(device)

In [None]:
from PIL import Image
generator = torch.Generator(device=pipeline.device).manual_seed(0)
images = pipeline(batch_size=64, generator=generator).images
make_grid(images, size=32)

In [None]:
with wandb.init(project=PROJECT_NAME, id=RUN_ID, resume=True) as run:
    run.log({"Generated images": [wandb.Image(image) for image in images]})

#### Other example

In [None]:
model_ckpt = 'anton-l/ddpm-ema-pokemon-64'
pipeline = DDPMPipeline.from_pretrained(model_ckpt).to(device)

In [None]:
images = pipeline(batch_size=32).images

In [None]:
for image in images:
    display(image)