<center>
    <p align="center">
        <img src="https://logodownload.org/wp-content/uploads/2017/09/mackenzie-logo-3.png" style="height: 7ch;"><br>
        <h1 align="center">Computer Systems Undergradute Thesis</h1>
        <h2 align="center">Quantitative Analysis of the Impact of Image Pre-Processing on the Accuracy of Computer Vision Models Trained to Identify Dermatological Skin Diseases</a>
        <h4 align="center">Gabriel Mitelman Tkacz</a>
        </h4>
    </p>
</center>

<hr>

In [1]:
import math
import re
import tomllib
from functools import partial
from itertools import permutations
from pprint import pprint

import dill
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torchvision.transforms.functional as TF
from matplotlib.font_manager import FontProperties, fontManager
from PIL import Image
from pynimbar import loading_animation

from util import (
	ColorSpaceTransform,
	DenoiseTransform,
	EqualizationTransform,
	NormalizeTransform,
	evaluate_model,
	get_model_data,
)

In [2]:
with open("parameters.toml") as f:
	parameters = tomllib.loads(f.read())

loading_handler = partial(loading_animation, break_on_error=True, verbose_errors=True, time_it_live=True)

alpha = chr(0x03B1)

pd.set_option("display.max_rows", 500)

pprint(parameters)

