In [None]:
pip install diffusers

In [None]:
pip install torchvision


In [None]:
!pip install datasets

In [None]:
!pip install accelerate


In [None]:
!unzip dataset.zip

In [None]:
from dataclasses import dataclass
from datasets import load_dataset
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers.utils import make_image_grid
import torch.nn.functional as F
from accelerate import Accelerator, notebook_launcher
from tqdm.auto import tqdm
from pathlib import Path
import os
import glob
from PIL import Image
import numpy as np

TILE_COLORS = {
    'grass': (0, 128, 0),    # green
    'path-level0': (165, 42, 42),   # brown
    'path_level1': (210, 180, 140),      # Path level 1
    'path_level2': (160, 82, 45),   # Path level 2
    'path_level3': (139, 69, 19),  # Path level 3
    'water': (0, 0, 255),    # blue
    'forest': (0, 100, 0),   # dark green
    'sand': (255, 255, 0),    # yellow
    'bridge': (255,140,0),
    'rocklevel1': (128,128,128),
    'rocklevel2': (169,169,169),
    'rocklevel3': (211,211,211)
}

@dataclass
class TrainingConfig:
    image_size = (40, 80)  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 350
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-output"  # the model name locally
    seed = 0

config = TrainingConfig()

# Load and preprocess dataset
dataset = load_dataset("imagefolder", data_dir="./dataset", split="train")

preprocess = transforms.Compose(
    [
        transforms.Resize(config.image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform)

# DataLoader
train_dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

# Define the model
model = UNet2DModel(
    sample_size=config.image_size,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 512),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

sample_image = dataset[0]["images"].unsqueeze(0)
print("Input shape:", sample_image.shape)
print("Output shape:", model(sample_image, timestep=0).sample.shape)

# Noise Scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)

Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])

# Define optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

# Custom padding function
def custom_pad(x):
    h, w = x.shape[-2:]
    pad_h = (8 - h % 8) % 8
    pad_w = (8 - w % 8) % 8
    return F.pad(x, (0, pad_w, 0, pad_h), mode='constant', value=0)

def evaluate(config, epoch, pipeline):
    # Generate images
    images = pipeline(
        batch_size=config.eval_batch_size,
        generator=torch.Generator(device='cpu').manual_seed(config.seed),
    ).images

    # Process each image
    processed_images = []
    for img in images:
        # Resize the image
        img = img.resize((80, 40))

        # Map to nearest color
        img = map_to_nearest_color(img)

        # Convert back to PIL Image
        processed_images.append(img)

    # Create image grid
    image_grid = make_image_grid(processed_images, rows=2, cols=8)

    # Save the image grid
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")

def map_to_nearest_color(image):
    image_np = np.array(image)
    output = np.zeros_like(image_np)

    colors = np.array(list(TILE_COLORS.values()))
    distances = np.sum((image_np[:,:,None,:] - colors[None,None,:,:])**2, axis=3)
    nearest_color_indices = np.argmin(distances, axis=2)

    for i, color in enumerate(colors):
        output[nearest_color_indices == i] = color

    return Image.fromarray(output.astype(np.uint8))

# Training loop
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = custom_pad(batch["images"])
            noise = torch.randn(clean_images.shape, device=clean_images.device)
            bs = clean_images.shape[0]
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device, dtype=torch.int64)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)
            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                pipeline.save_pretrained(config.output_dir)

args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
notebook_launcher(train_loop, args, num_processes=1)

sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])