# Author: Seunghee Kim
Created on: 24.12.08

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import random
import torch
import numpy as np
import PIL
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt


from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
from tqdm.auto import tqdm

import piq

from diffusers import StableDiffusionInpaintPipeline
from diffusers import UNet2DConditionModel, AutoencoderKL, StableDiffusionInpaintPipeline
from transformers import CLIPTextModel, CLIPTokenizer


In [None]:
# 파라미터 설정


IMAGE_SIZE = (512, 512)
# PROMPT = "Fill the missing area with the sea background"
PROMPT = 'fill the area with ocean, remove people'
SAVE_EXAMPLE_DIR = './masking_examples_v5'
os.makedirs(SAVE_EXAMPLE_DIR, exist_ok=True)

BATCH_SIZE = 9
EPOCHS = 3
# LEARNING_RATE = 1e-5
LEARNING_RATE = 5e-7
TRAIN_RATIO = 0.8
VALID_RATIO = 0.1
TEST_RATIO = 0.1
assert TRAIN_RATIO + VALID_RATIO + TEST_RATIO == 1.0

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 마스킹 증강 횟수
MASK_MULTIPLIER = 10

In [None]:
# 데이터셋 로드 및 마스킹 함수
beach_dataset = load_dataset('louiscklaw/beach_512', split='train', cache_dir='beach_dataset_huggingface')
beach_dataset = beach_dataset.select(range(500)) # 500개만 사용

