# Latent Space Optimization Results Visualization

Show, analyze and visualize the LSO results.

## Image Grids

Visualize all available images of an optimization run.

### Optimized 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def grid_optimized(version, iters, idxs):
    """
    Plot a grid of optimized images.
    Args:
		version: string, version of the experiment
		iters: list of iteration numbers (rows)
		idxs:  list of sample indices (columns)
    """
    n_rows = len(iters)
    n_cols = len(idxs)
    fig, axes = plt.subplots(
        n_rows, n_cols,
        figsize=(4 * n_cols, 4 * n_rows),  # ~4" per subplot
        squeeze=False
    )
    fig.subplots_adjust(hspace=0.4)

    # Load scores
    scores = np.load(f"../results/{version}/results.npz", allow_pickle=True)
    
    for i, it in enumerate(iters):
        for j, idx in enumerate(idxs):
            a = Image.open(f'/BS/optdif/work/results/{version}/data/samples/iter_{it}/img_opt/{idx}.png')
            ax_opt = axes[i][j]
            ax_opt.imshow(a)
            ax_opt.set_title(f"{scores['opt_point_properties'][it + idx].item():.2f}")
            ax_opt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
grid_optimized(
    version="ctrloralter_gbo_28",
    iters=list(range(0, 15, 5)),
    idxs=[0,1,2,3,4],
)

### Inital & Optimized 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def grid_initial_optimized(version, iters, idxs):
    """
    Plot a grid of initially sampled and optimized images.
    Args:
		version: string, version of the experiment
		iters: list of iteration numbers (rows)
		idxs:  list of sample indices (columns)
    """
    n_rows = len(iters)
    n_cols = len(idxs)
    fig, axes = plt.subplots(
        n_rows, n_cols * 2,
        figsize=(4 * n_cols * 2, 4 * n_rows),  # ~4" per subplot
        squeeze=False
    )
    fig.subplots_adjust(hspace=0.4)

    # Load scores
    scores = np.load(f"../results/{version}/results.npz", allow_pickle=True)
    
    for i, it in enumerate(iters):
        for j, idx in enumerate(idxs):
            # Left: initially sampled
            a = Image.open(f'/BS/optdif/work/results/{version}/data/samples/iter_{it}/img_init/{idx}.png')
            ax_init = axes[i][2*j]
            ax_init.imshow(a)
            ax_init.set_title(f"Initial ({scores['init_point_properties'][it + idx].item():.2f})")
            ax_init.axis('off')

            # Right: optimized
            b = Image.open(f'/BS/optdif/work/results/{version}/data/samples/iter_{it}/img_opt/{idx}.png')
            ax_opt = axes[i][2*j + 1]
            ax_opt.imshow(b)
            ax_opt.set_title(f"Optimized ({scores['opt_point_properties'][it + idx].item():.2f})")
            ax_opt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
grid_initial_optimized(
    version="sd_dngo_04",
    iters=list(range(0, 20, 5)),
    idxs=[0,1,2,3,4],
)

### Original & Initial & Optimized

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def grid_original_initial_optimized(version, iters, idxs):
    """
    Plot a grid of original, initially sampled, and optimized images.
    Args:
		version: string, version of the experiment
		iters: list of iteration numbers (rows)
		idxs:  list of sample indices (columns)
    """
    n_rows = len(iters)
    n_cols = len(idxs)
    fig, axes = plt.subplots(
        n_rows, n_cols * 3,
        figsize=(4 * n_cols * 3, 4 * n_rows),  # ~4" per subplot
        squeeze=False
    )
    fig.subplots_adjust(hspace=0.4)

    # Load scores
    scores = np.load(f"../results/{version}/results.npz", allow_pickle=True)
    
    for i, it in enumerate(iters):
        for j, idx in enumerate(idxs):
            # Left: original
            a = Image.open(f'/BS/optdif/work/results/{version}/data/samples/iter_{it}/img_orig/{idx}.png')
            ax_orig = axes[i][3*j]
            ax_orig.imshow(a)
            ax_orig.set_title(f"Iteration {it}  Top {idx}\nOriginal ({scores['orig_point_properties'][it + idx].item():.2f})")
            ax_orig.axis('off')

            # Middle: initially sampled
            b = Image.open(f'/BS/optdif/work/results/{version}/data/samples/iter_{it}/img_init/{idx}.png')
            ax_init = axes[i][3*j + 1]
            ax_init.imshow(b)
            ax_init.set_title(f"Initial ({scores['init_point_properties'][it + idx].item():.2f})")
            ax_init.axis('off')

            # Right: optimized
            c = Image.open(f'/BS/optdif/work/results/{version}/data/samples/iter_{it}/img_opt/{idx}.png')
            ax_opt = axes[i][3*j + 2]
            ax_opt.imshow(c)
            ax_opt.set_title(f"Optimized ({scores['opt_point_properties'][it + idx].item():.2f})")
            ax_opt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
