<a href="https://colab.research.google.com/github/eric15342335/realfill/blob/main/train_realfill.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/eric15342335/realfill
%cd realfill

In [None]:
!curl -L https://github.com/eric15342335/realfill/releases/download/dataset/realfill_data_release_full.zip -o realfill_data_release_full.zip
!unzip -q realfill_data_release_full.zip

In [None]:
!curl -L https://github.com/eric15342335/realfill/releases/download/dataset/jensen_images.zip -o jensen_images.zip
!unzip -q jensen_images.zip

In [None]:
%pwd
%ls -lh

In [None]:
!uv pip install -U -r requirements.txt --no-progress

In [None]:
from accelerate.utils import write_basic_config
write_basic_config()

In [None]:
import os
import shutil # Needed for file copying
import glob   # Needed for finding files matching a pattern
import re     # Needed for sorting filenames numerically

# --- Configuration Toggles ---
# Toggle switch - set to True for Google Drive storage, False for local storage
USE_DRIVE_STORAGE = True
# Toggle switch - set to True for FP32 training, False for FP16 (mixed precision)
USE_FP32 = False
# Toggle switch - set to True to use copied prior results to SUPPLEMENT reference images and adjust current paths
USE_GENERATED_REF_IMAGE = True # Renamed flag slightly for clarity
# Target number of TOTAL reference images after potentially adding generated ones
TARGET_REF_IMAGE_COUNT = 5

# --- Base Environment Setup ---
# Base directory for Google Drive storage
DRIVE_BASE_DIR = "/content/drive/MyDrive/RealFill"

# Keep your original environment variable setup
os.environ["DATASET"] = "realfill_data_release_full" # realfill_data_release_full, jensen_images
os.environ["MODEL_NAME"] = "stabilityai/stable-diffusion-2-inpainting"
os.environ["BENCHMARK"] = "RealBench"  # RealBench, Qualitative, Custom
os.environ["DATASET_NUMBER"] = "25" # Example dataset number from log

# --- Determine Base Output Prefix based on Storage ---
base_output_prefix = "" # Initialize
if USE_DRIVE_STORAGE:
    # Use Google Drive for storage
    try:
        if 'google.colab' in str(get_ipython()): # More robust Colab check
             from google.colab import drive
             drive.mount('/content/drive', force_remount=True) # Added force_remount for convenience
             print("Google Drive mounted successfully.")
             base_output_prefix = f'{DRIVE_BASE_DIR}/{os.environ["BENCHMARK"]}-{os.environ["DATASET_NUMBER"]}'
        elif os.path.exists(DRIVE_BASE_DIR): # Check if base Drive path exists outside Colab
             print("Google Drive path found, assuming mounted.")
             base_output_prefix = f'{DRIVE_BASE_DIR}/{os.environ["BENCHMARK"]}-{os.environ["DATASET_NUMBER"]}'
        else:
             print(f"Warning: Google Drive path '{DRIVE_BASE_DIR}' not found. Falling back to local paths.")
             USE_DRIVE_STORAGE = False
             base_output_prefix = f'{os.environ["BENCHMARK"]}-{os.environ["DATASET_NUMBER"]}'

    except NameError: # Handle case where get_ipython is not defined (running as script)
         if os.path.exists(DRIVE_BASE_DIR):
             print("Google Drive path found, assuming mounted.")
             base_output_prefix = f'{DRIVE_BASE_DIR}/{os.environ["BENCHMARK"]}-{os.environ["DATASET_NUMBER"]}'
         else:
             print(f"Warning: Google Drive path '{DRIVE_BASE_DIR}' not found. Falling back to local paths.")
             USE_DRIVE_STORAGE = False
             base_output_prefix = f'{os.environ["BENCHMARK"]}-{os.environ["DATASET_NUMBER"]}'
    except Exception as e:
         print(f"Error checking/mounting Google Drive: {e}. Falling back to local paths.")
         USE_DRIVE_STORAGE = False
         base_output_prefix = f'{os.environ["BENCHMARK"]}-{os.environ["DATASET_NUMBER"]}'

else:
    # Use local Colab storage (or standard local)
    base_output_prefix = f'{os.environ["BENCHMARK"]}-{os.environ["DATASET_NUMBER"]}'
    print("Using local storage.")

