<a href="https://colab.research.google.com/github/jobellet/vlPFC_Visual_Geometry/blob/main/inpainting_landscape.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
private_link = input('Enter the private link token:')

In [None]:

#--- Imports ---
import os
import time
from pathlib import Path
import pandas as pd
import zipfile
import urllib.request
from PIL import Image
import numpy as np
import cv2
import torch
!pip install git+https://github.com/openai/CLIP.git
!pip install diffusers
!pip install transformers
import clip
from diffusers import StableDiffusionInpaintPipeline
from tqdm import tqdm
import threading
from IPython.display import display
import ipywidgets as widgets


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-3h9c843j
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-3h9c843j
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


2025-07-05 20:11:54.035706: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751746314.058023     294 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751746314.064901     294 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
import os, sys
from pathlib import Path

# detect environment
IN_COLAB   = 'google.colab' in str(get_ipython())
IN_KAGGLE  = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ

# choose clone location
if IN_COLAB:
    path_to_repo = Path('/content/vlPFC_Visual_Geometry')
elif IN_KAGGLE:
    path_to_repo = Path('/kaggle/working/vlPFC_Visual_Geometry')
else:
    path_to_repo = Path('.')  # your local git clone

# clone if necessary
if not path_to_repo.exists():
    print(f"Cloning into {repo_dir}…")
    os.system(f"git clone https://github.com/jobellet/vlPFC_Visual_Geometry.git {repo_dir}")

# add to path
sys.path.append(str(path_to_repo))
sys.path.append(str(path_to_repo / 'utils'))

from extract_and_download_data import download_files, unzip

# Download metadata and stimuli
download_files(path_to_repo, ["hvm_public_extended_meta.csv", "background_variations_images.zip"], private_link=private_link)
# Extract images
DATA_DIR = Path("data")
EXTRACT_DIR = DATA_DIR / "background_variations_images"
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
unzip("downloads/background_variations_images.zip", EXTRACT_DIR)

# Add a check to see if files were extracted
if not any(EXTRACT_DIR.iterdir()):
    print(f"Warning: No files found in {EXTRACT_DIR}. Extraction may have failed.")

downloads/hvm_public_extended_meta.csv already exists.
downloads/background_variations_images.zip already exists.


In [None]:
#--- Load stimulus metadata ---
if IN_KAGGLE:
    meta = pd.read_csv("/kaggle/working/downloads/hvm_public_extended_meta.csv")
else:
    meta = pd.read_csv("downloads/hvm_public_extended_meta.csv")
# Add a new column for the actual filenames in the extracted directory
meta['extracted_filename'] = meta['image_id'] + '.png'

In [None]:

#--- Inpainting helper functions (from inpainting.py) ---
def stack_object_images(df_obj, convert_gray=True, max_images=None):
    pixels = []
    for i, fname in enumerate(df_obj.filename):
        if max_images is not None and i >= max_images:
            break
        img_path = EXTRACT_DIR / fname
        if not img_path.exists():
            continue
        im = Image.open(img_path)
        if convert_gray:
            im = im.convert("L")
        pixels.append(np.asarray(im, dtype=np.uint8))
    return np.stack(pixels, axis=0)