grid_original_initial_optimized(
    version="ctrloralter_gbo_28",
    iters=list(range(0, 20, 5)),
    idxs=[0,1],
)

## Smile Scores Visualization

Plots available smile scores and statistics for each iteration.

### Optimized

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def scores_optimized(version):
	
	# Load results from the npz file
	results = np.load(f"../results/{version}/results.npz", allow_pickle=True)

	# Convert the npz file to a pandas DataFrame
	df = pd.DataFrame({
		'opt': list(results['opt_point_properties']),
		'model_version': list(results['opt_model_version']),
	})

	# Group by model version and aggregate
	df = df.groupby('model_version').agg({
		'opt': ['mean', 'min', 'max'],
	}).reset_index()

	fig, ax = plt.subplots(figsize=(10, 6))
	ax.plot(df['model_version'], df['opt']['mean'], marker='o', label='Optimized Mean')
	ax.fill_between(df['model_version'], df['opt']['min'], df['opt']['max'], alpha=0.2, label='Optimized Min-Max Range')
	ax.set_xlabel('Iteration')
	ax.set_ylabel('Values')
	ax.set_title('Mean Objective Values of the Top Five Samples per Iteration')
	ax.set_ylim(0, 5)
	ax.axhline(y=2, color='gray', linestyle='--', label='Input Max')
	ax.legend()
	plt.xticks(ticks=range(len(df['model_version'])), labels=df['model_version'], rotation=45)

	plt.tight_layout()
	plt.show()

In [None]:
scores_optimized(version="ctrloralter_gbo_28")

### Initial & Optimized

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def scores_initial_optimized(version):
	
	# Load results from the npz file
	results = np.load(f"../results/{version}/results.npz", allow_pickle=True)

	# Convert the npz file to a pandas DataFrame
	df = pd.DataFrame({
		'opt': list(results['opt_point_properties']),
		'init': list(results['init_point_properties']),
		'model_version': list(results['opt_model_version']),
	})

	# Group by model version and aggregate
	df = df.groupby('model_version').agg({
		'opt': ['mean', 'min', 'max'],
		'init': ['mean', 'min', 'max']
	}).reset_index()

	fig, ax = plt.subplots(figsize=(10, 6))
	ax.plot(df['model_version'], df['opt']['mean'], marker='o', label='Optimized Mean')
	ax.fill_between(df['model_version'], df['opt']['min'], df['opt']['max'], alpha=0.2, label='Optimized Min-Max Range')
	ax.plot(df['model_version'], df['init']['mean'], marker='x', label='Initial Mean', linestyle=':')
	ax.set_xlabel('Iteration')
	ax.set_ylabel('Values')
	ax.set_title('Mean Objective Values of the Top Five Samples per Iteration')
	ax.set_ylim(0, 5)
	ax.axhline(y=2, color='gray', linestyle='--', label='Input Max')
	ax.legend()
	plt.xticks(ticks=range(len(df['model_version'])), labels=df['model_version'], rotation=45)

	plt.tight_layout()
	plt.show()

In [None]:
scores_initial_optimized(version="ctrloralter_gbo_28")

