In [2]:
# Dataset
!gdown 194DqJkUjjtlUCp7Sd2DkXgwu7SuBFsrj
!unzip celeba_hq_256.zip

import os
import random

file_names = os.listdir("/content/celeba_hq_256")
img_paths = ["/content/celeba_hq_256/" + file_name for file_name in file_names]
sample_size = int(len(img_paths) * 0.9)

train_imgpaths = random.sample(img_paths, sample_size)
val_imgpaths = [img_path for img_path in img_paths if img_path not in train_imgpaths]

[1;30;43mKết quả truyền trực tuyến bị cắt bớt đến 5000 dòng cuối.[0m
  inflating: celeba_hq_256/25000.jpg  
  inflating: celeba_hq_256/25001.jpg  
  inflating: celeba_hq_256/25002.jpg  
  inflating: celeba_hq_256/25003.jpg  
  inflating: celeba_hq_256/25004.jpg  
  inflating: celeba_hq_256/25005.jpg  
  inflating: celeba_hq_256/25006.jpg  
  inflating: celeba_hq_256/25007.jpg  
  inflating: celeba_hq_256/25008.jpg  
  inflating: celeba_hq_256/25009.jpg  
  inflating: celeba_hq_256/25010.jpg  
  inflating: celeba_hq_256/25011.jpg  
  inflating: celeba_hq_256/25012.jpg  
  inflating: celeba_hq_256/25013.jpg  
  inflating: celeba_hq_256/25014.jpg  
  inflating: celeba_hq_256/25015.jpg  
  inflating: celeba_hq_256/25016.jpg  
  inflating: celeba_hq_256/25017.jpg  
  inflating: celeba_hq_256/25018.jpg  
  inflating: celeba_hq_256/25019.jpg  
  inflating: celeba_hq_256/25020.jpg  
  inflating: celeba_hq_256/25021.jpg  
  inflating: celeba_hq_256/25022.jpg  
  inflating: celeba_hq_256/25023

In [3]:
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

def bbox2mask(img_shape, bbox, dtype='uint8'):
    height, width = img_shape[:2]
    mask = np.zeros((height, width, 1), dtype=dtype)
    mask[bbox[0]:bbox[0] + bbox[2], bbox[1]:bbox[1] + bbox[3], :] = 1
    return mask

class InpaintingDataset(Dataset):
    def __init__(self, img_paths, image_size=[256, 256]):
        self.img_paths = img_paths
        self.tfs = transforms.Compose([
            transforms.Resize((image_size[0], image_size[1])),
            transforms.ToTensor()
        ])
        self.image_size = image_size

    def __getitem__(self, index):
        img_path = self.img_paths[index]
        img = Image.open(img_path).convert('RGB')
        img = self.tfs(img)
        mask = self.get_mask()
        mask_img = img * (1. - mask) + mask
        return {
            "gt_image": img,
            "cond_image": mask_img,
            "mask": mask,
            "path": img_path
        }

    def __len__(self):
        return len(self.img_paths)

    def get_mask(self):
        h, w = self.image_size  # Center mask
        mask = bbox2mask(self.image_size, (h // 4, w // 4, h // 4, w // 4))
        return torch.from_numpy(mask).permute(2, 0, 1)

train_dataset = InpaintingDataset(train_imgpaths)
batch_size = 64  # (GPU 24 GB)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)

In [1]:
# Model
!pip install -q torchcfm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m102.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m43.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [5]:
from torchcfm.models.unet import UNetModel
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "gpu")
model = UNetModel(dim=(3, 256, 256), num_channels=32, num_res_blocks=1).to(device)
optimizer = torch.optim.Adam(model.parameters())

n_epochs = 1000
for epoch in range(n_epochs):
    losses = []
    for i, data in tqdm(enumerate(train_loader)):
        optimizer.zero_grad()
        x1 = data["gt_image"].to(device)
        mask = data["mask"].to(device)
        x0 = torch.randn_like(x1).to(device)

        x_noise = (1.0 - mask) * x1 + mask * x0
        t = torch.rand(x0.shape[0], 1, 1, 1).to(device)
        xt = t * x1 + (1 - t) * x_noise
        ut = x1 - x_noise

        t = t.squeeze()
        x_cond = xt * mask + (1.0 - mask) * x1
        vt = model(t, x_cond)

        loss = torch.mean(((vt - ut) ** 2) * mask)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    avg_loss = sum(losses) / len(losses)
    print(f"epoch: {epoch}, loss: {avg_loss:.4f}")

0it [00:00, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 178.12 MiB is free. Process 2270 has 14.56 GiB memory in use. Of the allocated memory 14.28 GiB is allocated by PyTorch, and 167.24 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
model.eval()

def euler_method(model, cond_image, t_steps, dt, mask):
    y = cond_image
    y_values = [y]
    with torch.no_grad():
        for t in t_steps[1:]:
            t = t.reshape(-1,)
            dy = model(t.to(device), y)
            y = y + dy * dt
            y = cond_image * (1. - mask) + mask * y
            y_values.append(y)
    return torch.stack(y_values)

# Initial random image and class (optional)
sample = next(iter(train_loader))
gt_image = sample['gt_image'].to(device)
noise = torch.randn_like(gt_image, device=device)
mask = sample['mask'].to(device)
cond_image = gt_image * (1. - mask) + mask * noise

# Time parameters
t_steps = torch.linspace(0, 1, 50, device=device)  # From 0 to 1
dt = t_steps[1] - t_steps[0]

# Solve the ODE using Euler method
traj = euler_method(model, cond_image, t_steps, dt, mask)

In [None]:
# Cài đặt
!pip install streamlit

# Mở ứng dụng
!streamlit run app.py

In [None]:
import streamlit as st

st.title("Image Inpainting using Conditional Flow Matching")
st.write("Model: Conditional Flow Matching. Dataset: CelebA-HQ")

if st.button("Run Example Image"):
    # Chạy inference và hiển thị kết quả
    st.image(output_image_tensor)