# --- Construct Suffixes Conditionally ---
# Suffix for the CURRENT RUN'S OUTPUT if using generated refs to supplement
generated_suffix = "-generated" if USE_GENERATED_REF_IMAGE else ""
# Suffix based on precision
precision_suffix = "-fp32" if USE_FP32 else "" # Applies to source copy path and current output path

# --- Set CURRENT Run's Output Paths with Conditional Suffixes ---
os.environ["OUTPUT_DIR"] = f'{base_output_prefix}-model{generated_suffix}{precision_suffix}'
os.environ["OUTPUT_IMG_DIR"] = f'{base_output_prefix}-results{generated_suffix}{precision_suffix}'

# --- Set Dataset/Input Paths ---
os.environ["TRAIN_DIR"] = f'{os.environ["DATASET"]}/{os.environ["BENCHMARK"]}/{os.environ["DATASET_NUMBER"]}'
os.environ["VAL_IMG"] = f'{os.environ["TRAIN_DIR"]}/target/target.png'
os.environ["VAL_MASK"] = f'{os.environ["TRAIN_DIR"]}/target/mask.png'
# Define the reference image directory path
REF_DIR = os.path.join(os.environ["TRAIN_DIR"], "ref")

# --- Create Necessary Directories ---
# Ensure base dataset directory structure exists before potentially copying into it
os.makedirs(os.environ["TRAIN_DIR"], exist_ok=True)
os.makedirs(REF_DIR, exist_ok=True) # Ensure ref dir exists

# Create the output directories for the CURRENT run
os.makedirs(os.environ["OUTPUT_DIR"], exist_ok=True)
os.makedirs(os.environ["OUTPUT_IMG_DIR"], exist_ok=True)


# --- Conditional Reference Image Supplementing Logic ---
if USE_GENERATED_REF_IMAGE:
    print("-" * 30)
    print("Attempting to SUPPLEMENT reference directory with prior results images...")

    # --- Define the SOURCE directory path (without '-generated') ---
    source_results_dir_for_copy = f'{base_output_prefix}-results{precision_suffix}'
    target_ref_dir = REF_DIR

    print(f"Source Results Dir for Copy: {source_results_dir_for_copy}")
    print(f"Target Reference Dir: {target_ref_dir}")

    if not os.path.isdir(source_results_dir_for_copy):
        print(f"Warning: Source results directory for copy '{source_results_dir_for_copy}' does not exist. Cannot copy images.")
    else:
        # Find existing source images (e.g., 0.png, 1.png ...) in the SOURCE dir
        source_image_files = []
        try:
            # Find files like '0.png', '1.png', ...
            potential_files = [f for f in os.listdir(source_results_dir_for_copy) if re.match(r"^\d+\.png$", f)]
            # Sort numerically
            potential_files.sort(key=lambda x: int(os.path.splitext(x)[0]))
            # Create full paths
            source_image_files = [os.path.join(source_results_dir_for_copy, f) for f in potential_files]

        except FileNotFoundError:
             print(f"Error: Tried to list files in '{source_results_dir_for_copy}', but it disappeared or access denied.")
             source_image_files = [] # Ensure it's empty if listing failed

        if not source_image_files:
            print(f"Warning: No numbered PNG images (like 0.png, 1.png) found in source directory '{source_results_dir_for_copy}'.")
        else:
            print(f"Found {len(source_image_files)} potential source images in {source_results_dir_for_copy}.")

            # --- Count current images in the target ref directory ---
            # Use glob to find ALL .png files, not just numbered ones
            current_ref_images = glob.glob(os.path.join(target_ref_dir, '*.png'))
            current_ref_count = len(current_ref_images)
            print(f"Found {current_ref_count} existing PNG images in {target_ref_dir}.")

            # --- Determine next available number index ---
            # Find the highest number currently used in filenames like "N.png" in the ref dir
            max_existing_num = -1
            for img_path in current_ref_images:
                basename = os.path.basename(img_path)
                match = re.match(r"^(\d+)\.png$", basename)
                if match:
                    max_existing_num = max(max_existing_num, int(match.group(1)))

            next_available_index = max_existing_num + 1
            print(f"Next available numerical index for copied files: {next_available_index}")

            # --- Calculate how many images to copy ---
            num_needed = TARGET_REF_IMAGE_COUNT - current_ref_count
            print(f"Target total reference image count: {TARGET_REF_IMAGE_COUNT}")


            if num_needed <= 0:
                print("Reference directory already has target number of images or more. No copy needed.")
            else:
                print(f"Number of additional images needed: {num_needed}")
                num_to_copy = min(num_needed, len(source_image_files))

                if num_to_copy == 0:
                     print("No source images available to copy.")
                elif num_to_copy < num_needed:
                     print(f"Warning: Only {num_to_copy} source images available, but needed {num_needed} more.")

                print(f"Will attempt to copy {num_to_copy} images.")
                copied_count = 0
                for i in range(num_to_copy):
                    src_path = source_image_files[i]

                    # --- Determine NEW filename for the destination ---
                    dst_filename = f"{next_available_index + i}.png"
                    dst_path = os.path.join(target_ref_dir, dst_filename)

                    # Double-check to prevent accidental overwrite (shouldn't happen with new logic)
                    if os.path.exists(dst_path):
                        print(f"Error: Calculated destination path '{dst_path}' already exists! Skipping copy. Check logic.")
                        continue

                    try:
                        shutil.copy2(src_path, dst_path) # copy2 preserves metadata
                        print(f"Copied '{src_path}' to '{dst_path}' (Added as new file)")
                        copied_count += 1
                    except Exception as e:
                        print(f"Error copying '{src_path}' to '{dst_path}': {e}")

                print(f"Finished copying. Added {copied_count} new images to the reference directory.")
                # Verify final count
                final_ref_images = glob.glob(os.path.join(target_ref_dir, '*.png'))
                print(f"Reference directory now contains {len(final_ref_images)} PNG images.")
    print("-" * 30)


