Aqui veremos que pasa si calculamos la correlación por batches. cuanto cambia

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy.signal import convolve2d
from scipy.stats import pearsonr
from ast import literal_eval
from tqdm.notebook import tqdm

import pydicom as dcm
from pydicom.pixel_data_handlers.util import apply_voi_lut

import torch
from torch.nn.functional import interpolate
from torchvision.transforms import v2 as transforms

from FindClf import ImageIO
from FindClf.Models import create_efficientNetV2
from FindClf.DetectorOps import Detector
from CorrRELAX.Transforms import get_transforms
from CorrRELAX.WindowOps import get_windows
from CorrRELAX.Algorithm import CorrRelax

In [None]:
## PARAMETROS
LABEL_NAMES = [
    "No Finding",
    "Mass",
    "Suspicious Calcification",
    "Asymmetries",
    "Architectural Distortion",
    "Suspicious Lymph Node",
    "Skin Thickening",
    "Retractions",
]

batch_size = 32
window_size = 256
window_shape = (window_size, window_size)
stride = 32
mask_size = 8
mask_shape = (mask_size, mask_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Abrir la imagen de evaluación, buscando sus datos dentro de la tabla de anotaciones

In [None]:
df = pd.read_csv("finding_annotations_V2.csv")
image_id = "531bba59b58ee255662e46898934195e"
sample = df.groupby("image_id").get_group(image_id)
study_id = sample["study_id"].values[0]
sample.head()

In [None]:
impath = os.path.join("path/to/dataset", study_id, image_id + ".dicom")
if not os.path.exists(impath):
    raise FileNotFoundError(f"Imagen {impath} no encontrada")

fig, ax = plt.subplots(ncols=2)

image = ImageIO.load_dicom(impath)
ax[0].imshow(image, cmap="gray")
ax[0].set_title(f"Original ({image.dtype})")
image = ImageIO.clahefusion((image * 255).astype(np.uint8), thresholds=[1.0, 2.0])
ax[1].imshow(image)
ax[1].set_title(f"CLAHE ({image.dtype})")


aplicamos el detector

In [None]:
detector = detector = Detector("ROIdetector/roidet_mammo.pth", device=device)
xmin, ymin, xmax, ymax = detector.get_roi(image)

fig, ax = plt.subplots()
ax.imshow(image, cmap="gray")
rect = Rectangle(
    (xmin, ymin), xmax - xmin, ymax - ymin, linewidth=1, edgecolor="r", facecolor="none"
)
ax.add_patch(rect)
ax.set_title("ROI Detection")

In [None]:
cropped = image[ymin:ymax, xmin:xmax]
fig, ax = plt.subplots()
ax.imshow(cropped, cmap="gray")
ax.set_title("Cropped Image")

Cargamos el clasificador

In [None]:
# modelpath = 'models/findclfV4.1_EfficientNetV2.pth'
modelpath = "checkpoints/20241106_214840/EfficientNetV2_final.pth"
model = create_efficientNetV2(len(LABEL_NAMES))
state_dict = torch.load(modelpath, weights_only=True, map_location=device)
model.load_state_dict(state_dict["model_state_dict"])
model.to(device)
model.eval()

print("Modelo cargado")

In [None]:
transforms = get_transforms(window_shape)

Dividimos la imagen en ventanas

In [None]:
windows, (dx, dy) = get_windows(cropped, window_size, stride)
ndx, ndy = len(dx), len(dy)
print(f"Tenemos en total {len(windows)} ventanas ({ndx}x{ndy})")


In [None]:
# seleccionamos una ventana aleatoria
n = 2550  # np.random.randint(len(windows))
window = windows[n]

fig, ax = plt.subplots()
ax.imshow(window)

Vamos a construir CorRELAX desde cero, porque vamos a tener que obtener todas las distancias y ver cuanto cambia si tomamos promedios de submuestras

In [None]:
# necesitamos nuestro generador de mascaras
def make_masks(n_iters, mask_shape, p=0.5):
    for _ in range(n_iters):
        masks = (torch.rand(batch_size, 1, *mask_shape, device=device) > p).float()
        interp = interpolate(
            masks, size=window_shape, mode="bilinear", align_corners=False
        )
        yield interp


In [None]:
n_masks = 2560
n_iters = n_masks // batch_size
print(f"Generando {n_masks} mascaras de {mask_shape} en {n_iters} iteraciones")

# almacenamos los resultados de las distancias
full_results = {name: [] for name in ["distVect", "predVect", "prediction"]}
batch_results = {name: [] for name in ["corrDist", "corrPred"]}
distance_fn = torch.nn.CosineSimilarity(dim=1)  # Distancia Coseno

In [None]:
batch_size = 32  # Vamos a evaluar a diferentes batches y vamos a registrarlos
with torch.no_grad():
    # Preprocesamos la ventana
    window_tensor = torch.from_numpy(window).to(device)
    window_tensor = transforms(window_tensor)

    ## CorRELAX
    # debemos evaluar la imagen con el modelo
    model_output = model(window_tensor)
    image_pred = torch.sigmoid(model_output["Classifier"])  # predicción
    image_vect = model_output["Features"]  # obtenemos las features internas

    # Ahora, debemos evaluar las versiones enmascaradas de la imagen
    for masks in tqdm(make_masks(n_iters, mask_shape), total=n_iters):
        # aplicamos la mascara
        masked_window = window_tensor * masks
        masked_output = model(masked_window)
        # Obtenemos la predicción y las features internas para cada subset de mascaras
        masked_pred = torch.sigmoid(masked_output["Classifier"])
        masked_vect = masked_output["Features"]

        # Calculamos la distancia entre los features originales y enmascarados
        dist_vectors = distance_fn(image_vect, masked_vect).squeeze()
        full_results["distVect"].append(dist_vectors)

        # Distancia entre predicciones originales y enmascaradas
        pred_vectors = distance_fn(image_pred, masked_pred).squeeze()
        full_results["predVect"].append(pred_vectors)

        # y Acumulamos las predicciones de las mascaras
        full_results["prediction"].append(masked_pred)

        # -- Calcularemos la correlación en cada batch
        corr_dist = torch.corrcoef(torch.stack([dist_vectors, pred_vectors], dim=0))[
            0, 1
        ]
        corr_pred = torch.corrcoef(
            torch.cat([dist_vectors.unsqueeze(1), masked_pred], dim=1).T
        )[0, 1:]

        batch_results["corrDist"].append(corr_dist.detach().item())
        batch_results["corrPred"].append(corr_pred.detach().tolist())

    # Concatenamos los resultados
    full_results["distVect"] = (
        torch.cat(full_results["distVect"], dim=0).detach().cpu().numpy()
    )
    full_results["predVect"] = (
        torch.cat(full_results["predVect"], dim=0).detach().cpu().numpy()
    )
    full_results["prediction"] = (
        torch.cat(full_results["prediction"], dim=0).detach().cpu().numpy()
    )

    batch_results["corrDist"] = np.nan_to_num(np.array(batch_results["corrDist"]))
    batch_results["corrPred"] = np.nan_to_num(np.array(batch_results["corrPred"]))

    # Queda pendiente la correlación de las distancias y las predicciones para evaluar importancia de las features.

# print(results['distVect'].shape, results['predVect'].shape, results['prediction'].shape)

In [None]:
batch_results["corrDist"].shape, batch_results["corrPred"].shape

In [None]:
k

In [None]:
k = int(1 + np.log2(len(batch_results["corrDist"])))

fig, ax = plt.subplot_mosaic([["meandist", "meanpred"]])
ax["meandist"].hist(batch_results["corrDist"], bins=3 * k, range=(0, 1))
ax["meanpred"].matshow(batch_results["corrPred"], cmap="seismic", vmin=-1, vmax=1)

In [None]:
print(
    f"Media = {batch_results['corrDist'].mean():.4f} ± {batch_results['corrDist'].std():.4f}"
)
for i, name in enumerate(LABEL_NAMES):
    print(
        f"CorrPred({name:^24}) = {batch_results['corrPred'].mean(axis=0)[i]:.4f} ± {batch_results['corrPred'].std(axis=0)[i]:.4f}"
    )

Correlación de toda la población

In [None]:
## Correlaciones

# Correlación entre distancia de features y distancia de predicciones
full_corr_dist = pearsonr(full_results["distVect"], full_results["predVect"]).statistic

# Correlación entre distancia de features y predicciones de mascaras
full_corr_pred = pearsonr(
    full_results["distVect"][..., np.newaxis], full_results["prediction"]
).statistic

print(f"2560 mascaras\nCorrDist = {full_corr_dist:.4f}")
for i, name in enumerate(LABEL_NAMES):
    print(f"CorrPred [{name:^24}] = {full_corr_pred[i]:.4f}")

Ahora, que pasa si procesamos con mas o menos batches, cuanto cambia el resultado...



| Batches  | CorDist         | CorrPred NFind  | CorrPred Asymm  | Tiempo (it/s) | %datos/s |
|---------:|:---------------:|:---------------:|:---------------:|:-------------:|:-------:|
| Full     | 0.7725          | 0.9224          | 0.0539          | -             | -       |
| 96       | 0.7867 ± 0.0302 | 0.9238 ± 0.0130 | 0.0854 ± 0.0752 |  6.14         | 23.025  |
| 64       | 0.7879 ± 0.0412 | 0.9255 ± 0.0184 | 0.0638 ± 0.1327 |  8.64         | 21.600  |
| 48       | 0.7922 ± 0.0410 | 0.9281 ± 0.0152 | 0.0957 ± 0.1166 | 11.25         | 21.093  |
| 36       | 0.8026 ± 0.0483 | 0.9237 ± 0.0203 | 0.0676 ± 0.1433 | 15.48         | 21.768  |
| 32       | 0.7956 ± 0.0576 | 0.9267 ± 0.0225 | 0.0570 ± 0.1702 | 16.93         | 21.162  |
| 24       | 0.8060 ± 0.0567 | 0.9292 ± 0.0227 | 0.0847 ± 0.1917 | 22.16         | 20.775  |
| 16       | 0.8111 ± 0.0709 | 0.9272 ± 0.0343 | 0.0956 ± 0.2503 | 33.12         | 20.700  |
| 12       | 0.8277 ± 0.0860 | 0.9284 ± 0.0378 | 0.1106 ± 0.2750 | 44.74         | 20.972  |

In [None]:
batch_size = 96  # Vamos a evaluar a diferentes batches y vamos a registrarlos
n_iters = n_masks // batch_size
print(f"Generando {n_masks} mascaras de {mask_shape} en {n_iters} iteraciones")
batch_results = {name: [] for name in ["corrDist", "corrPred"]}

In [None]:
with torch.no_grad():
    # Preprocesamos la ventana
    window_tensor = torch.from_numpy(window).to(device)
    window_tensor = transforms(window_tensor)
    ## CorRELAX
    # debemos evaluar la imagen con el modelo
    model_output = model(window_tensor)
    image_pred = torch.sigmoid(model_output["Classifier"])  # predicción
    image_vect = model_output["Features"]  # obtenemos las features internas

    # Ahora, debemos evaluar las versiones enmascaradas de la imagen
    for masks in tqdm(make_masks(n_iters, mask_shape), total=n_iters):
        # aplicamos la mascara
        masked_window = window_tensor * masks
        masked_output = model(masked_window)
        # Obtenemos la predicción y las features internas para cada subset de mascaras
        masked_pred = torch.sigmoid(masked_output["Classifier"])
        masked_vect = masked_output["Features"]

        # Calculamos la distancia entre los features originales y enmascarados
        dist_vectors = distance_fn(image_vect, masked_vect).squeeze()
        # full_results['distVect'].append(dist_vectors)

        # Distancia entre predicciones originales y enmascaradas
        pred_vectors = distance_fn(image_pred, masked_pred).squeeze()
        # full_results['predVect'].append(pred_vectors)

        # y Acumulamos las predicciones de las mascaras
        # full_results['prediction'].append(masked_pred)

        # -- Calcularemos la correlación en cada batch
        corr_dist = torch.corrcoef(torch.stack([dist_vectors, pred_vectors], dim=0))[
            0, 1
        ]
        corr_pred = torch.corrcoef(
            torch.cat([dist_vectors.unsqueeze(1), masked_pred], dim=1).T
        )[0, 1:]

        batch_results["corrDist"].append(corr_dist.detach().item())
        batch_results["corrPred"].append(corr_pred.detach().tolist())

    batch_results["corrDist"] = np.nan_to_num(np.array(batch_results["corrDist"]))
    batch_results["corrPred"] = np.nan_to_num(np.array(batch_results["corrPred"]))

In [None]:
print(f"Prueba Batch {batch_size}")
print(
    f"Media = {batch_results['corrDist'].mean():.4f} ± {batch_results['corrDist'].std():.4f}"
)
for i, name in enumerate(LABEL_NAMES):
    print(
        f"CorrPred({name:^24}) = {batch_results['corrPred'].mean(axis=0)[i]:.4f} ± {batch_results['corrPred'].std(axis=0)[i]:.4f}"
    )

k = int(1 + np.log2(len(batch_results["corrDist"])))

fig, ax = plt.subplot_mosaic([["meandist", "meanpred"]])
ax["meandist"].hist(batch_results["corrDist"], bins=3 * k, range=(0, 1))
ax["meanpred"].matshow(batch_results["corrPred"], cmap="seismic", vmin=-1, vmax=1)

In [None]:
subsamples = np.linspace(0, 2560, 21, dtype=int)
all_corrdists, all_corrpreds = [], []

for n_subsamples in subsamples[1:]:
    # Tomamos un subconjunto de las distancias y predicciones
    sub_distvect = results["distVect"][:n_subsamples]
    sub_predvect = results["predVect"][:n_subsamples]
    sub_prediction = results["prediction"][:n_subsamples]

    # Correlación entre distancia de features y distancia de predicciones
    corr_dist = pearsonr(sub_distvect, sub_predvect).statistic
    # Correlación entre distancia de features y predicciones de mascaras
    corr_pred = pearsonr(sub_distvect[..., np.newaxis], sub_prediction).statistic

    # Almacenamos los resultados
    all_corrdists.append(corr_dist)
    all_corrpreds.append(corr_pred)

In [None]:
plt.plot(subsamples[1:], np.abs(all_corrdists - full_corr_dist), label="CorrDist")

Ahora vamos a evaluar con todas las ventanas, solo la distancia de correlación.
Esto nos permitirá tener al menos una idea de como se comporta en promedio para una imagen.
Finalmente, si combinamos para multiples imagenes podriamos recrear la figura que hice antes

In [None]:
subsamples = np.linspace(0, 2560, 21, dtype=int)
windows_corrdists = []

In [None]:
for window in tqdm(windows):
    results = {name: [] for name in ["distVect", "predVect"]}
    with torch.no_grad():
        # Para evitar procesar imagenes vacias
        if np.sum(window) == 0:
            continue

        # Preprocesamos la ventana
        window_tensor = torch.from_numpy(window).to(device)
        window_tensor = transforms(window_tensor)

        model_output = model(window_tensor)
        image_pred = torch.sigmoid(model_output["Classifier"])  # predicción
        image_vect = model_output["Features"]  # obtenemos las features internas

        for masks in tqdm(make_masks(n_iters, mask_shape), total=n_iters, leave=False):
            # aplicamos la mascara
            masked_window = window_tensor * masks
            masked_output = model(masked_window)
            # Obtenemos la predicción y las features internas para cada subset de mascaras
            masked_pred = torch.sigmoid(masked_output["Classifier"])
            masked_vect = masked_output["Features"]

            # Calculamos la distancia entre los features originales y enmascarados
            dist_vectors = distance_fn(image_vect, masked_vect).squeeze()
            results["distVect"].append(dist_vectors)

            # Distancia entre predicciones originales y enmascaradas
            pred_vectors = distance_fn(image_pred, masked_pred).squeeze()
            results["predVect"].append(pred_vectors)

        # Concatenamos los resultados
        results["distVect"] = (
            torch.cat(results["distVect"], dim=0).detach().cpu().numpy()
        )
        results["predVect"] = (
            torch.cat(results["predVect"], dim=0).detach().cpu().numpy()
        )

    # Correlacion que evaluaremos
    # Correlación entre distancia de features y distancia de predicciones
    full_corr_dist = pearsonr(results["distVect"], results["predVect"]).statistic

    # Submuestras
    corr_dists = []
    for n_subsamples in subsamples[1:]:
        # Tomamos un subconjunto de las distancias y predicciones
        sub_distvect = results["distVect"][:n_subsamples]
        sub_predvect = results["predVect"][:n_subsamples]
        # Correlación entre distancia de features y distancia de predicciones
        corr_dist = pearsonr(sub_distvect, sub_predvect).statistic
        corr_dists.append(corr_dist)

    windows_corrdists.append(corr_dists)

In [None]:
windows_corrdists = np.array(windows_corrdists)

plot

In [None]:
logscale = True
if logscale:
    diffs = np.abs(windows_corrdists - windows_corrdists[:, -1][..., np.newaxis])
else:
    diffs = windows_corrdists - windows_corrdists[:, -1][..., np.newaxis]
boxwidth = 80
rng_width = 0.45 * boxwidth

In [None]:
fig, ax = plt.subplots(figsize=(15, 9))

# Como aparecen nans vamos a filtrar a la mala
nanmask = ~np.isnan(diffs)
filtered_diffs = [d[m] for d, m in zip(diffs.T, nanmask.T)]

bp = ax.boxplot(
    filtered_diffs,
    positions=subsamples[1:],
    showfliers=False,
    widths=boxwidth,
    boxprops=dict(color="darkblue", linewidth=1.5),
    whiskerprops=dict(color="darkblue", linewidth=1.5),
    medianprops=dict(color="crimson", linewidth=2),
    capprops=dict(color="darkblue", linewidth=2),
    zorder=1,
)

for diff, subsamp in zip(diffs.T, subsamples[1:]):
    ax.scatter(
        np.random.uniform(subsamp - rng_width, subsamp + rng_width, size=len(diff)),
        diff,
        alpha=0.25,
        marker="x",
        s=1,
        color="slategray",
        zorder=-1,
    )

# Configuración
# ax.set_title('Correlation difference within image at different mask samples')
ax.set_title(
    "Diferencia de correlación dentro de imagen a diferentes muestras de mascaras",
    fontsize=22,
    fontname="Swis721 BT",
    wrap=True,
)

# ax.set_xlabel('Number of masks')
ax.set_xlabel("Número de mascaras", fontsize=16, fontname="Swis721 BT")
ax.set_xticklabels(
    subsamples[1:], rotation=45, ha="center", fontsize=12
)  # rotar la etiqueta de x
# ax.set_ylabel('Correlation difference with full set')
ax.set_ylabel(
    "Diferencia de correlación con el conjunto completo",
    fontsize=14,
    fontname="Swis721 BT",
    wrap=True,
)
ax.set_ylim(-0.2, 0.2)

# Lineas de referencia
ax.axhline(
    0.05,
    color="coral",
    linestyle="--",
    linewidth=2.5,
    alpha=0.75,
    zorder=0,
    label=r"$\pm 0.05$",
)
# ax.axhline(-0.05, color='coral', linestyle='--', alpha=0.75, zorder=0)

ax.axhline(
    0.025,
    color="firebrick",
    linestyle="--",
    linewidth=2.5,
    alpha=0.75,
    zorder=0,
    label=r"$\pm 0.025$",
)
# ax.axhline(-0.025, color='firebrick', linestyle='--', alpha=0.75, zorder=0)

# IDEA: como plotear un boxplot con eje y logaritmico
if logscale:
    # ax.set_ylabel('Correlation difference with full set (Log Scale)')
    ax.set_ylabel(
        "Diferencia de correlación con el conjunto completo (Escala Logarítmica)",
        fontsize=16,
        fontname="Swis721 BT",
        wrap=True,
    )
    ax.tick_params(axis="y", labelsize=12, labelfontfamily="Swis721 BT")
    ax.set_yscale("log")
    ax.set_ylim(0.0005, 1.001)

ax.legend()
# fig.savefig('06-LogCorrDiffs.jpg', dpi=300, bbox_inches='tight')
fig.savefig("plots_tesis/LogCorrDiffs.pdf", dpi=300, bbox_inches="tight")

In [None]:
mean_diff = np.abs(windows_corrdists - windows_corrdists[:, -1][..., np.newaxis]).mean(
    axis=0
)
plt.plot(subsamples[1:], mean_diff, label="MeanDiff")

In [None]:
plt.plot(
    subsamples[1:],
    np.array(windows_corrdists).mean(axis=0),
    label="CorrDist (Ventanas)",
)