{'PREPROCESS': {'colorspace': {'high_con': {'source_space': 'RGB',
                                            'target_space': 'LAB'},
                               'source_space': 'RGB',
                               'target_space': 'HSV'},
                'denoise': {'high_con': {'search_window_size': 24,
                                         'template_window_size': 6},
                            'search_window_size': 19,
                            'template_window_size': 5},
                'normalize': {'high_con': {'mean': 0.3, 'std': 0.75},
                              'mean': 0.4,
                              'std': 0.2}},
 'TRAINING': {'batch_size': 128,
              'diseased_skin_path': './dataset/diseased/',
              'healthy_skin_path': './dataset/healthy/',
              'learning_rate': 0.0001,
              'num_epochs': 3,
              'num_workers': 12,
              'pin_memory': True,
              'precision_threshold': 0.8,
              'resize_dim

In [3]:
preprocesses = (
	ColorSpaceTransform(**parameters["PREPROCESS"]["colorspace"]["high_con"]),
	DenoiseTransform(**parameters["PREPROCESS"]["denoise"]["high_con"]),
	EqualizationTransform(),
	NormalizeTransform(**parameters["PREPROCESS"]["normalize"]["high_con"]),
)

preprocess_combinations = {i: permutations(preprocesses, i) for i in range(1, len(preprocesses) + 1)}

preprocess_labels = {s.__class__.__name__: re.sub(r"[^A-Z]", "", s.__class__.__name__)[:-1] for s in preprocesses}

for i in preprocess_combinations[4]:
    print(i)

(ColorSpaceTransform(), DenoiseTransform(), EqualizationTransform(), <util.preprocessing.NormalizeTransform object at 0x7bf653b36930>)
(ColorSpaceTransform(), DenoiseTransform(), <util.preprocessing.NormalizeTransform object at 0x7bf653b36930>, EqualizationTransform())
(ColorSpaceTransform(), EqualizationTransform(), DenoiseTransform(), <util.preprocessing.NormalizeTransform object at 0x7bf653b36930>)
(ColorSpaceTransform(), EqualizationTransform(), <util.preprocessing.NormalizeTransform object at 0x7bf653b36930>, DenoiseTransform())
(ColorSpaceTransform(), <util.preprocessing.NormalizeTransform object at 0x7bf653b36930>, DenoiseTransform(), EqualizationTransform())
(ColorSpaceTransform(), <util.preprocessing.NormalizeTransform object at 0x7bf653b36930>, EqualizationTransform(), DenoiseTransform())
(DenoiseTransform(), ColorSpaceTransform(), EqualizationTransform(), <util.preprocessing.NormalizeTransform object at 0x7bf653b36930>)
(DenoiseTransform(), ColorSpaceTransform(), <util.prepr

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
training_ratio = parameters["TRAINING"]["training_dataset_ratio"]
testing_ratio = validation_ratio = round(1 - training_ratio, 1) / 2

print(f"Training ratio: {training_ratio * 100}%")
print(f"Testing ratio: {testing_ratio * 100}%")
print(f"Validation ratio: {validation_ratio * 100}%")

seed = 47
print(f"\nSeed: {seed}")

Training ratio: 80.0%
Testing ratio: 10.0%
Validation ratio: 10.0%

Seed: 47


In [6]:
font_path = "./fonts/Times-New-Roman.otf"
fontManager.addfont(font_path)

prop = FontProperties(fname=font_path)

sns.set_context("notebook")
sns.set_theme(
	font=prop.get_name(),
	style="whitegrid",
	palette="deep",
	rc={"font.size": 12, "axes.titlesize": 20, "axes.labelsize": 18, "xtick.labelsize": 16, "ytick.labelsize": 16},
)

In [None]:
IMAGE_PATH = "./dataset/diseased/ISIC_0024319.jpg"
N_COLS = 3

base_pil = Image.open(IMAGE_PATH).convert("RGB")
base_tensor = TF.to_tensor(base_pil)

all_perms = [()]
for comb_iter in preprocess_combinations.values():
	all_perms.extend(list(comb_iter))

images, captions = [], []
for perm in all_perms:
	img_t = base_tensor
	for proc in perm:
		img_t = proc(img_t)

	images.append(TF.to_pil_image(img_t))

	if perm:
		labels = [preprocess_labels[p.__class__.__name__] for p in perm]
		captions.append(" → ".join(labels))
	else:
		captions.append("Original")

n_images = len(images)
n_rows = math.ceil(n_images / N_COLS)
fig, axes = plt.subplots(n_rows, N_COLS, figsize=(4 * N_COLS, 4 * n_rows))
fig.patch.set_alpha(0)
axes = axes.flatten()

for ax, img, title in zip(axes, images, captions, strict=False):
	ax.set_title("")
	ax.imshow(img)
	ax.axis("off")
	ax.text(
		0.5,
		-0.1,
		title,
		size=25,
		ha="center",
		transform=ax.transAxes,
	)


for ax in axes[n_images:]:
	ax.axis("off")

plt.tight_layout()
plt.savefig("preprocesses.png", transparent=True)
plt.show()

## Class 0 Model: Images with no pre-processing

In [None]:
(
	base_train_loader,
	base_test_loader,
	base_validation_loader,
) = get_model_data(
	training_ratio=training_ratio,
	testing_ratio=testing_ratio,
	validation_ratio=validation_ratio,
	seed=seed,
)

base_precision, base_confusion_matrix, base_training_time = evaluate_model(
	device, base_train_loader, base_test_loader, base_validation_loader,
)

print(f"Base precision: {base_precision * 100:.1f}%")

if base_precision < parameters["TRAINING"]["precision_threshold"]:
	raise ValueError("The base model did not meet the precision threshold.")

base_confusion_matrix

## Class 1 Models: Images with only one pre-process

### Class 1.1 Models: Normalizing the image

In [None]:
(
	normalize_train_loader,
	normalize_test_loader,
	normalize_validation_loader,
) = get_model_data(
	[NormalizeTransform(**parameters["PREPROCESS"]["normalize"]["high_con"])],
	training_ratio=training_ratio,
	testing_ratio=testing_ratio,
	validation_ratio=validation_ratio,
	seed=seed,
)

normalize_precision, normalize_confusion_matrix, normalize_training_time = evaluate_model(
	device, normalize_train_loader, normalize_test_loader, normalize_validation_loader,
)

normalize_precision_diff = normalize_precision - base_precision

print(f"\n\nNormalized precision: {normalize_precision * 100:.1f}%")
print(
	f"That is an {'upgrade' if normalize_precision_diff > 0 else 'downgrade'} of {normalize_precision_diff * 100:.1f}%.",
)

In [None]:
df = pd.read_json("./params/normalize_high_con.json")

df["precision"] = df["precision"].apply(lambda x: x[0])

df["precision"] = df["precision"] - base_precision

pivot_table = df.pivot(index="std", columns="mean", values="precision")

pivot_table = pivot_table.sort_index().sort_index(axis=1).iloc[::-1]

pivot_table_pct = pivot_table * 100

vabs = max(abs(pivot_table_pct.min().min()), abs(pivot_table_pct.max().max()))

plt.figure(figsize=(8, 6))

sns.heatmap(
	pivot_table_pct,
	annot=True,
	fmt=".1f",
	cmap="RdYlGn",
	cbar_kws={"label": alpha, "format": "%.0f%%"},
	vmin=-vabs,
	vmax=vabs,
)

plt.xticks(rotation=0)
plt.yticks(rotation=0)
plt.gca().set_xticklabels([f"{x:.1f}" for x in pivot_table_pct.columns])
plt.gca().set_yticklabels([f"{y:.1f}" for y in pivot_table_pct.index])

plt.title("Correlation between Normalization Parameters and Model Precision")
plt.xlabel("Mean")
plt.ylabel("Standard Deviation")

plt.show()

### Class 1.2 Models: Denoising the image

In [None]:
(
	denoise_train_loader,
	denoise_test_loader,
	denoise_validation_loader,
) = get_model_data(
	[DenoiseTransform(**parameters["PREPROCESS"]["denoise"]["high_con"])],
	training_ratio=training_ratio,
	testing_ratio=testing_ratio,
	validation_ratio=validation_ratio,
	seed=seed,
)

denoise_precision, denoise_confusion_matrix, denoise_training_time = evaluate_model(
	device, denoise_train_loader, denoise_test_loader, denoise_validation_loader,
)

denoise_precision_diff = denoise_precision - base_precision

print(f"\n\nDenoised precision: {denoise_precision * 100:.1f}%")
print(f"That is an {'upgrade' if denoise_precision_diff > 0 else 'downgrade'} of {denoise_precision_diff * 100:.1f}%.")

In [None]:
df = pd.read_json("./params/denoise_high_con.json")

df["precision"] = df["precision"].apply(lambda x: x[0])

df["precision"] = df["precision"] - base_precision

pivot_table = df.pivot(index="search_window_size", columns="template_window_size", values="precision")

pivot_table = pivot_table.sort_index().sort_index(axis=1).iloc[::-1]

pivot_table_pct = pivot_table * 100

vabs = max(abs(pivot_table_pct.min().min()), abs(pivot_table_pct.max().max()))

plt.figure(figsize=(8, 6))

sns.heatmap(
	pivot_table_pct,
	annot=True,
	fmt=".1f",
	cmap="RdYlGn",
	cbar_kws={"label": alpha, "format": "%.0f%%"},
	vmin=-vabs,
	vmax=vabs,
)

plt.xticks(rotation=0)
plt.yticks(rotation=0)
plt.gca().set_xticklabels([f"{x:.0f}" for x in pivot_table_pct.columns])
plt.gca().set_yticklabels([f"{y:.0f}" for y in pivot_table_pct.index])

plt.title("Correlation between Denoising Parameters and Model Precision")
plt.xlabel("Template Window Size")
plt.ylabel("Search Window Size")

plt.show()

### Class 1.3 Models: Equalizing the image

In [None]:
(
	equalized_train_loader,
	equalized_test_loader,
	equalized_validation_loader,
) = get_model_data(
	[EqualizationTransform()],
	training_ratio=training_ratio,
	testing_ratio=testing_ratio,
	validation_ratio=validation_ratio,
	seed=seed,
)

equalized_precision, equalized_confusion_matrix, equalized_training_time = evaluate_model(
	device, equalized_train_loader, equalized_test_loader, equalized_validation_loader,
)

equalized_precision_diff = equalized_precision - base_precision

print(f"\n\nEqualized precision: {equalized_precision * 100:.1f}%")
print(
	f"That is an {'upgrade' if equalized_precision_diff > 0 else 'downgrade'} of {equalized_precision_diff * 100:.1f}%.",
)

### Class 1.4 Models: Changing the colorspace

In [None]:
(
	colorspace_train_loader,
	colorspace_test_loader,
	colorspace_validation_loader,
) = get_model_data(
	[ColorSpaceTransform(**parameters["PREPROCESS"]["colorspace"]["high_con"])],
	training_ratio=training_ratio,
	testing_ratio=testing_ratio,
	validation_ratio=validation_ratio,
	seed=seed,
)

colorspace_precision, colorspace_confusion_matrix, colorspace_training_time = evaluate_model(
	device, colorspace_train_loader, colorspace_test_loader, colorspace_validation_loader,
)

colorspace_precision_diff = colorspace_precision - base_precision

print(f"\n\nColorspaced precision: {colorspace_precision * 100:.1f}%")
print(
	f"That is an {'upgrade' if colorspace_precision_diff > 0 else 'downgrade'} of {colorspace_precision_diff * 100:.1f}%.",
)

In [None]:
df = pd.read_json("./params/colorspace_high_con.json")

df["precision"] = df["precision"].apply(lambda x: x[0])

df["precision"] = df["precision"] - base_precision

pivot_table = df.set_index("target_space")

pivot_table = pivot_table.sort_index().sort_index(axis=1).iloc[::-1]

pivot_table_pct = pivot_table * 100

vabs = max(abs(pivot_table_pct.min().min()), abs(pivot_table_pct.max().max()))

plt.figure(figsize=(8, 6))

sns.heatmap(
	pivot_table_pct,
	annot=True,
	fmt=".1f",
	cmap="RdYlGn",
	cbar_kws={"label": alpha, "format": "%.0f%%"},
	vmin=-vabs,
	vmax=vabs,
)

plt.xticks(rotation=0)
plt.yticks(rotation=0)

plt.title("Correlation between Color Space Parameters and Model Precision")
plt.ylabel("Target Color Space")

plt.show()

In [None]:
class1_precisions = {
	NormalizeTransform.__name__: (normalize_precision, normalize_training_time),
	EqualizationTransform.__name__: (equalized_precision, equalized_training_time),
	DenoiseTransform.__name__: (denoise_precision, denoise_training_time),
	ColorSpaceTransform.__name__: (colorspace_precision, colorspace_training_time),
}

class1_df_data = [
	{
		"transform_1": k.split(", ")[0],
		"precision": v[0],
		alpha: v[0] - base_precision,
		"training_time": v[1],
		"training_time_ratio": v[1] / base_training_time,
	}
	for k, v in class1_precisions.items()
]
class1_df = pd.DataFrame(class1_df_data).sort_values(alpha, ascending=False).reset_index(drop=True)
class1_df.transform_1 = class1_df.transform_1.apply(lambda x: preprocess_labels[x])
class1_df

## Class 2 Models: Images with two pre-processes

In [None]:
class2_precisions = {}

for idx, combination in enumerate(preprocess_combinations[2]):
	(
		class2_train_loader,
		class2_test_loader,
		class2_validation_loader,
	) = get_model_data(
		combination,
		training_ratio=training_ratio,
		testing_ratio=testing_ratio,
		validation_ratio=validation_ratio,
		seed=seed,
	)

	curr_precision, curr_confusion_matrix, curr_training_time = evaluate_model(
		device, class2_train_loader, class2_test_loader, class2_validation_loader, verbose=False,
	)

	uuid = ", ".join([str(t.__class__.__name__) for t in combination])

	class2_precisions[uuid] = (curr_precision, curr_training_time)

	curr_precision_diff = curr_precision - base_precision

	print(f"\n\nClass 2.{idx + 1} {uuid} precision: {curr_precision * 100:.1f}%")
	print(f"That is an {'upgrade' if curr_precision_diff > 0 else 'downgrade'} of {curr_precision_diff * 100:.1f}%.")

In [None]:
class2_df_data = [
	{
		"transform_1": k.split(", ")[0],
		"transform_2": k.split(", ")[1],
		"precision": v[0],
		alpha: v[0] - base_precision,
		"training_time": v[1],
		"training_time_ratio": v[1] / base_training_time,
	}
	for k, v in class2_precisions.items()
]
class2_df = pd.DataFrame(class2_df_data).sort_values(alpha, ascending=False).reset_index(drop=True)
class2_df.transform_1 = class2_df.transform_1.apply(lambda x: preprocess_labels[x])
class2_df.transform_2 = class2_df.transform_2.apply(lambda x: preprocess_labels[x])
class2_df

In [None]:
grouped = class2_df.groupby(["transform_1", "transform_2"]).mean()
grouped[alpha] *= 100

vabs = max(abs(grouped.min().min()), abs(grouped.max().max()))

pivot = grouped[alpha].unstack()

plt.figure(figsize=(8, 6))
sns.heatmap(
	pivot, annot=True, fmt=".1f", cmap="RdYlGn", cbar_kws={"label": alpha, "format": "%.0f%%"}, vmax=vabs, vmin=-vabs,
)
plt.title("Model Precision by Transform Combinations")
plt.ylabel("Transform 1")
plt.xlabel("Transform 2")
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.show()

## Class 3 Models: Images with three pre-processes

In [None]:
class3_precisions = {}

for idx, combination in enumerate(preprocess_combinations[3]):
	(
		class3_train_loader,
		class3_test_loader,
		class3_validation_loader,
	) = get_model_data(
		combination,
		training_ratio=training_ratio,
		testing_ratio=testing_ratio,
		validation_ratio=validation_ratio,
		seed=seed,
	)

	curr_precision, curr_confusion_matrix, curr_training_time = evaluate_model(
		device, class3_train_loader, class3_test_loader, class3_validation_loader, verbose=False,
	)

	uuid = "→".join([str(t.__class__.__name__) for t in combination])

	class3_precisions[uuid] = (curr_precision, curr_training_time)

	curr_precision_diff = curr_precision - base_precision

	print(f"\n\nClass 3.{idx + 1} {uuid} precision: {curr_precision * 100:.1f}%")
	print(f"That is an {'upgrade' if curr_precision_diff > 0 else 'downgrade'} of {curr_precision_diff * 100:.1f}%.")

In [None]:
class3_df_data = [
	{
		"transform_1": k.split("→")[0],
		"transform_2": k.split("→")[1],
		"transform_3": k.split("→")[2],
		"precision": v[0],
		alpha: v[0] - base_precision,
		"training_time": v[1],
		"training_time_ratio": v[1] / base_training_time,
	}
	for k, v in class3_precisions.items()
]
class3_df = pd.DataFrame(class3_df_data).sort_values(alpha, ascending=False).reset_index(drop=True)
class3_df.transform_1 = class3_df.transform_1.apply(lambda x: preprocess_labels[x])
class3_df.transform_2 = class3_df.transform_2.apply(lambda x: preprocess_labels[x])
class3_df.transform_3 = class3_df.transform_3.apply(lambda x: preprocess_labels[x])
class3_df

In [None]:
order = preprocess_labels.values()

# Set up the matplotlib figure
plt.figure(figsize=(18, 6))

# Heatmap for Transform 1 vs Alpha
plt.subplot(1, 3, 1)
# Calculate mean alpha for each category in transform_1
pivot_t1 = class3_df.pivot_table(values=alpha, index="transform_1", aggfunc="mean").reindex(order)
# Transpose for heatmap
pivot_t1 = pivot_t1.T
sns.heatmap(pivot_t1, annot=True, cmap="RdYlGn", cbar=False)
plt.title("Impact of Transform 1 on Alpha")
plt.xlabel("Transform 1")
plt.yticks([])  # Hide y-axis labels

# Heatmap for Transform 2 vs Alpha
plt.subplot(1, 3, 2)
# Calculate mean alpha for each category in transform_2
pivot_t2 = class3_df.pivot_table(values=alpha, index="transform_2", aggfunc="mean").reindex(order)
# Transpose for heatmap
pivot_t2 = pivot_t2.T
sns.heatmap(pivot_t2, annot=True, cmap="RdYlGn", cbar=False)
plt.title("Impact of Transform 2 on Alpha")
plt.xlabel("Transform 2")
plt.yticks([])  # Hide y-axis labels

# Heatmap for Transform 3 vs Alpha
plt.subplot(1, 3, 3)
# Calculate mean alpha for each category in transform_3
pivot_t3 = class3_df.pivot_table(values=alpha, index="transform_3", aggfunc="mean").reindex(order)
# Transpose for heatmap
pivot_t3 = pivot_t3.T
sns.heatmap(pivot_t3, annot=True, cmap="RdYlGn", cbar=False)
plt.title("Impact of Transform 3 on Alpha")
plt.xlabel("Transform 3")
plt.yticks([])  # Hide y-axis labels

plt.tight_layout()
plt.show()

# Additional Heatmaps for Combined Transforms

# Heatmap for Transform 1 and Transform 2
# plt.figure(figsize=(12, 6))
# pivot_t1_t2 = class3_df.pivot_table(values=alpha, index='transform_1', columns='transform_2', aggfunc='mean').reindex(index=order, columns=order)
# sns.heatmap(pivot_t1_t2, annot=True, cmap='RdYlGn', linewidths=0.5, linecolor='gray')
# plt.title('Impact of Transform 1 and Transform 2 on Alpha')
# plt.xlabel('Transform 2')
# plt.ylabel('Transform 1')
# plt.show()

# # Heatmap for Transform 1 and Transform 3
# plt.figure(figsize=(12, 6))
# pivot_t1_t3 = class3_df.pivot_table(values=alpha, index='transform_1', columns='transform_3', aggfunc='mean').reindex(index=order, columns=order)
# sns.heatmap(pivot_t1_t3, annot=True, cmap='RdYlGn', linewidths=0.5, linecolor='gray')
# plt.title('Impact of Transform 1 and Transform 3 on Alpha')
# plt.xlabel('Transform 3')
# plt.ylabel('Transform 1')
# plt.show()

# # Heatmap for Transform 2 and Transform 3
# plt.figure(figsize=(12, 6))
# pivot_t2_t3 = class3_df.pivot_table(values=alpha, index='transform_2', columns='transform_3', aggfunc='mean').reindex(index=order, columns=order)
# sns.heatmap(pivot_t2_t3, annot=True, cmap='RdYlGn', linewidths=0.5, linecolor='gray')
# plt.title('Impact of Transform 2 and Transform 3 on Alpha')
# plt.xlabel('Transform 3')
# plt.ylabel('Transform 2')
# plt.show()

## Class 4 Models: Images with four pre-processes

In [None]:
class4_precisions = {}

for idx, combination in enumerate(preprocess_combinations[4]):
	(
		class4_train_loader,
		class4_test_loader,
		class4_validation_loader,
	) = get_model_data(
		combination,
		training_ratio=training_ratio,
		testing_ratio=testing_ratio,
		validation_ratio=validation_ratio,
		seed=seed,
	)

	curr_precision, curr_confusion_matrix, curr_training_time = evaluate_model(
		device, class4_train_loader, class4_test_loader, class4_validation_loader, verbose=False,
	)

	uuid = "➔".join([str(t.__class__.__name__) for t in combination])

	class4_precisions[uuid] = (curr_precision, curr_training_time)

	curr_precision_diff = curr_precision - base_precision

	print(f"\n\nClass 4.{idx + 1} {uuid} precision: {curr_precision * 100:.1f}%")
	print(f"That is an {'upgrade' if curr_precision_diff > 0 else 'downgrade'} of {curr_precision_diff * 100:.1f}%.")

In [None]:
class4_df_data = [
	{
		"transform_1": k.split("➔")[0],
		"transform_2": k.split("➔")[1],
		"transform_3": k.split("➔")[2],
		"transform_4": k.split("➔")[3],
		"precision": v[0],
		alpha: v[0] - base_precision,
		"training_time": v[1],
		"training_time_ratio": v[1] / base_training_time,
	}
	for k, v in class4_precisions.items()
]
class4_df = pd.DataFrame(class4_df_data).sort_values(alpha, ascending=False).reset_index(drop=True)
class4_df.transform_1 = class4_df.transform_1.apply(lambda x: preprocess_labels[x])
class4_df.transform_2 = class4_df.transform_2.apply(lambda x: preprocess_labels[x])
class4_df.transform_3 = class4_df.transform_3.apply(lambda x: preprocess_labels[x])
class4_df.transform_4 = class4_df.transform_4.apply(lambda x: preprocess_labels[x])
class4_df

In [None]:
analysis_df = (
	pd.concat([class1_df, class2_df, class3_df, class4_df], axis=0)
	.reset_index(drop=True)
	.fillna("-")[["transform_1", "transform_2", "transform_3", "transform_4", alpha, "training_time_ratio"]]
)
analysis_df[alpha] *= 100
analysis_df[f"w{alpha}"] = analysis_df[alpha] / analysis_df["training_time_ratio"]
analysis_df.head(100).sort_values(f"w{alpha}", ascending=False)

In [None]:
IMAGE_PATH = "./dataset/diseased/ISIC_0024319.jpg"
N_COLS = 3

base_pil = Image.open(IMAGE_PATH).convert("RGB")
base_tensor = TF.to_tensor(base_pil)

images = [base_pil]
captions = ["Original"]

preprocess_map = {preprocess_labels[proc.__class__.__name__]: proc for proc in preprocesses}

transform_columns = ["transform_1", "transform_2", "transform_3", "transform_4"]

for _, row in analysis_df.sort_values(alpha, ascending=False).reset_index().iloc[:5][transform_columns].iterrows():
	img_t = base_tensor
	names = row.to_list()
	perm = [preprocess_map[name] for name in names if name in preprocess_map and name != "-"]

	for proc in perm:
		img_t = proc(img_t)

	images.append(TF.to_pil_image(img_t))
	if perm:
		labels = [preprocess_labels[p.__class__.__name__] for p in perm]
		captions.append(" → ".join(labels))
	else:
		captions.append("Original")

n_images = len(images)
n_rows = math.ceil(n_images / N_COLS)

fig, axes = plt.subplots(n_rows, N_COLS, figsize=(4 * N_COLS, 4 * n_rows))
fig.patch.set_alpha(0)
axes = axes.flatten()

for ax, img, title in zip(axes, images, captions, strict=False):
	ax.imshow(img)
	ax.axis("off")
	ax.text(0.5, -0.1, title, size=20, ha="center", transform=ax.transAxes)

# turn off any extra axes
for ax in axes[n_images:]:
	ax.axis("off")

plt.tight_layout()
plt.savefig("top5.png", transparent=True)
plt.show()

In [None]:
# dill.dump_session(datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_globalsave.pkl")
dill.load_session("checkpoints/2025-04-20_20-31-28_globalsave.pkl")