# --- Set Shell Environment Variable for Precision ---
os.environ['PRECISION_ARG'] = '--mixed_precision=fp16' if not USE_FP32 else ''


# --- Final Print Statements ---
print(f"Using {'Google Drive' if USE_DRIVE_STORAGE else 'local'} storage")
print(f"USE_GENERATED_REF_IMAGE flag set to: {USE_GENERATED_REF_IMAGE}")
if USE_GENERATED_REF_IMAGE:
    print(f" -> Will attempt to ADD reference images from: {source_results_dir_for_copy}") # Clarify source
    print(f" -> Copied images will be named starting from index {next_available_index if 'next_available_index' in locals() else 'N/A'}") # Show next index
    print(f" -> Current run outputs will have '-generated' suffix.")
print(f"Using FP32 precision: {USE_FP32}")
print(f"Model output directory (current run): {os.environ['OUTPUT_DIR']}")
print(f"Results output directory (current run): {os.environ['OUTPUT_IMG_DIR']}")
print(f"Dataset train directory: {os.environ['TRAIN_DIR']}")
print(f"Reference image directory: {REF_DIR}")
print(f"Validation image: {os.environ['VAL_IMG']}")
print(f"Validation mask: {os.environ['VAL_MASK']}")
if os.environ['PRECISION_ARG']:
  print(f"Precision argument for scripts: {os.environ['PRECISION_ARG']}")
else:
  print("Precision argument for scripts: (Using default, likely FP32 if not fp16)")

In [None]:
!accelerate launch train_realfill.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --output_dir=$OUTPUT_DIR \
  --resolution=512 \
  --train_batch_size=16 \
  --gradient_accumulation_steps=1 \
  --unet_learning_rate=2e-4 \
  --text_encoder_learning_rate=4e-5 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=100 \
  --max_train_steps=2000 \
  --lora_rank=8 \
  --lora_dropout=0.1 \
  --lora_alpha=16 \
  --resume_from_checkpoint="latest" \
  --report_to tensorboard \
  --validation_steps 100 \
  --checkpointing_steps 100 \
  $PRECISION_ARG \
  --use_8bit_adam \
  --set_grads_to_none \
  --enable_xformers_memory_efficient_attention

In [None]:
!accelerate launch infer.py \
    --model_path=$OUTPUT_DIR \
    --validation_image=$VAL_IMG \
    --validation_mask=$VAL_MASK \
    --output_dir=$OUTPUT_IMG_DIR

In [None]:
# Zip final inference results
!zip -r9j $OUTPUT_IMG_DIR.zip $OUTPUT_IMG_DIR
# Zip tensorboard logs
!zip -r9D $OUTPUT_DIR-tensorboard.zip $OUTPUT_DIR/logs
%ls