In [None]:
# --- Imports ---
import os, glob, torch, cv2
import numpy as np
from PIL import Image
from math import sqrt
from tqdm import tqdm
from einops import rearrange
from torchvision.utils import make_grid
from torchvision import transforms as T
from types import SimpleNamespace as Namespace  # cleaner than argparse
from omegaconf import OmegaConf

sys.path.append(os.getcwd() + "/ldm")
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler

transform_PIL = T.ToPILImage()

# --- Helper Functions ---

def create_model(device, yaml_path, model_path):
    config = OmegaConf.load(yaml_path)
    config.model['params']['ckpt_path'] = model_path
    model = instantiate_from_config(config.model)
    sampler = DDIMSampler(model)
    model = model.to(device)
    return model, sampler

def process_data(image_pth, mask_pth, kernel_size=2):
    mask = cv2.imread(mask_pth, cv2.IMREAD_GRAYSCALE)
    original_size = mask.shape

    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    dilated_mask = cv2.dilate(mask, kernel, iterations=1)
    dilated_mask = Image.fromarray(dilated_mask)

    dilated_mask = np.expand_dims(dilated_mask, axis=2).astype(np.float32) / 255.0
    dilated_mask = (dilated_mask >= 0.1).astype(np.float32)
    dilated_mask = torch.from_numpy(dilated_mask.transpose(2,0,1)[None])

    image = np.array(Image.open(image_pth).convert("RGB").resize((512,512))).astype(np.float32) / 255.0
    image = torch.from_numpy(image.transpose(2,0,1)[None])

    mask = np.array(Image.open(mask_pth).convert("L").resize((512,512))).astype(np.float32) / 255.0
    mask = (mask >= 0.1).astype(np.float32)
    mask = torch.from_numpy(mask[None,None])

    masked_image = (1 - mask) * image

    batch = {"image": image * 2.0 - 1.0, 
             "mask": dilated_mask * 2.0 - 1.0, 
             "masked_image": masked_image * 2.0 - 1.0}

    original_image = Image.open(image_pth).convert("RGB")
    original_mask = Image.open(mask_pth).convert("L")

    imagename = os.path.splitext(os.path.basename(image_pth))[0]
    
    return batch, original_size, original_image, original_mask, imagename

def run_inference(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 
                          'mps' if getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available() 
                          else 'cpu')
    print(f"Using device: {device}")

    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)
    
    model, sampler = create_model(device, args.yaml_path, args.model_path)
    model.eval()

    logpath = os.path.join(args.log_path, args.image.split('/')[-1].replace('.', '_'))
    os.makedirs(logpath, exist_ok=True)

    batch, original_size, original_image, original_mask, imagename = process_data(args.image, args.mask, args.dilate_kernel)

    c = model.cond_stage_model.encode(batch["masked_image"].to(device))
    cc = torch.nn.functional.interpolate(batch["mask"].to(device), size=c.shape[-2:])
    c = torch.cat((c, cc), dim=1)

    shape = (c.shape[1]-1,) + c.shape[2:]
    cond = c.expand(args.batchsize, -1, -1, -1)

    samples_ddim, _ = sampler.sample(
        S=args.Steps,
        conditioning=cond,
        batch_size=cond.shape[0],
        shape=shape,
        verbose=False
    )

    x_samples_ddim = model.decode_first_stage(samples_ddim)
    predicted_image_clamped = torch.clamp((x_samples_ddim + 1.0) / 2.0, 0.0, 1.0)

    all_samples = []
    base_count = len(glob.glob(os.path.join(logpath, 'sample*')))
    grid_count = len(glob.glob(os.path.join(logpath, 'grid*')))

    image_array = np.array(original_image)
    mask_array = np.array(original_mask)

    for sample in predicted_image_clamped:
        output_PIL = transform_PIL(sample)
        output_PIL = output_PIL.resize((original_size[1], original_size[0]))

        if args.isReplace:
            out_array = np.array(output_PIL)
            out_array[mask_array == 0] = image_array[mask_array == 0]
            output_PIL = Image.fromarray(out_array)

        output_PIL.save(os.path.join(logpath, f"sample_{base_count:05}.png"))
        base_count += 1

    all_samples.append(predicted_image_clamped)

    # Save grid
    grid = torch.stack(all_samples, 0)
    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
    grid = make_grid(grid, nrow=int(sqrt(args.batchsize)))

    grid_img = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
    grid_img = Image.fromarray(grid_img.astype(np.uint8))
    grid_img.save(os.path.join(logpath, f'grid-{grid_count:04}.png'))

    print(f"Results saved to {logpath}")

# --- Example Manual args (for Jupyter) ---
args = Namespace(
    yaml_path = "ldm/models/ldm/inpainting_big/config_LAKERED.yaml",
    model_path = "ckpt/LAKERED.ckpt",
    log_path = "demo_res",
    image = "demo/src/COD_CAMO_camourflage_00018.jpg",
    mask = "demo/src/COD_CAMO_camourflage_00018.png",
    batchsize = 1,
    isReplace = False,
    dilate_kernel = 2,
    Steps = 50
)

run_inference(args)
