In [None]:
import os

os.environ["WANDB_PROJECT"] = "unconditional_image_generation"

import argparse
import inspect
import logging
import math
import shutil
from datetime import timedelta
from pathlib import Path

import accelerate
import datasets
import torch
import torch.nn.functional as F
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm

import diffusers
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import (
    check_min_version,
    is_accelerate_version,
    is_tensorboard_available,
    is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available

from dataclasses import dataclass


@dataclass
class Args:
    resolution: int = 64
    ddpm_num_steps: int = 1000
    ddpm_beta_schedule: str = "linear"
    learning_rate: float = 1e-4
    adam_beta1: float = 0.95
    adam_beta2: float = 0.999
    adam_weight_decay: float = 1e-6
    adam_epsilon: float = 1e-08
    dataset_name: str = "huggan/flowers-102-categories"
    dataset_config_name: str = None
    cache_dir: str = None
    ema_max_decay: float = 0.9999
    ema_inv_gamma: float = 1.0
    ema_power: float = 3 / 4


args = Args()

model = UNet2DModel(
    sample_size=args.resolution,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)
ema_model = EMAModel(
    model.parameters(),
    decay=args.ema_max_decay,
    use_ema_warmup=True,
    inv_gamma=args.ema_inv_gamma,
    power=args.ema_power,
    model_cls=UNet2DModel,
    model_config=model.config,
)

noise_scheduler = DDPMScheduler(
    num_train_timesteps=args.ddpm_num_steps,
    beta_schedule=args.ddpm_beta_schedule,
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

dataset = load_dataset(
    args.dataset_name,
    args.dataset_config_name,
    cache_dir=args.cache_dir,
    split="train",
)

  from .autonotebook import tqdm as notebook_tqdm