def random_mask(image_size, num_shapes, max_shape_size=100):
    mask = Image.new("L", image_size, 0) 
    draw = ImageDraw.Draw(mask)
    width, height = image_size

    for _ in range(random.randint(1, num_shapes)):
        shape_type = random.choice(["ellipse", "rectangle", "polygon"])
        
        x1, y1 = random.randint(0, width - max_shape_size), random.randint(0, height - max_shape_size)
        
        x2 = x1 + random.randint(20, max_shape_size)
        y2 = y1 + random.randint(20, max_shape_size)
        
        x2 = min(x2, width)
        y2 = min(y2, height)
        
        if shape_type == "ellipse":
            draw.ellipse([x1, y1, x2, y2], fill=255)
        elif shape_type == "rectangle":
            draw.rectangle([x1, y1, x2, y2], fill=255)
        elif shape_type == "polygon":
            points = [
                (x1, y1),
                (x2, y1 + random.randint(0, max_shape_size // 2)),
                (x1 + random.randint(0, max_shape_size // 2), y2)
            ]
            draw.polygon(points, fill=255)
    
    mask = np.array(mask)
    return mask

resize_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(IMAGE_SIZE),
])

class InpaintingDataset(Dataset):
    def __init__(self, hf_dataset, prompt=PROMPT, transform=resize_transform, mask_multiplier=1):
        self.dataset = hf_dataset
        self.prompt = prompt
        self.transform = transform
        self.mask_multiplier = mask_multiplier

    def __len__(self):
        return len(self.dataset) * self.mask_multiplier
    
    def __getitem__(self, idx):
        real_idx = idx // self.mask_multiplier

        data = self.dataset[real_idx]
        image = data['image']
        if self.transform is not None:
            image = self.transform(image)
        
        image = image.convert("RGB")
        mask = random_mask(IMAGE_SIZE, num_shapes=5)
        mask_pil = Image.fromarray(mask)
        image_np = np.array(image)
        
        masked_image_np = image_np.copy()
        masked_image_np[mask == 255] = 255

        to_tensor = transforms.ToTensor()
        image_t = to_tensor(image)
        masked_image_t = to_tensor(Image.fromarray(masked_image_np))
        mask_t = to_tensor(mask_pil).float() # [1,H,W]
        mask_t = transforms.Resize((64, 64), interpolation=transforms.InterpolationMode.BILINEAR)(mask_t)
        
        return {
            'original_image': image_t,
            'masked_image': masked_image_t,
            'mask': mask_t,
            'prompt': self.prompt
        }

full_dataset = InpaintingDataset(beach_dataset, mask_multiplier=MASK_MULTIPLIER)

# 마스킹 예시 저장
for i in range(100):
    sample = full_dataset[i]
    masked_img = transforms.ToPILImage()(sample['masked_image'])
    original_img = transforms.ToPILImage()(sample['original_image'])
    mask_img = transforms.ToPILImage()(sample['mask'])
    
    masked_img.save(os.path.join(SAVE_EXAMPLE_DIR, f"masked_{i}.png"))
    original_img.save(os.path.join(SAVE_EXAMPLE_DIR, f"original_{i}.png"))
    mask_img.save(os.path.join(SAVE_EXAMPLE_DIR, f"mask_{i}.png"))

# Train/Valid/Test Split 및 DataLoader 구축
dataset_length = len(full_dataset)
train_length = int(dataset_length * TRAIN_RATIO)
valid_length = int(dataset_length * VALID_RATIO)
test_length = dataset_length - train_length - valid_length

train_dataset, valid_dataset, test_dataset = random_split(full_dataset, [train_length, valid_length, test_length],
                                                         generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# 모델 파인튜닝
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    cache_dir='./cache_dir'
).to(DEVICE)

pipe.unet.to(torch.float32)
pipe.text_encoder.to(torch.float32)
pipe.vae.to(torch.float32)

unet = pipe.unet
text_encoder = pipe.text_encoder
vae = pipe.vae
tokenizer = pipe.tokenizer

optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

def train_epoch():
    unet.train()
    train_loss = 0.0
    for batch in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            input_ids = tokenizer(batch['prompt'], padding="max_length", truncation=True, max_length=77, return_tensors="pt").input_ids.to(DEVICE)
            text_embeddings = text_encoder(input_ids)[0]
            latents = vae.encode(batch['original_image'].to(DEVICE)*2-1).latent_dist.sample() * 0.18215
            masked_latents = vae.encode(batch['masked_image'].to(DEVICE)*2-1).latent_dist.sample() * 0.18215
            mask = batch['mask'].to(DEVICE)

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, 1000, (latents.shape[0],), device=DEVICE, dtype=torch.long)
            noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
            
            unet_input = torch.cat([noisy_latents, masked_latents, mask], dim=1)
            model_pred = unet(unet_input, timesteps, encoder_hidden_states=text_embeddings, return_dict=False)[0]

            loss = torch.nn.functional.mse_loss(model_pred, noise)

        if torch.isnan(loss):
            print("Warning: NaN loss detected, skipping this batch.")
            continue

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * latents.size(0)
    return train_loss / len(train_loader.dataset)

def evaluate(loader):
    unet.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            with torch.cuda.amp.autocast():
                input_ids = tokenizer(batch['prompt'], padding="max_length", truncation=True, max_length=77, return_tensors="pt").input_ids.to(DEVICE)
                text_embeddings = text_encoder(input_ids)[0]
                latents = vae.encode(batch['original_image'].to(DEVICE)*2-1).latent_dist.sample() * 0.18215
                masked_latents = vae.encode(batch['masked_image'].to(DEVICE)*2-1).latent_dist.sample() * 0.18215
                mask = batch['mask'].to(DEVICE)
                
                noise = torch.randn_like(latents)
                timesteps = torch.randint(0, 1000, (latents.shape[0],), device=DEVICE, dtype=torch.long)
                noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
                
                unet_input = torch.cat([noisy_latents, masked_latents, mask], dim=1)
                model_pred = unet(unet_input, timesteps, encoder_hidden_states=text_embeddings, return_dict=False)[0]
                
                loss = torch.nn.functional.mse_loss(model_pred, noise)
                val_loss += loss.item() * latents.size(0)
    return val_loss / len(loader.dataset)

for epoch in range(EPOCHS):
    train_loss = train_epoch()
    valid_loss = evaluate(valid_loader)
    print(f"Epoch {epoch}, Train Loss: {train_loss}, Valid Loss: {valid_loss}")

# SSIM, LPIPS Metric Code

In [None]:
def compute_metrics(loader, pipeline=pipe):
    ssim_scores = []
    lpips_metric = piq.LPIPS().to(DEVICE)
    lpips_scores = []
    
    unet.eval()
    with torch.no_grad():
        for batch in tqdm(loader, desc="Computing metrics"):
            batch_size = batch['original_image'].shape[0]
            prompts = batch['prompt']
            if isinstance(prompts, str):
                prompts = [prompts]*batch_size
            elif not isinstance(prompts, list):
                prompts = [batch['prompt']] * batch_size

            for i in range(batch_size):
                original_img_t = batch['original_image'][i].to(DEVICE)
                mask_t = batch['mask'][i].cpu()
                masked_image_t = batch['masked_image'][i].cpu()

                masked_image_pil = transforms.ToPILImage()(masked_image_t)
                mask_pil = transforms.ToPILImage()(mask_t).convert("L")
                
                prompt = prompts[i]

                result_image = pipeline(prompt=prompt, image=masked_image_pil, mask_image=mask_pil).images[0]
                gen_image_t = transforms.ToTensor()(result_image).unsqueeze(0).to(DEVICE)
                original_img_t = original_img_t.unsqueeze(0)

                ssim_val = piq.ssim(gen_image_t, original_img_t, data_range=1.0)
                ssim_scores.append(ssim_val.item())

                lpips_val = lpips_metric(gen_image_t, original_img_t)
                lpips_scores.append(lpips_val.mean().item())

    mean_ssim = np.mean(ssim_scores)
    mean_lpips = np.mean(lpips_scores)
    return mean_ssim, mean_lpips

mean_ssim, mean_lpips = compute_metrics(valid_loader)
test_ssim, test_lpips = compute_metrics(test_loader, pipe)

In [None]:
print("Validation Metrics:")
print("SSIM:", mean_ssim)
print("LPIPS:", mean_lpips)

# os.makedirs('./finetuned_model', exist_ok=True)
# pipe.save_pretrained('./finetuned_model')

print("Test Metrics:")
print("SSIM:", test_ssim)
print("LPIPS:", test_lpips)

In [None]:
input_image = Image.open("two-people-sea.jpg").convert("RGB")
input_mask = Image.open("two-people-sea-mask.png").convert("L")
test_prompt = PROMPT

original_pipe = StableDiffusionInpaintPipeline.from_pretrained(
    # "runwayml/stable-diffusion-inpainting", 
    # "botp/stable-diffusion-v1-5-inpainting",
    "stabilityai/stable-diffusion-2-inpainting",
    cache_dir='./cache_dir'
    ).to(DEVICE)
original_pipe.unet.to(torch.float32)
original_pipe.text_encoder.to(torch.float32)
original_pipe.vae.to(torch.float32)

def run_inference(pipeline, image, mask, prompt):
    with torch.no_grad():
        result = pipeline(prompt=prompt, image=image, mask_image=mask).images[0]
    return result

num_generations = 10
original_results = []
finetuned_results = []

for i in range(num_generations):
    # Original 결과 생성
    original_result = run_inference(original_pipe, input_image, input_mask, test_prompt)
    original_file = f"original_result_{i+1}.png"
    original_result.save(original_file)
    original_results.append(original_result)

    # Fine-tuned 결과 생성
    finetuned_result = run_inference(pipe, input_image, input_mask, test_prompt)
    finetuned_file = f"finetuned_result_{i+1}.png"
    finetuned_result.save(finetuned_file)
    finetuned_results.append(finetuned_result)

In [None]:


plt.figure(figsize=(20, 10))
for i in range(num_generations):

    plt.subplot(2, num_generations, i + 1)
    plt.imshow(original_results[i])
    plt.axis('off')
    plt.title(f'Original {i + 1}')
    
    plt.subplot(2, num_generations, num_generations + i + 1)
    plt.imshow(finetuned_results[i])
    plt.axis('off')
    plt.title(f'Fine-tuned {i + 1}')

plt.tight_layout()
plt.show()

In [None]:
# import matplotlib.pyplot as plt

# epochs = [1, 2, 3]
# train_loss = [0.19729973109904678, 0.14593217349145562, 0.13217759937467052]
# valid_loss = [0.16069402199983596, 0.1383479192107916, 0.11289855517819523]



# plt.figure(figsize=(8, 6))
# plt.plot(epochs, train_loss, label='Train Loss', marker='o', linewidth=2)
# plt.plot(epochs, valid_loss, label='Validation Loss', marker='s', linewidth=2)

# plt.xlabel('Epochs', fontsize=14)
# plt.ylabel('Loss', fontsize=14)
# plt.title('Train and Validation Loss', fontsize=16)
# plt.legend(fontsize=12)
# plt.grid(True)

# plt.savefig('train_validation_loss.png', dpi=300)
# plt.show()