### Optimized & Initial & Original

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def scores_original_initial_optimized(version):

	# Load results from the npz file
	results = np.load(f"../results/{version}/results.npz", allow_pickle=True)

	# Convert the npz file to a pandas DataFrame
	df = pd.DataFrame({
		'opt': list(results['opt_point_properties']),
		'init': list(results['init_point_properties']),
		'orig': list(results['orig_point_properties']),
		'model_version': list(results['opt_model_version']),
	})

	# Group by model version and aggregate
	df = df.groupby('model_version').agg({
		'opt': ['mean', 'min', 'max'],
		'init': ['mean', 'min', 'max'],
		'orig': ['mean', 'min', 'max']
	}).reset_index()

	fig, ax = plt.subplots(figsize=(10, 6))
	ax.plot(df['model_version'], df['opt']['mean'], marker='o', label='Optimized Mean')
	ax.fill_between(df['model_version'], df['opt']['min'], df['opt']['max'], alpha=0.2, label='Optimized Min-Max Range')
	ax.plot(df['model_version'], df['orig']['mean'], marker='s', label='Original Mean', linestyle='--')
	ax.plot(df['model_version'], df['init']['mean'], marker='x', label='Initial Mean', linestyle=':')
	ax.set_xlabel('Iteration')
	ax.set_ylabel('Values')
	ax.set_title('Mean Objective Values of the Top Five Samples per Iteration')
	ax.set_ylim(0, 5)
	ax.axhline(y=2, color='gray', linestyle='--', label='Input Max')
	ax.legend()
	plt.xticks(ticks=range(len(df['model_version'])), labels=df['model_version'], rotation=45)

	plt.tight_layout()
	plt.show()

In [None]:
scores_original_initial_optimized("ctrloralter_gbo_23")

## Scores Compare

### Compare Optimized Multiple Versions

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def scores_compare_optimized_mult_versions(version_dict):

	result_dict = {}
	min_iterations = np.inf

	for version, version_name in version_dict.items():
		# Load results from the npz file
		results = np.load(f"../results/{version}/results.npz", allow_pickle=True)

		# Convert the npz file to a pandas DataFrame
		df = pd.DataFrame({
			'opt': list(results['opt_point_properties']),
			'model_version': list(results['opt_model_version']),
		})

		# Group by model version and aggregate
		df = df.groupby('model_version').agg({'opt': 'mean'}).reset_index()

		# Load version
		version_name = version_name or f"Version {version}"

		# Store the results in the result_dict
		result_dict[version_name] = df['opt']

		# Update min_iterations
		min_iterations = min(min_iterations, len(df['model_version']))

	fig, ax = plt.subplots(figsize=(10, 6))

	for version_name, opt_values in result_dict.items():
		ax.plot(range(min_iterations), opt_values[:min_iterations], marker='o', label=f'{version_name}')

	ax.set_xlabel('Iteration')
	ax.set_ylabel('Values')
	ax.set_title('Mean Objective Values of Top Five Samples per Iteration')
	ax.set_ylim(0, 5)
	ax.axhline(y=2, color='gray', linestyle='--', label='Input Max')
	ax.legend()
	plt.xticks(ticks=range(min_iterations), labels=df['model_version'][:min_iterations], rotation=45)

	plt.tight_layout()
	plt.show()

In [None]:
scores_compare_optimized_mult_versions({
    "ctrloralter_gbo_20": "LoRA Style",
    "ctrloralter_gbo_21": "LORA Style + Depth",
    "sd_gbo_03": "SD VAE",
    "latent_vqvae_gbo_06": "LatentVQVAE",
})

In [None]:
scores_compare_optimized_mult_versions({
    "latent_vqvae_dngo_06": "Bayesian Optimization (DNGO)",
    "latent_vqvae_gbo_06": "Gradient-Based Optimization",
})

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

version_dict = {
    "ctrloralter_gbo_25": "Initial Style",
    "ctrloralter_gbo_26": "Finetuned Style",
    "ctrloralter_gbo_23": "Initial Style + Depth",
    "ctrloralter_gbo_24": "Finetuned Style + Depth",
    "ctrloralter_gbo_27": "Initial Style + HED",
    "ctrloralter_gbo_28": "Finetuned Style + HED",
}

