In [None]:
# Add parent directory to Python path for module imports
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

In [None]:
import os
import math 
from config import NYUV2, MAKE3D

# ============================================================================
# CONFIGURATION: Choose your dataset here
# ============================================================================
DATASET = "Make3D"   # Options: "NYUv2" or "Make3D"
split = "test"       # Options: "train" or "test"
proc_fraction = 0.43 # Fraction of processors to use on machine
# ============================================================================

# Dataset-specific configuration
if DATASET == "NYUv2":
    depth_dir = os.path.join(NYUV2.data_dir, "NYUv2", split + "_rgb")
    bash_file = f"run_all_nyuv2_{split}.sh"
    file_extension = ".png"
    python_script = f"python run_coordinate_descent_nyuv2.py {split}"
    
elif DATASET == "Make3D":
    depth_dir = os.path.join(MAKE3D.data_dir, "Make3D", "Train400Img" if split == "train" else "Test134Img")
    bash_file = f"run_all_make3d_{split}.sh"
    file_extension = ".jpg"
    python_script = f"python run_coordinate_descent_make3d.py {split}"
    
else:
    raise ValueError(f"Unknown dataset: {DATASET}. Choose 'NYUv2' or 'Make3D'")

print(f"Dataset: {DATASET}")
print(f"Split: {split}")
print(f"Output: {bash_file}")
print()

num_procs = os.cpu_count()
print("Number of processors:", num_procs)

max_parallel = max(1, int(proc_fraction * num_procs))

# List all training data and extract image number
if DATASET == "NYUv2":
    # For NYUv2, remove the .png extension
    files = sorted([f[:-4] for f in os.listdir(depth_dir) if f.endswith(file_extension)])
else:
    # For Make3D, keep the full filename
    files = sorted([f for f in os.listdir(depth_dir) if f.endswith(file_extension)])
    
print(f"Found {len(files)} files")

num_batches = math.ceil(len(files) / max_parallel)

# Write bash script
batch = 0
with open(bash_file, "w") as f:
    f.write("#!/bin/bash\n\n")
    f.write(f"# Generated script for {DATASET} {split} set\n")
    f.write(f"# Processing {len(files)} files in {num_batches} batches\n")
    f.write(f"# Running {max_parallel} jobs in parallel\n\n")
    f.write(f"echo 'Starting batch {batch+1}/{num_batches}'\n")

    for i, num in enumerate(files, 1):
        f.write(f"{python_script} {num} &\n")
        # Insert a wait every max_parallel jobs
        if i % max_parallel == 0:
            f.write("wait\n\n")
            batch += 1
            if i < len(files):  # Don't print if this is the last batch
                f.write(f"echo 'Starting batch {batch+1}/{num_batches}'\n")

    # Final wait for leftover jobs
    f.write("wait\n\n")
    f.write(f"echo 'Completed all {len(files)} jobs for {DATASET} {split} set'\n")

print(f"\nWrote {bash_file} with {len(files)} jobs in {num_batches} parallel batches.")