# Distance vs. Complexity (Figs. 5 and 10)

In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from aeroblade.paper import configure_mpl, get_nice_name, set_figsize

configure_mpl()
set_figsize("single", ratio=1.0, factor=0.49)

output_dir = Path("output/02/default/figures")
output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
combined = pd.read_parquet("output/02/default/combined_dist_compl.parquet").query(
    "repo_id == 'max'"
)
combined[["dir", "distance_metric"]] = combined[["dir", "distance_metric"]].map(
    get_nice_name
)

for nice_dir, group_df in combined.groupby("dir", observed=True):
    pass
    sns.histplot(
        x=np.stack(group_df.complexity).flatten(),
        y=np.stack(group_df.distance).flatten(),
        stat="density",
        bins=100,
        legend=False,
        binrange=((0.05, 0.4), (0, 0.07)),
        vmax=1000,
    )
    plt.xlabel("Complexity")
    plt.ylabel(group_df.iloc[0].distance_metric)
    plt.savefig(output_dir / f"dist_vs_compl_{nice_dir}.pdf")
    plt.close()

In [None]:
all_generated = combined.query("dir != 'Real'")
all_generated = pd.DataFrame(
    {
        "complexity": np.concatenate(all_generated.complexity.values),
        "distance": np.concatenate(all_generated.distance.values),
    }
)
sns.histplot(
    x=all_generated.complexity,
    y=all_generated.distance,
    stat="density",
    bins=100,
    legend=False,
    binrange=((0.05, 0.4), (0, 0.07)),
    vmax=1000,
)
plt.xlabel("Complexity")
plt.ylabel(combined.iloc[0].distance_metric)
plt.savefig(output_dir / "dist_vs_compl_all_generated.pdf")
plt.close()

# Patches with High/Low Reconstruction Distance (Fig. 4)

In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import numpy as np
import pandas as pd
from aeroblade.data import ImageFolder
from aeroblade.image import extract_patches
from aeroblade.paper import DATASET_ORDER, configure_mpl, get_nice_name, set_figsize
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.transforms.v2.functional import to_pil_image

configure_mpl()
set_figsize()
plt.rcParams.update(
    {"figure.dpi": 600, "axes.grid": False, "axes.linewidth": 0.5, "axes.labelsize": 5}
)

output_dir = Path("output/02/default/figures")
output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
combined = pd.read_parquet("output/02/default/combined_dist_compl.parquet").query(
    "repo_id == 'max'"
)
combined["nice_dir"] = combined["dir"].apply(get_nice_name)

fig = plt.figure()
grid = ImageGrid(
    fig,
    111,
    nrows_ncols=(3, len(combined.dir.unique())),
    axes_pad=0.02,
    direction="column",
    share_all=True,
)
i = 0
for nice_dir in DATASET_ORDER:
    group_df = combined.query("nice_dir == @nice_dir")
    dir = group_df.iloc[0].dir
    distance_patches = np.stack(group_df.distance)
    lowest_indices = np.argwhere(distance_patches < np.quantile(distance_patches, 0.01))
    for idx in np.random.default_rng(seed=42).permutation(lowest_indices)[:3]:
        image = ImageFolder(Path(dir) / group_df.iloc[idx[0]].file)[0][0]
        image_patches = extract_patches(
            image.unsqueeze(0), size=128, stride=64
        ).squeeze(0)
        selected_patch = image_patches[idx[1]]
        grid[i].imshow(np.array(to_pil_image(selected_patch)))
        i += 1
    grid[i - 1].set_xlabel(nice_dir)
grid[0].set_xticks([])
grid[0].set_yticks([])
plt.savefig(output_dir / "lowest_percent.pdf")
plt.close()

fig = plt.figure()
grid = ImageGrid(
    fig,
    111,
    nrows_ncols=(3, len(combined.dir.unique())),
    axes_pad=0.02,
    direction="column",
    share_all=True,
)
i = 0
for nice_dir in DATASET_ORDER:
    group_df = combined.query("nice_dir == @nice_dir")
    dir = group_df.iloc[0].dir
    distance_patches = np.stack(group_df.distance)
    lowest_indices = np.argwhere(distance_patches > np.quantile(distance_patches, 0.99))
    for idx in np.random.default_rng(seed=42).permutation(lowest_indices)[:3]:
        image = ImageFolder(Path(dir) / group_df.iloc[idx[0]].file)[0][0]
        image_patches = extract_patches(
            image.unsqueeze(0), size=128, stride=64
        ).squeeze(0)
        selected_patch = image_patches[idx[1]]
        grid[i].imshow(np.array(to_pil_image(selected_patch)))
        i += 1
    grid[i - 1].set_xlabel(nice_dir)
grid[0].set_xticks([])
grid[0].set_yticks([])
plt.savefig(output_dir / "highest_percent.pdf")
plt.close()