result_dict = {}
min_iterations = np.inf

for version, version_name in version_dict.items():
	# Load results from the npz file
	results = np.load(f"../results/{version}/results.npz", allow_pickle=True)

	# Convert the npz file to a pandas DataFrame
	df = pd.DataFrame({
		'opt': list(results['opt_point_properties']),
		'model_version': list(results['opt_model_version']),
	})

	# Group by model version and aggregate
	df = df.groupby('model_version').agg({'opt': 'mean'}).reset_index()

	# Load version
	version_name = version_name or f"Version {version}"

	# Store the results in the result_dict
	result_dict[version_name] = df['opt']

	# Update min_iterations
	min_iterations = min(min_iterations, len(df['model_version']))

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(range(min_iterations), result_dict['Initial Style'][:min_iterations], marker='o', label=f'Initial Style', color='tab:blue', linestyle='--')
ax.plot(range(min_iterations), result_dict['Finetuned Style'][:min_iterations], marker='o', label=f'Finetuned Style', color='tab:blue')
ax.plot(range(min_iterations), result_dict['Initial Style + Depth'][:min_iterations], marker='o', label=f'Initial Style + Depth', color='tab:green', linestyle='--')
ax.plot(range(min_iterations), result_dict['Finetuned Style + Depth'][:min_iterations], marker='o', label=f'Finetuned Style + Depth', color='tab:green')
ax.plot(range(min_iterations), result_dict['Initial Style + HED'][:min_iterations], marker='o', label=f'Initial Style + HED', color='tab:orange', linestyle='--')
ax.plot(range(min_iterations), result_dict['Finetuned Style + HED'][:min_iterations], marker='o', label=f'Finetuned Style + HED', color='tab:orange')

ax.set_xlabel('Iteration')
ax.set_ylabel('Values')
ax.set_title('Mean Objective Values of Top Five Samples per Iteration')
ax.set_ylim(0, 5)
ax.axhline(y=2, color='gray', linestyle='--', label='Input Max')
ax.legend()
plt.xticks(ticks=range(min_iterations), labels=df['model_version'][:min_iterations], rotation=45)

plt.tight_layout()
plt.show()

## Experiments

## Setup

### Results Directory

In [None]:
from pathlib import Path

BASE_DIR = Path("../results").expanduser().resolve()

def get_result_dir(version : str, seed : int) -> Path:
    """
    Return the path to the first main.log that matches the seed.
    Allowed directory names:
      <version>_<seed>
      <version>_<seed>_<anything>
    """
    # exact match first (no job-id)
    exact = BASE_DIR / f"{version}_{seed}"
    if exact.is_dir():
        return exact

    # wildcard for any trailing underscore / job-id
    pattern = f"{version}_{seed}_*/"
    matches = sorted(BASE_DIR.glob(pattern))
    if not matches:
        raise FileNotFoundError(f"No log found for pattern {pattern} under {BASE_DIR}")
    return matches[0]

## Experiment 1 (SD-VAE)

In [None]:
# import matplotlib.pyplot as plt
# import matplotlib.gridspec as gridspec

# import numpy as np
# from PIL import Image

# np.random.seed(42)  # For reproducibility

# # Create a GridSpec with an extra column for spacing
# fig = plt.figure(constrained_layout=True, figsize=(8*2+0.1, 6*2+0.1*2))
# gs = gridspec.GridSpec(figure=fig, nrows=8, ncols=9, width_ratios=[1,1,1,1,0.1,1,1,1,1], height_ratios=[1,1,0.1,1,1,0.1,1,1])

# # Create a list to hold axes (ignoring the spacer column)
# spacer_rows  = {2, 5}
# spacer_col   = 4
# axes = []
# # keep only the data rows, then step through them two at a time
# data_rows = [r for r in range(8) if r not in spacer_rows]
# left_cols  = range(0, spacer_col)
# right_cols = range(spacer_col + 1, 9)

# for top_idx in range(0, len(data_rows), 2):
#     top, bottom = data_rows[top_idx], data_rows[top_idx + 1]

