In [None]:
from dataclasses import dataclass
import numpy as np
import glob
from huggingface_hub import login
from torchvision.transforms import ToTensor
import os
from torchvision import datasets, transforms
from PIL import Image
import cv2
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from torch.nn.parallel import DataParallel
import matplotlib.pyplot as plt
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers import DDPMPipeline
import math
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import torch.multiprocessing as mp
from accelerate import notebook_launcher
import torch.nn.functional as F
from torch.multiprocessing import Pool, Process, set_start_method

login("")

try:
    mp.set_start_method("spawn")
except RuntimeError:
    print('no')
    pass  

@dataclass
class TrainingConfig:
    image_size = 512  # 생성되는 이미지 해상도
    train_batch_size = 1
    eval_batch_size = 16  # 평가 동안에 샘플링할 이미지 수
    num_epochs = 1000
    gradient_accumulation_steps = 8
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 100
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no`는 float32, 자동 혼합 정밀도를 위한 `fp16`
    output_dir = "test_ddpm"  # 로컬 및 HF Hub에 저장되는 모델명

    push_to_hub = False  # 저장된 모델을 HF Hub에 업로드할지 여부
    hub_private_repo = False
    overwrite_output_dir = True  # 노트북을 다시 실행할 때 이전 모델에 덮어씌울지
    seed = 36

config = TrainingConfig()
    
image_path = np.sort(glob.glob(os.path.join('./','*.png')))

transform = A.Compose([  
    A.RandomGamma(gamma_limit=(90, 110), p=0.1),
    A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.1),
    A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
    ToTensorV2()
])

class CustomDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]))
        file_name = os.path.basename(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)


        if self.transform:
            augmented = self.transform(image=image)
            #(c, w ,h) -> (w, h, c)
            image = augmented['image']
            
        # image.clone().detach().to(torch.float32).requires_grad_(True)
        
        return image

train_dataset = CustomDataset(image_path, transform)
# DataLoader 설정 (배치 크기 32)
dataloader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=False)

model = UNet2DModel(
    sample_size=config.image_size,  # 타겟 이미지 해상도
    in_channels=3,  # 입력 채널 수, RGB 이미지에서 3
    out_channels=3,  # 출력 채널 수
    layers_per_block=3,  # UNet 블럭당 몇 개의 ResNet 레이어가 사용되는지
    block_out_channels=(64, 64, 128, 128, 512, 512),  # 각 UNet 블럭을 위한 출력 채널 수
    down_block_types=(
        "DownBlock2D",  # 일반적인 ResNet 다운샘플링 블럭
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # spatial self-attention이 포함된 일반적인 ResNet 다운샘플링 블럭
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # 일반적인 ResNet 업샘플링 블럭
        "AttnUpBlock2D",  # spatial self-attention이 포함된 일반적인 ResNet 업샘플링 블럭
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(dataloader) * config.num_epochs),
)

In [None]:
for i in dataloader:
    _ = 1
    break

sample_image = i[0].squeeze(0)
noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)

# 2개의 subplot 생성
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# 원본 이미지와 노이즈가 추가된 이미지를 각 subplot에 그리기
axes[1].imshow(((noisy_image.permute(0, 1, 2) + 1.0) * 127.5).type(torch.uint8).numpy()[0], cmap='gray')
axes[0].set_title("Original Image")
axes[0].axis('off')  # 축 숨기기

axes[0].imshow(((sample_image.permute(0, 1, 2) + 1.0) * 127.5).type(torch.uint8).numpy()[0], cmap='gray')
axes[1].set_title("Noisy Image")
axes[1].axis('off')  # 축 숨기기

plt.tight_layout()
plt.show()

In [None]:
def make_grid(images, rows, cols, size=(128, 128)): 
    
    grid_size = (rows, cols)
    
    fig, axes = plt.subplots(grid_size[0], grid_size[1], figsize=(12, 12))
    
    for i, ax in enumerate(axes.flat):
        img = images[i].permute(1, 2, 0)  # (3, 512, 512) -> (512, 512, 3)
        ax.imshow(img)
        ax.axis('off')  # 축 제거
    
    plt.tight_layout()
    
    return plt

def evaluate(config, epoch, pipeline):
    image_list=[]
    for i in range(config.eval_batch_size):
        image = pipeline(
            batch_size=1,  
            generator=torch.manual_seed(config.seed + i), 
        ).images[0]
        
        image_tensor = ToTensor()(image)
        image_list.append(image_tensor)
        
    image_grid = make_grid(image_list, rows=4, cols=4, size=(128, 128)) 

    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.savefig(f"{test_dir}/{epoch:04d}.png")

In [None]:
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        if config.push_to_hub:
            repo_id = create_repo(
                repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
            ).repo_id
        accelerator.init_trackers("train_example")

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images =  batch

            noise = torch.randn(clean_images.shape, device=clean_images.device)
            bs = clean_images.shape[0]

            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
                dtype=torch.int64
            )
        
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
        
            with accelerator.accumulate(model):
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                loss = loss / config.gradient_accumulation_steps
                accelerator.backward(loss)

                #gradient_accumulation_steps
                if (step + 1) % config.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        if accelerator.is_main_process:
            pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                if config.push_to_hub:
                    upload_folder(
                        repo_id=repo_id,
                        folder_path=config.output_dir,
                        commit_message=f"Epoch {epoch}",
                        ignore_patterns=["step_*", "epoch_*"],
                    )
                else:
                    pipeline.save_pretrained(config.output_dir)

In [None]:
args = (config, model, noise_scheduler, optimizer, dataloader, lr_scheduler)
notebook_launcher(train_loop, args, num_processes=1)

# 결과 이미지 시각화
pipeline = DDPMPipeline.from_pretrained("./test_ddpm")
pipeline.to("cuda")

with torch.no_grad():
    generated_image = pipeline(num_inference_steps=1000).images[0]
    
plt.imshow(generated_image)
plt.axis("off")
plt.show()