## Interactive evaluation of classification model

To check where the existing models fail and try to understand why

In [None]:
import torch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

from protein_classification.config import AlgorithmConfig, DataConfig, DataAugmentationConfig
from protein_classification.data import InMemoryDataset, ZarrDataset
from protein_classification.data.cellatlas import get_cellatlas_filepaths_and_labels
from protein_classification.data.preprocessing import ZarrPreprocessor
from protein_classification.data.utils import train_test_split, collate_test_time_crops
from protein_classification.model import BioStructClassifier
from protein_classification.utils.evaluation import compute_classification_metrics
from protein_classification.utils.io import load_config, load_checkpoint

torch.set_float32_matmul_precision('medium')

In [None]:
CKPT_DIR = "/group/jug/federico/classification_training/2507/DenseNet121_5Cl_Mitochondria/4"

Get configs

In [None]:
algo_config = AlgorithmConfig(
    **load_config(
        config_fpath=CKPT_DIR, config_type="algorithm",
    )
)
algo_config.training_config.batch_size = 1 # Evaluate one sample at a time
data_config = DataConfig(
    **load_config(
        config_fpath=CKPT_DIR, config_type="data",
    )
)
data_config.test_augmentation_config = DataAugmentationConfig(
    transform=None,
    crop_size=data_config.train_augmentation_config.crop_size,
    random_crop=True,
    strategy="background",
    metrics=["std"],
    bg_threshold=3.0, # Default threshold for background crops
)

Get Data

In [None]:
input_data, curr_labels = get_cellatlas_filepaths_and_labels(
    data_dir=data_config.data_dir, protein_labels=data_config.labels,
)
_, test_input_data = train_test_split(
    input_data, train_ratio=0.9, deterministic=True
)
print("--------------Dataset Info--------------")
print(f"Number test samples: {len(test_input_data)}")
print(f"Labels: {curr_labels}")
print("----------------------------------------\n")


In [None]:
test_dataset = InMemoryDataset(
    inputs=test_input_data,
    split="test",
    return_label=True,
    img_size=data_config.img_size,
    augmentation_config=data_config.test_augmentation_config,
    bit_depth=data_config.bit_depth,
    normalize=data_config.normalize,
    dataset_stats=data_config.dataset_stats,
)

In [None]:
test_dloader = DataLoader(
test_dataset,
batch_size=algo_config.training_config.batch_size,
shuffle=False,
num_workers=3,
pin_memory=True,
drop_last=False,
collate_fn=(
    collate_test_time_crops 
    if data_config.test_augmentation_config.strategy == "overlap" else None
),
)

Setup the model

In [None]:
model = BioStructClassifier(config=algo_config)
ckpt = load_checkpoint(ckpt_dir=CKPT_DIR, best=True)
model.load_state_dict(ckpt["state_dict"], strict=True)

Get predictions

In [None]:
trainer = Trainer(
    accelerator="gpu",
    enable_progress_bar=True,
    precision=32,
)

In [None]:
outputs = trainer.predict(model=model, dataloaders=test_dloader)
preds, probs, labels, inputs = [], [], [], []
for batch in outputs:
    batch_preds, batch_probs, batch_labels, batch_inputs = batch
    preds.append(batch_preds)
    probs.append(batch_probs)
    labels.append(batch_labels)
    inputs.append(batch_inputs)

Compute metrics

In [None]:
metrics = compute_classification_metrics(
    preds=torch.cat(preds),
    gts=torch.cat(labels),
    probs=torch.cat(probs),
    num_classes=len(curr_labels),
    average="macro",
)

In [None]:
print("\n------------------------------------------")
print("Accuracy:", metrics["accuracy"])
print("F1 (macro):", metrics["f1"])
print("Precision:", metrics["precision"])
print("Recall:", metrics["recall"])
print("Confusion Matrix:\n", metrics["confusion_matrix"])

### Debug: understand why the model is struggling in some cases

Check incorrectly labeled samples of class 1

In [None]:
right_idxs = torch.where((torch.tensor(labels) == 1) & (torch.tensor(preds) == 1))[0]
wrong_idxs = torch.where((torch.tensor(labels) == 1) & (torch.tensor(preds) == 3))[0]
right_idxs_cl3 = torch.where((torch.tensor(labels) == 3) & (torch.tensor(preds) == 3))[0]
len(right_idxs), len(wrong_idxs), len(right_idxs_cl3)

In [None]:
# Permute indexes to visualize the samples
right_idxs = right_idxs[torch.randperm(len(right_idxs))]
wrong_idxs = wrong_idxs[torch.randperm(len(wrong_idxs))]
right_idxs_cl3 = right_idxs_cl3[torch.randperm(len(right_idxs_cl3))]

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(3, 6, figsize=(30, 15), constrained_layout=True)
fig.patch.set_facecolor("black")
for i, idx in enumerate(right_idxs[:6]):
    axes[0, i].imshow(inputs[idx].squeeze(), cmap="gray")
    axes[0, i].text(
        0.95, 0.95, f"probs: {[round(p, 2) for p in probs[idx].squeeze().tolist()]}",
        transform=axes[0, i].transAxes, fontsize=14,
        verticalalignment='top', horizontalalignment='right',
        bbox=dict(facecolor='white', alpha=0.5, edgecolor='none')
    )
    axes[0, i].axis("off")

for i, idx in enumerate(wrong_idxs[:6]):
    axes[1, i].imshow(inputs[idx].squeeze(), cmap="gray")
    axes[1, i].axis("off")
    axes[1, i].text(
        0.95, 0.95, f"probs: {[round(p, 2) for p in probs[idx].squeeze().tolist()]}",
        transform=axes[1, i].transAxes, fontsize=14,
        verticalalignment='top', horizontalalignment='right',
        bbox=dict(facecolor='white', alpha=0.5, edgecolor='none')
    )

for i, idx in enumerate(right_idxs_cl3[:6]):
    axes[2, i].imshow(inputs[idx].squeeze(), cmap="gray")
    axes[2, i].axis("off")
    axes[2, i].text(
        0.95, 0.95, f"probs: {[round(p, 2) for p in probs[idx].squeeze().tolist()]}",
        transform=axes[2, i].transAxes, fontsize=14,
        verticalalignment='top', horizontalalignment='right',
        bbox=dict(facecolor='white', alpha=0.5, edgecolor='none')
    )

Visualize some whole images ER vs Mitochondria

In [None]:
import tifffile as tiff

In [None]:
er_imgs = []
mito_imgs = []
for fpath, label in input_data[:40]:
    if label == 1:  # Mitochondria
        mito_imgs.append(tiff.imread(fpath))
    elif label == 3:  # ER
        er_imgs.append(tiff.imread(fpath))
        
print(f"Number of ER images: {len(er_imgs)}")
print(f"Number of Mitochondria images: {len(mito_imgs)}")

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(25, 10), constrained_layout=True)
fig.patch.set_facecolor("black")
for i, img in enumerate(er_imgs[:5]):
    axes[0, i].imshow(img.squeeze()[:512, :512], cmap="gray")
    axes[0, i].axis("off")

for i, img in enumerate(mito_imgs[:5]):
    axes[1, i].imshow(img.squeeze()[:512, :512], cmap="gray")
    axes[1, i].axis("off")