#     # left half: top row, then bottom row
#     for c in left_cols:
#         axes.append(fig.add_subplot(gs[top,    c]))
#     for c in left_cols:
#         axes.append(fig.add_subplot(gs[bottom, c]))

#     # right half (skip spacer column): top row, then bottom row
#     for c in right_cols:
#         axes.append(fig.add_subplot(gs[top,    c]))
#     for c in right_cols:
#         axes.append(fig.add_subplot(gs[bottom, c]))

# # Sample 4 random images from the results and convert to iter and img number
# random_samples = np.random.choice(range(100), size=4, replace=False)
# iter_nums = [x - (x % 5) for x in random_samples]
# sample_nums = [x % 5 for x in random_samples]

# # Get the images
# samples = []
# for cell_idx, cell_name in enumerate(["dngo_train_lbfgsb", "dngo_normal_lbfgsb", "dngo_train_trustconstr", "dngo_normal_trustconstr", "gbo_train", "gbo_normal"]):

#     # Define data directory
#     data_dir = get_result_dir(f"ex1_sd15_{cell_name}", seed=42) / "data/samples"
    
#     for iter_num, sample_num in zip(iter_nums, sample_nums):

#         # Load initial image
#         init_file = data_dir / f"iter_{iter_num}/img_init/{sample_num}.png"
#         init_img = Image.open(init_file).convert("RGB")
#         samples.append(init_img)

#         # Load optimized image
#         opt_file = data_dir / f"iter_{iter_num}/img_opt/{sample_num}.png"
#         opt_img = Image.open(opt_file).convert("RGB")
#         samples.append(opt_img)

# # Plot images on the axes
# for i in range(48):
#     ax = axes[i]
#     ax.imshow(samples[i])
#     # Remove ticks and spines
#     ax.set_xticks([])
#     ax.set_yticks([])
#     for spine in ax.spines.values():
#         spine.set_visible(False)

# # Column titles
# fig.text(0.258, 1.02, "Training Data Samples", ha="center", va="center", fontsize=20)
# fig.text(0.755, 1.02, "Normal Distribution Samples", ha="center", va="center", fontsize=20)

# # Rot titles
# fig.text(-0.02, 0.84, "L-BFGS-B", ha="center", va="center", fontsize=20, rotation=90)
# fig.text(-0.02, 0.5, "trust-constr", ha="center", va="center", fontsize=20, rotation=90)
# fig.text(-0.02, 0.16, "GBO", ha="center", va="center", fontsize=20, rotation=90)

# # Vertical line
# fig.add_artist(plt.Line2D([0.5, 0.5], [0, 1], transform=fig.transFigure, color='gray', linewidth=2, linestyle=':'))

# # Horizontal lines
# fig.add_artist(plt.Line2D([0, 1], [0.67, 0.67], transform=fig.transFigure, color='gray', linewidth=2, linestyle=':'))
# fig.add_artist(plt.Line2D([0, 1], [0.33, 0.33], transform=fig.transFigure, color='gray', linewidth=2, linestyle=':'))

# # plt.savefig("vis/latent_prior_samples_comparison.pdf", bbox_inches='tight')
# plt.show()

### SD1.5 Train-Data vs. Normal Distribution

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image

import numpy as np
np.random.seed(42)  # For reproducibility

num_outer_rows = 8
num_outer_cols = 5
pair_cols = [0, 1, 3, 4]		# columns that get 2-image blocks (0,1 on the left of the spacer; 3,4 on the right)
data_rows = [0, 1, 3, 4, 6, 7]	# real data rows (skip the spacer rows 2 and 5)

fig = plt.figure(figsize=(8*2+0.2, 6*2*1.05+0.2))

# Outer grid
outer = gridspec.GridSpec(
    nrows=num_outer_rows,
    ncols=num_outer_cols,
    figure=fig,
    width_ratios=[1, 1, 0.01, 1, 1],
    height_ratios=[1, 1, 0.1, 1, 1, 0.1, 1, 1],
    wspace=0.1,
    hspace=0.2,
    bottom=0, top=1,
    left=0, right=1,
)