def build_static_mask(df_obj, variance_thr=5, dilate_px=0, feather_px=12):
    img_stack = stack_object_images(df_obj)
    ptp = img_stack.max(axis=0) - img_stack.min(axis=0)
    static = (ptp < variance_thr).astype(np.uint8)
    bbox = (
        int(df_obj.axis_bb_left.min()),
        int(df_obj.axis_bb_top.min()),
        int(df_obj.axis_bb_right.max()),
        int(df_obj.axis_bb_bottom.max()),
    )
    mask = np.zeros_like(static, dtype=np.uint8)
    x1, y1, x2, y2 = bbox
    mask[y1:y2, x1:x2] = static[y1:y2, x1:x2]
    mask = cv2.dilate(mask, None, iterations=dilate_px//2 or 1)
    mask = cv2.GaussianBlur(mask.astype(np.float32), (0,0), feather_px)
    mask = (mask / mask.max() * 255).astype(np.uint8)
    return Image.fromarray(mask, mode="L")

def clip_similarity(clip_model, clip_preprocess, img, txt_feat):
    with torch.no_grad():
        img_tensor = clip_preprocess(img).unsqueeze(0).to(device)
        feat = clip_model.encode_image(img_tensor)
        feat /= feat.norm(dim=-1, keepdim=True)
    return (feat @ txt_feat.T).item()


In [None]:
#--- Initialize models ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16,
    cache_dir=os.path.expanduser("~/.cache/huggingface")
).to(device)
pipe.safety_checker = None



clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
for p in clip_model.parameters():
    p.requires_grad_(False)

txt_feat_nat = clip_model.encode_text(clip.tokenize("a natural landscape").to(device))
txt_feat_sal = clip_model.encode_text(clip.tokenize("a landscape with an object at the center").to(device))
txt_feat_nat /= txt_feat_nat.norm(dim=-1, keepdim=True)
txt_feat_sal /= txt_feat_sal.norm(dim=-1, keepdim=True)

POS_PROMPT = "empty landscape, photorealistic, consistent lighting, high detail, outdoor, natural scene"
NEG_PROMPT = "object, animal, person, vehicle, table, text, logo, watermark, focal subject, centered object, unrealistic"


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [None]:
#--- Function to inpaint a single filename ---
def inpaint_image(fname, n_candidates=3):
    # Ensure we find the row for the specific filename using the new column
    row_query = meta.query("extracted_filename == @fname")
    if row_query.empty:
        with out: # Keep essential warning in GUI output
            print(f"Warning: Metadata not found for {fname}. Skipping inpainting.")
        return None
    row = row_query.iloc[0]

    # Get the DataFrame for the object to build the mask - using f-string for query
    query_str = f"object_name == '{row.object_name}' and rxz == {row.rxz}"
    df_obj = meta.query(query_str)

    mask = build_static_mask(df_obj)
    img = Image.open(EXTRACT_DIR / fname).convert("RGB")
    best_img, best_score = None, -float('inf')
    for i in tqdm(range(n_candidates), desc=f"Inpainting {fname}"):
        # Removed verbose output: with out: print(f"  Processing candidate {i+1}/{n_candidates}")
        seed = int(time.time_ns()) & 0xFFFFFFFF
        gen = torch.Generator(device).manual_seed(seed)
        out_pipe = pipe( # Renamed variable to avoid conflict with output widget
            prompt=POS_PROMPT,
            negative_prompt=NEG_PROMPT,
            image=img,
            mask_image=mask,
            num_inference_steps=30,
            guidance_scale=4.5,
            strength=1.0,
            generator=gen,
        )
        candidate = out_pipe.images[0]
        score_nat = clip_similarity(clip_model, clip_preprocess, candidate, txt_feat_nat)
        score_sal = clip_similarity(clip_model, clip_preprocess, candidate, txt_feat_sal)
        ratio = score_nat / (score_sal + 1e-6)
        if ratio > best_score:
            best_score = ratio
            best_img = candidate

    # Save inpainted
    INPAINT_DIR = DATA_DIR / "inpainted_images"
    INPAINT_DIR.mkdir(parents=True, exist_ok=True)
    out_path = INPAINT_DIR / fname
    if best_img: # Check if a valid image was found
        try:
            best_img.convert("L").save(out_path, format='PNG') # Explicitly specify format
            # Add check after saving
            if out_path.exists() and out_path.stat().st_size > 0:
                 with out:
                     print(f"Successfully saved inpainted image to {out_path}")
                 return True # Indicate success
            else:
                 with out:
                     print(f"Warning: Inpainted image file {out_path} was not created or is empty.")
                 return False # Indicate failure
        except Exception as e:
            with out:
                print(f"Error saving inpainted image {out_path}: {e}")
            return False # Indicate failure
        # Removed verbose output: with out: print(f"Finished inpainting for {fname}. Saved to {out_path}")
    else:
        with out: # Keep essential failure message in GUI output
            print(f"Failed to generate a valid image for {fname}")
        return False # Indicate failure
