In [None]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
  # Get the current GPU device
  device = torch.cuda.current_device()
  print(f"Current GPU device: {device}")
  print(f"GPU name: {torch.cuda.get_device_name(device)}")
else:
  print("CUDA is not available. Using CPU instead.")

Current GPU device: 0
GPU name: Tesla T4


In [None]:
# ===============================
# Step 3: Mount Google Drive
# ===============================

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# ===============================
# Step 1: Setup Model and Paths
# ===============================

model_id = "timbrooks/instruct-pix2pix"

pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None
).to("cuda")

try:
    pipe.enable_xformers_memory_efficient_attention()
    print("Enabled memory-efficient attention with xformers.")
except Exception as e:
    print("Could not enable xformers memory-efficient attention:", e)

In [None]:
import os
import math
import random
import torch
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
from tqdm import tqdm
from diffusers import StableDiffusionInstructPix2PixPipeline

# INPUT_FOLDER = "/content/drive/MyDrive/UC_Berkeley/CS280A_Final_Project/person-small/images"
# OUTPUT_FOLDER = "/content/drive/MyDrive/UC_Berkeley/CS280A_Final_Project/"
INPUT_FOLDER = "/content/drive/MyDrive/UC_Berkeley/CS280A_Final_Project/wrong_images"
OUTPUT_FOLDER = "/content/drive/MyDrive/UC_Berkeley/CS280A_Final_Project/wrong_images_edited"

os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# ===============================
# Step 2: Select Top 10 Images
# ===============================

image_files = [f for f in os.listdir(INPUT_FOLDER) if f.lower().endswith((".png", ".jpg", ".jpeg"))]
image_files.sort()
top_images = image_files[:]
print(f"Selected top {len(top_images)} images for editing.")

# ===============================
# Step 3: Define Prompt
# ===============================

PROMPT = (
    "Change the person's outfit to a well-tailored black suit with matching black pants, "
    "a crisp white shirt, and a navy blue tie. Ensure the outfit is formal and elegant, "
    "with all clothing items properly fitted. Maintain the same pose and lighting, photograph style, high resolution."
)

# We will not use a negative prompt to replicate the platform approach more closely.
# If needed, you can experiment by adding a negative prompt later.

# ===============================
# Step 4: Editing Function Mirroring the Platform Approach
# ===============================

def edit_image(pipe, instruction, image, steps=50, guidance_scale=7.5, image_guidance_scale=1.5, seed=None):
    """
    Edits an image using the StableDiffusionInstructPix2PixPipeline, following the platform's resizing and calling approach.
    """
    # Compute new dimensions as done by the platform code
    width, height = image.size
    factor = 512 / max(width, height)
    factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
    width = int((width * factor) // 64) * 64
    height = int((height * factor) // 64) * 64
    image = ImageOps.fit(image, (width, height), method=Image.Resampling.LANCZOS)

    # Set random seed if provided
    if seed is not None:
        generator = torch.manual_seed(seed)
    else:
        generator = None

    # Call the pipeline similarly to the platform code
    edited_image = pipe(
        instruction,
        image=image,
        guidance_scale=guidance_scale,
        image_guidance_scale=image_guidance_scale,
        num_inference_steps=steps,
        generator=generator,
    ).images[0]

    return edited_image

# ===============================
# Step 5: Process Images
# ===============================

base_seed = 42

for idx, img_name in enumerate(tqdm(top_images, desc="Editing Images")):
    input_path = os.path.join(INPUT_FOLDER, img_name)
    output_path = os.path.join(OUTPUT_FOLDER, img_name)

    try:
        print(f"\nEditing {img_name}...")
        image = Image.open(input_path).convert("RGB")

        # Set a unique seed per image for controlled variation
        seed = base_seed + idx

        # Edit the image using the platform-like method
        edited_image = edit_image(
            pipe=pipe,
            instruction=PROMPT,
            image=image,
            steps=50,              # Platform default steps
            guidance_scale=7.5,    # Platform default text CFG
            image_guidance_scale=1.5, # Platform default image CFG
        )

        edited_image.save(output_path)
        print(f"Saved edited image to {output_path}")

    except Exception as e:
        print(f"Failed to edit {img_name}: {e}")

print("\nEditing completed")

# ===============================
# Step 6: Visualization (Optional)
# ===============================

def display_images(original, edited, title="Comparison"):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].imshow(original)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(edited)
    axs[1].set_title("Edited Image")
    axs[1].axis("off")

    plt.suptitle(title)
    plt.show()


Selected top 1 images for editing.


Editing Images:   0%|          | 0/1 [00:00<?, ?it/s]


Editing frame_00269.png...


  0%|          | 0/50 [00:00<?, ?it/s]

Editing Images: 100%|██████████| 1/1 [00:07<00:00,  7.04s/it]

Saved edited image to /content/drive/MyDrive/UC_Berkeley/CS280A_Final_Project/wrong_images_edited/frame_00269.png

Editing completed