axes = []
# rows that belong together: (0,1), (3,4), (6,7)
for upper_idx in range(0, len(data_rows), 2):
    top, bottom = data_rows[upper_idx], data_rows[upper_idx + 1]

    # iterate over the four 2-image blocks in the desired order
    for block_group in ([0, 1],          # left half  = training samples
                        [3, 4]):         # right half = normal-dist samples
        for pc in block_group:
            # inner grids for this block
            inner_top = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer[top,    pc], wspace=0, hspace=0)
            inner_bottom = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer[bottom, pc], wspace=0, hspace=0)
            
            # Append top inner grid
            axes.append(fig.add_subplot(inner_top[0, 0]))    # init
            axes.append(fig.add_subplot(inner_top[0, 1]))    # opt
            
            # Append bottom inner grid
            axes.append(fig.add_subplot(inner_bottom[0, 0])) # init
            axes.append(fig.add_subplot(inner_bottom[0, 1])) # opt

# Sample 4 random images from the results and convert to iter and img number
random_samples = np.random.choice(range(100), size=4, replace=False)
iter_nums = [x - (x % 5) for x in random_samples]
sample_nums = [x % 5 for x in random_samples]

# Get the images
samples = []
scores = []
for cell_idx, cell_name in enumerate(["dngo_train_lbfgsb", "dngo_normal_lbfgsb", "dngo_train_trustconstr", "dngo_normal_trustconstr", "gbo_train", "gbo_normal"]):

    # Define data directory
    result_dir = get_result_dir(f"ex1_sd15_{cell_name}", seed=42)
    data_dir = result_dir / "data/samples"
    scores_file = np.load(result_dir / "results.npz")

    for iter_num, sample_num in zip(iter_nums, sample_nums):

        # Load initial image
        init_file = data_dir / f"iter_{iter_num}/img_init/{sample_num}.png"
        init_img = Image.open(init_file).convert("RGB")
        samples.append(init_img)
        scores.append(f"Initial: {scores_file['init_point_properties'][iter_num + sample_num].item():.2f}")

        # Load optimized image
        opt_file = data_dir / f"iter_{iter_num}/img_opt/{sample_num}.png"
        opt_img = Image.open(opt_file).convert("RGB")
        samples.append(opt_img)
        scores.append(f"Optimized: {scores_file['opt_point_properties'][iter_num + sample_num].item():.2f}")

# Plot images on the axes
for i in range(48):
    ax = axes[i]
    ax.imshow(samples[i])
    ax.set_title(f"{scores[i]}")
    # Remove ticks and spines
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)

# Column titles
fig.text(0.258, 1.04, "Training Data Samples", ha="center", va="center", fontsize=20)
fig.text(0.755, 1.04, "Normal Distribution Samples", ha="center", va="center", fontsize=20)

# Row titles
fig.text(-0.02, 0.84, "L-BFGS-B", ha="center", va="center", fontsize=20, rotation=90)
fig.text(-0.02, 0.5, "trust-constr", ha="center", va="center", fontsize=20, rotation=90)
fig.text(-0.02, 0.16, "GBO", ha="center", va="center", fontsize=20, rotation=90)

# Vertical line
fig.add_artist(plt.Line2D([0.5, 0.5], [0, 1], transform=fig.transFigure, color='gray', linewidth=2, linestyle=':'))

# Horizontal lines
fig.add_artist(plt.Line2D([0, 1], [0.68, 0.68], transform=fig.transFigure, color='gray', linewidth=2, linestyle=':'))
fig.add_artist(plt.Line2D([0, 1], [0.33, 0.33], transform=fig.transFigure, color='gray', linewidth=2, linestyle=':'))

plt.savefig("vis/ex1_sd15_train_normal.pdf", bbox_inches='tight')
plt.show()

### SD3.5 Original vs. Finetuned

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from PIL import Image

import numpy as np
np.random.seed(42)  # For reproducibility

num_outer_rows = 5
num_outer_cols = 5
pair_cols = [0, 1, 3, 4]	# columns that get 2-image blocks (0,1 on the left of the spacer; 3,4 on the right)
data_rows = [0, 1, 3, 4]	# real data rows (skip the spacer rows 2 and 5)