#--- Function to start background inpainting for all missing images ---
def start_batch_inpainting(update_display_func, current_index):
    # Get the list of filenames that have metadata and are in the extracted directory
    extracted_filenames_set = sorted(set(f.name for f in EXTRACT_DIR.glob("*.png")))
    filenames_with_metadata = [fname for fname in meta['extracted_filename'] if fname in extracted_filenames_set]

    # Get the list of already inpainted filenames
    inpainted_filenames_set = sorted(set(f.name for f in INPAINT_DIR.glob("*.png")))

    # Identify filenames that need inpainting
    to_inpaint = [fname for fname in filenames_with_metadata if fname not in inpainted_filenames_set]

    if not to_inpaint:
        with out:
            print("All images with metadata have already been inpainted.")
        return

    with out:
        print(f"Starting background inpainting for {len(to_inpaint)} missing images...")

    # Function to run in the background thread
    def inpaint_missing_images_thread(filename_list, update_display_func, current_index):
        for fname in filename_list:
            success = inpaint_image(fname)
            if success:
                # Update the GUI display after each image is processed
                # Add a small delay to avoid overwhelming the GUI with rapid updates
                time.sleep(0.1) # Added small delay
                # Only update the display if we are on the current image or a subsequent one
                global filenames # Access the global filenames list
                filenames = sorted(
                    set(f.name for f in EXTRACT_DIR.glob("*.png")) &
                    set(f.name for f in INPAINT_DIR.glob("*.png"))
                )
                if filenames and fname in filenames[current_index:]:
                     update_display_func(current_index)


    # Start the background thread
    threading.Thread(target=lambda: inpaint_missing_images_thread(to_inpaint, update_display_func, current_index), daemon=True).start()

In [None]:
from pathlib import Path

# make sure DATA_DIR is still pointing at where you unzipped
if IN_KAGGLE:
    DATA_DIR     = Path("/kaggle/working/data")
else:
    DATA_DIR     = Path("data")
EXTRACT_DIR  = DATA_DIR / "background_variations_images"
INPAINT_DIR  = DATA_DIR / "inpainted_images"

# (re-create the dirs if you like, though they should already exist)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
INPAINT_DIR.mkdir(parents=True, exist_ok=True)

# now collect the intersection of the two sets of PNGs
filenames = sorted(
    set(f.name for f in EXTRACT_DIR.glob("*.png")) &
    set(f.name for f in INPAINT_DIR.glob("*.png"))
)
index = 0


In [None]:

#--- Interactive GUI with ipywidgets ---
out = widgets.Output()
label = widgets.Label()

import io  # at the top of your notebook