fig = plt.figure(figsize=(8*2+0.2, 4*2*1.05+0.2))

# Outer grid
outer = gridspec.GridSpec(
    nrows=num_outer_rows,
    ncols=num_outer_cols,
    figure=fig,
    width_ratios=[1, 1, 0.01, 1, 1],
    height_ratios=[1, 1, 0.1, 1, 1],
    wspace=0.1,
    hspace=0.2,
    bottom=0, top=1,
    left=0, right=1,
)

axes = []
# rows that belong together: (0,1), (3,4)
for upper_idx in range(0, len(data_rows), 2):
    top, bottom = data_rows[upper_idx], data_rows[upper_idx + 1]

    # iterate over the four 2-image blocks in the desired order
    for block_group in ([0, 1],          # left half  = training samples
                        [3, 4]):         # right half = normal-dist samples
        for pc in block_group:
            # inner grids for this block
            inner_top = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer[top,    pc], wspace=0.02, hspace=0)
            inner_bottom = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer[bottom, pc], wspace=0.02, hspace=0)
            
            # Append top inner grid
            axes.append(fig.add_subplot(inner_top[0, 0]))    # init
            axes.append(fig.add_subplot(inner_top[0, 1]))    # opt
            
            # Append bottom inner grid
            axes.append(fig.add_subplot(inner_bottom[0, 0])) # init
            axes.append(fig.add_subplot(inner_bottom[0, 1])) # opt

# Sample 4 random images from the results and convert to iter and img number
random_samples = np.random.choice(range(100), size=4, replace=False)
iter_nums = [x - (x % 5) for x in random_samples]
sample_nums = [x % 5 for x in random_samples]

# Get the images
samples = []
scores = []
for cell_idx, cell_name in enumerate(["sd35_dngo_train_lbfgsb", "sd35f_dngo_train_lbfgsb", "sd35_dngo_train_trustconstr", "sd35f_dngo_train_trustconstr"]):

    # Define data directory
    result_dir = get_result_dir(f"ex1_{cell_name}", seed=42)
    data_dir = result_dir / "data/samples"
    scores_file = np.load(result_dir / "results.npz")

    for iter_num, sample_num in zip(iter_nums, sample_nums):

        # Load initial image
        init_file = data_dir / f"iter_{iter_num}/img_init/{sample_num}.png"
        init_img = Image.open(init_file).convert("RGB")
        samples.append(init_img)
        scores.append(f"Initial: {scores_file['init_point_properties'][iter_num + sample_num].item():.2f}")

        # Load optimized image
        opt_file = data_dir / f"iter_{iter_num}/img_opt/{sample_num}.png"
        opt_img = Image.open(opt_file).convert("RGB")
        samples.append(opt_img)
        scores.append(f"Optimized: {scores_file['opt_point_properties'][iter_num + sample_num].item():.2f}")

# Plot images on the axes
for i in range(32):
    ax = axes[i]
    ax.imshow(samples[i])
    ax.set_title(f"{scores[i]}")
    # Remove ticks and spines
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)

# Column titles
fig.text(0.248, 1.06, "Original SD3.5", ha="center", va="center", fontsize=20)
fig.text(0.758, 1.06, "Finetuned SD3.5", ha="center", va="center", fontsize=20)

# Row titles
fig.text(-0.02, 0.77, "L-BFGS-B", ha="center", va="center", fontsize=20, rotation=90)
fig.text(-0.02, 0.25, "trust-constr", ha="center", va="center", fontsize=20, rotation=90)

# Vertical line
fig.add_artist(plt.Line2D([0.5, 0.5], [0, 1], transform=fig.transFigure, color='gray', linewidth=2, linestyle=':'))

# Horizontal line
fig.add_artist(plt.Line2D([0, 1], [0.51, 0.51], transform=fig.transFigure, color='gray', linewidth=2, linestyle=':'))

plt.savefig("vis/ex1_sd35_original_finetuned.pdf", bbox_inches='tight')
plt.show()