def update_display(idx):
    out.clear_output(wait=True)
    # Ensure filenames list is up-to-date before accessing
    global filenames
    filenames = sorted(
        set(f.name for f in EXTRACT_DIR.glob("*.png")) &
        set(f.name for f in INPAINT_DIR.glob("*.png"))
    )
    if not filenames:
        with out:
            print("No inpainted images available to display.")
        label.value = "No images"
        return

    if idx < 0 or idx >= len(filenames):
        with out:
            print(f"Index {idx} is out of bounds for filenames list of size {len(filenames)}")
        return # Exit if index is invalid

    fname = filenames[idx]

    # load
    orig_path = EXTRACT_DIR / fname
    inp_path = INPAINT_DIR  / fname

    if not orig_path.exists():
         with out:
             print(f"Original image not found: {orig_path}")
         label.value = f"Error: Original image not found for {fname}"
         return

    if not inp_path.exists():
        with out:
            print(f"Inpainted image not found: {inp_path}")
        label.value = f"{idx+1}/{len(filenames)}: Inpainted image missing for {fname}"
        # Do not display images if inpainted is missing, but keep the label updated
        return


    try:
        orig = Image.open(orig_path)
        inp  = Image.open(inp_path)
    except Exception as e:
         with out:
            print(f"Error loading images for {fname}: {e}")
         label.value = f"Error loading images for {fname}"
         return



    # serialize to PNG
    buf1 = io.BytesIO()
    orig.save(buf1, format='PNG')
    data1 = buf1.getvalue()
    buf1.close()

    buf2 = io.BytesIO()
    inp.save(buf2, format='PNG')
    data2 = buf2.getvalue()
    buf2.close()

    # display as actual PNG bytes
    with out:
        display(widgets.HBox([
            widgets.Image(value=data1, format='png', width=400),
            widgets.Image(value=data2, format='png', width=400),
        ]))
    label.value = f"{idx+1}/{len(filenames)}: {fname}"


def on_prev(b):
    global index
    # Ensure filenames list is up-to-date before checking bounds
    global filenames
    filenames = sorted(
        set(f.name for f in EXTRACT_DIR.glob("*.png")) &
        set(f.name for f in INPAINT_DIR.glob("*.png"))
    )
    if not filenames:
        with out:
            print("No images to navigate.")
        return

    if index > 0:
        index -= 1
        update_display(index)
    else:
        with out:
            print("Already at the first image.")


def on_next(b):
    global index
    # Ensure filenames list is up-to-date before checking bounds
    global filenames
    filenames = sorted(
        set(f.name for f in EXTRACT_DIR.glob("*.png")) &
        set(f.name for f in INPAINT_DIR.glob("*.png"))
    )
    if not filenames:
        with out:
            print("No images to navigate.")
        return

    if index < len(filenames) - 1:
        index += 1
        print(f"{100* index/640} %")
        update_display(index)
    elif index < 640-1:
        print(f"Please wait for all images to be impainted {100* index/640} %")
    else:
        with out:
            print("Already at the last image.")


def on_delete(b):
    global filenames, index
    # Ensure filenames list is up-to-date before accessing
    filenames = sorted(
        set(f.name for f in EXTRACT_DIR.glob("*.png")) &
        set(f.name for f in INPAINT_DIR.glob("*.png"))
    )
    if not filenames:
        with out:
            print("No images to delete.")
        return

    fname = filenames[index]
    # delete file
    path = INPAINT_DIR / fname
    if path.exists():
        path.unlink()
        with out: # Add print statement to confirm deletion
            print(f"Deleted inpainted image: {fname}")
    else:
         with out: # Add print statement if file not found for deletion
            print(f"Inpainted image not found for deletion: {fname}")

    # regenerate and update display
    with out: # Add print statement to confirm regeneration start
        print(f"Attempting to regenerate {fname}...")
    inpaint_image(fname) # Call directly, blocking until complete
    update_display(index) # Update display after regeneration

btn_prev = widgets.Button(description='⟵ Previous')
btn_next = widgets.Button(description="Next ⟶")
btn_del = widgets.Button(description='Delete & Inpaint', button_style='danger')

btn_prev.on_click(on_prev)
btn_next.on_click(on_next)
btn_del.on_click(on_delete)

ui = widgets.VBox([widgets.HBox([btn_prev, btn_del, btn_next]), label, out])
display(ui)

# Start batch inpainting when the GUI is initialized
# Keep this for initial batch inpainting in background
start_batch_inpainting(update_display, index)

update_display(index)

VBox(children=(HBox(children=(Button(description='⟵ Previous', style=ButtonStyle()), Button(button_style='dang…

Inpainting 3853f344d0b167c6ea29361cc8e72a4f362fbbc1.png:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]