# Main 

## Check Setup

In [None]:
from helpers_functions.setup import *

path_archive = "archive"

path_train_val_list = "archive/train_val_list_NIH.txt"
path_test_list = "archive/test_list_NIH.txt"

path_all_data_csv = "archive/Data_Entry_2017.csv"

path_folder_images = "archive/images-224/images-224"

### Check structure

In [None]:
import json

path = path_archive
print(json.dumps(list_tree(path, max_depth=1), indent=2))


## Data Preprocessing

### Create class-label linking

In [None]:
from helpers_functions.multi_hot import *

class_label_str_to_idx, class_label_idx_to_str = create_class_mappings(path_all_data_csv)

print(class_label_str_to_idx)
print(class_label_idx_to_str)

### Multi-hot encoding

In [None]:
# Assume class_to_idx and idx_to_class are already created
image_to_multihot = create_image_multihot_mapping_from_dicts(path_all_data_csv, class_label_str_to_idx)

# Check the first image mapping
first_image = list(image_to_multihot.keys())[0]
print(first_image, image_to_multihot[first_image])


## Modeling

#### Make a Dataloader

In [None]:
from torchvision.models import ResNet18_Weights
from torchvision import transforms

weights = ResNet18_Weights.DEFAULT
preprocess = weights.transforms()

train_transforms = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.02, 0.02), scale=(0.9, 1.1)),
    # transforms.RandomHorizontalFlip(p=0.5),  # only if medically safe
    preprocess  # <- apply the pretrained model's resize/ToTensor/Normalize as one step
])

val_transforms = preprocess  # deterministic preprocess for validation/test

In [None]:
from torch.utils.data import Subset

class SubsetWithTransform(Subset):
    """
    Subset that overrides the transform used for samples by calling into the
    underlying dataset but temporarily swapping its transform.
    """
    def __init__(self, dataset, indices, transform=None):
        super().__init__(dataset, indices)
        self._transform = transform

    def __getitem__(self, idx):
        # idx is an index into the subset; map to dataset index
        dataset_idx = self.indices[idx]
        # save original transform and swap in ours
        orig_transform = getattr(self.dataset, "transform", None)
        self.dataset.transform = self._transform
        try:
            item = self.dataset[dataset_idx]
        finally:
            # restore original transform to avoid side-effects
            self.dataset.transform = orig_transform
        return item

In [None]:
from torchvision import transforms

IMAGE_SIZE = 224  # change to your model input

train_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomRotation(degrees=10),
    transforms.RandomAffine(degrees=0, translate=(0.02,0.02), scale=(0.9, 1.1)),
    # transforms.RandomHorizontalFlip(p=0.5),  # only if safe for your data
    transforms.ToTensor(),
    # Normalize for 3-channel images (ImageNet stats as example)
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

In [None]:
from torch.utils.data import Subset

class SubsetWithTransform(Subset):
    """
    Subset that overrides the transform used for samples by calling into the
    underlying dataset but temporarily swapping its transform.
    This is safe because we won't use train and val loaders concurrently on the same dataset object.
    """
    def __init__(self, dataset, indices, transform=None):
        super().__init__(dataset, indices)
        self._transform = transform

    def __getitem__(self, idx):
        # idx is an index into the subset; map to dataset index
        dataset_idx = self.indices[idx]
        # save original transform and swap in ours
        orig_transform = getattr(self.dataset, "transform", None)
        self.dataset.transform = self._transform
        try:
            item = self.dataset[dataset_idx]
        finally:
            # restore original transform to avoid side-effects
            self.dataset.transform = orig_transform
        return item

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

class MedicalImageDataset(Dataset):
    """
    PyTorch Dataset for medical images with precomputed multi-hot labels.

    Each sample returns:
        - image tensor
        - multi-hot label vector
        - image file name
    """
    def __init__(self, files_list_dir: str, img_dir: str, image_to_multihot: dict, transform=None):
        """
        Args:
            files_list_dir (str): Path to a .txt file with each row being an image name.
            img_dir (str): Directory containing the images.
            image_to_multihot (dict): Precomputed dictionary mapping image name â†’ multi-hot vector.
            transform (callable, optional): Optional transform to apply to images.
        """
        # Load image names from txt
        with open(files_list_dir, "r") as f:
            self.image_names = [line.strip() for line in f if line.strip()]

        self.img_dir = img_dir
        self.image_to_multihot = image_to_multihot
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # Ensure no infinite loop
        if idx >= 2 * len(self):
            raise FileNotFoundError(f"No valid images.")

        img_name = self.image_names[idx]
        img_path = os.path.join(self.img_dir, img_name)

        # Skip missing images
        if not os.path.exists(img_path):
            print(f"Warning: Image file '{img_path}' not found. Skipping to next index.")
            return self.__getitem__((idx + 1) % len(self))  # wrap around if at end


        # Load image
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        else:
            # Default conversion to tensor
            image = torch.from_numpy(np.array(image)).permute(2,0,1).float() / 255.0

        # Get multi-hot vector from precomputed dictionary
        multi_hot = torch.from_numpy(
            self.image_to_multihot.get(
                img_name,
                np.zeros(len(next(iter(self.image_to_multihot.values()))),
                dtype=np.float32)
            )
        )

        return image, multi_hot, img_name


# Create Dataset and DataLoader
dataset = MedicalImageDataset(path_train_val_list, path_folder_images, image_to_multihot)

loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Get the first batch
imgs, multi_hot, names = next(iter(loader))

print("Image batch shape:", imgs.shape)
print("Label batch shape:", multi_hot.shape)
print("First image name:", names[0])
print("First label vector:", multi_hot[0])

# Display the first image in the batch
plt.imshow(imgs[0].permute(1, 2, 0))
plt.title(f"Image: {names[0]}")
plt.axis('off')
plt.show()


#### Create CNN

In [None]:
from models.models import *

model_name = "Base"
model = models(model_name, backbone_name='resnet18')

print(model)

#### Train the CNN

In [None]:
from models.train import *
from models.val import *
import torch
import torch.optim as optim
from torch.utils.data import random_split, Subset
import os
import time

# Train Parameters
epochs = 10
val_every = 5
batch_size = 16
save = True
save_every = 1        # save checkpoint every `save_every` epochs
download_after_save = True   # set True in Colab to auto-download checkpoints to your local machine
device = "cuda" if torch.cuda.is_available() else "cpu"

# move model to device before creating optimizer
model = model.to(device)

num_label_classes = next(iter(image_to_multihot.values())).shape[0]
print("Label vector length =", num_label_classes)

# If the model has a .fc layer (ResNet-style)
if hasattr(model, "fc") and isinstance(model.fc, nn.Linear):
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_label_classes).to(device)
    print("Updated model.fc to output =", num_label_classes)
else:
    print("WARNING: model has no .fc; show me model structure if this fails.")

optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Split the dataset into train and val
val_ratio = 0.2
total_size = len(dataset)
val_size = int(total_size * val_ratio)
train_size = total_size - val_size
torch.manual_seed(42)  # For reproducibility

#train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

#NEW ADDITION FOR TRANSFORMATIONS
perm = torch.randperm(total_size).tolist()
train_idx = perm[:train_size]
val_idx   = perm[train_size:]

# build subsets that apply different transforms
train_dataset = SubsetWithTransform(dataset, train_idx, transform=train_transforms)
val_dataset   = SubsetWithTransform(dataset, val_idx,   transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train dataset size: {len(train_dataset)}", flush=True)
print(f"Validation dataset size: {len(val_dataset)}", flush=True)

# helper: move optimizer state tensors to cpu (recursively)
def optimizer_state_cpu(opt_state_dict):
    new = {"state": {}, "param_groups": opt_state_dict.get("param_groups", [])}
    for k, v in opt_state_dict.get("state", {}).items():
        if isinstance(v, dict):
            new_state = {}
            for kk, vv in v.items():
                if isinstance(vv, torch.Tensor):
                    new_state[kk] = vv.cpu()
                else:
                    new_state[kk] = vv
            new["state"][k] = new_state
        else:
            # fallback
            new["state"][k] = v
    return new

# helper: save checkpoint CPU-friendly
def save_checkpoint(model, optimizer, epoch, val_loss=None, out_dir="checkpoints", model_name="model", download=False):
    os.makedirs(out_dir, exist_ok=True)
    # Move model params to cpu for portability
    model_state_cpu = {k: v.detach().cpu() for k, v in model.state_dict().items()}

    opt_state = optimizer.state_dict()
    opt_state_cpu = optimizer_state_cpu(opt_state)

    checkpoint = {
        "epoch": epoch,
        "model_state": model_state_cpu,
        "optimizer_state": opt_state_cpu,
        "val_loss": val_loss,
    }
    fname = os.path.join(out_dir, f"{model_name}_epoch{epoch:03d}.pth")
    torch.save(checkpoint, fname)
    print(f"Saved checkpoint -> {fname}", flush=True)

    # Colab browser download option (ignored if not in colab)
    if download:
        try:
            from google.colab import files
            files.download(fname)
        except Exception as e:
            print("Auto-download failed (not in Colab?). Reason:", e, flush=True)

# Training + validation loop (1 epoch at a time so we can checkpoint each epoch)
start_time = time.perf_counter()
best_val = float("inf")
checkpoint_dir = f"checkpoints/{model_name}"

for epoch in range(1, epochs + 1):
    # train one epoch: note train(...) should handle one epoch between start_epoch and end_epoch
    train(
        model_name,
        model,
        optimizer,
        train_dataset,
        start_epoch=epoch - 1,
        end_epoch=epoch,
        batch_size=batch_size,
        device=device,
        save=False,            # disable internal saving if train() can save (we handle saving here)
        save_every=999,        # just in case; keep train from saving its own files
        checkpoint_dir=checkpoint_dir
    )

    # validate according to val_every
    val_loss = None
    if (epoch % val_every == 0) or (epoch == epochs):
        val_loss = val(
            model_name,
            model,
            val_dataset,
            epoch=epoch,
            batch_size=batch_size,
            device=device
        )
        # val(...) should return loss; if it doesn't, it'll be None and we still save

        # keep best
        if val_loss is not None and val_loss < best_val:
            best_val = val_loss
            print(f"New best val {best_val:.6f} at epoch {epoch}", flush=True)
            # Save a "best" copy
            save_checkpoint(model, optimizer, epoch, val_loss=val_loss, out_dir=checkpoint_dir, model_name=model_name, download=download_after_save)
            # Also write latest.pth
            latest_path = os.path.join(checkpoint_dir, f"{model_name}_latest.pth")
            torch.save({"epoch": epoch, "model_state": {k: v.detach().cpu() for k, v in model.state_dict().items()}}, latest_path)
        else:
            # no improvement: maybe still save according to policy below
            pass

    # save every `save_every` epochs (and also save final epoch)
    if save and ((epoch % save_every == 0) or (epoch == epochs)):
        save_checkpoint(model, optimizer, epoch, val_loss=val_loss, out_dir=checkpoint_dir, model_name=model_name, download=download_after_save)

end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"{epochs} Epochs model {model_name} train/val time: {elapsed_time:.4f} seconds", flush=True)


#### Classifier Performance

In [None]:
# Create Dataset and DataLoader
test_dataset = MedicalImageDataset(path_test_list, path_folder_images, image_to_multihot)

test_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Get the first batch
imgs, multi_hot, names = next(iter(test_loader))

print("Image batch shape:", imgs.shape)
print("Label batch shape:", multi_hot.shape)
print("First image name:", names[0])
print("First label vector:", multi_hot[0])

# Display the first image in the batch
plt.imshow(imgs[0].permute(1, 2, 0))
plt.title(f"Image: {names[0]}")
plt.axis('off')
plt.show()


In [None]:
from helpers_functions.metrics import *
from models.pred import *

from torch.utils.data import Subset
from io import StringIO
import sys

# TODO: DELETE ONCE READY TO FULLY UTILIZE
# Take first 100 elements
#test_subset = Subset(test_dataset, indices=list(range(100)))

save_dir = f"results/{model_name}"
os.makedirs(save_dir, exist_ok=True)

#capture printed output
buffer = StringIO()
stdout_original = sys.stdout
sys.stdout = buffer  # redirect prints into buffer

df, cm = evaluate(
    model_name=model_name, 
    model=model,
    pred=pred, 
    dataset=test_dataset)

print(json.dumps(df.to_dict(), indent=2))

print("Confusion Matrix\n", cm)

# Restore stdout
sys.stdout = stdout_original

# Print to console normally
print(buffer.getvalue())

# Save captured text to file 
log_path = os.path.join(save_dir, "metrics_output.txt")
with open(log_path, "w") as f:
    f.write(buffer.getvalue())

print(f"\nSaved printed metrics to:\n  {log_path}")

# Prediction

## Threshold Tuning

In [None]:
from tqdm import tqdm
from sklearn.metrics import f1_score, average_precision_score, roc_auc_score
from torch.utils.data import DataLoader

print("Device:", device)

# path to the checkpoint
path_checkpoint = "checkpoints/ASL/ASL_epoch005 - best.pth"
out_dir = f"results/{model_name}"
os.makedirs(out_dir, exist_ok=True)

try:
    val_dataset  
except NameError:
    raise RuntimeError("val_dataset not found in notebook.")

val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

# Robust load (reuse your load function or quick loader)
def load_checkpoint_to_model(model, ckpt_path, device):
    ckpt = torch.load(ckpt_path, map_location=device)
    if isinstance(ckpt, dict):
        for key in ("model_state_dict", "state_dict", "model", "net"):
            if key in ckpt:
                state = ckpt[key]; break
        else:
            state = ckpt
    else:
        state = ckpt

    new_state = {}
    for k, v in state.items():
        nk = k[len("module."):] if k.startswith("module.") else k
        new_state[nk] = v
    model.load_state_dict(new_state, strict=False)
    model.to(device)
    model.eval()
    return model

# Robust extraction of logits from model output
def extract_logits(output):
    """
    Return a torch.Tensor of logits.
    Handles:
      - tensor -> return it
      - tuple/list -> take first element (if tensor)
      - dict -> try common keys ('logits','out','pred','prediction')
    """
    # direct tensor
    if torch.is_tensor(output):
        return output
    # tuple or list
    if isinstance(output, (tuple, list)):
        if len(output) == 0:
            raise ValueError("Model returned empty tuple/list")
        return extract_logits(output[0])
    # dict-like
    if isinstance(output, dict):
        for key in ("logits", "out", "pred", "prediction", "score"):
            if key in output:
                return extract_logits(output[key])
        # otherwise try to find the first tensor value
        for v in output.values():
            if torch.is_tensor(v):
                return v
    raise TypeError(f"Cannot extract logits from model output of type {type(output)}")

# Get probs and labels (tuple-safe)
@torch.no_grad()
def get_probs_and_labels(model, loader, device):
    all_scores = []
    all_targets = []
    filenames = []
    model.eval()
    for imgs, targets, names in tqdm(loader, desc="Predicting on val"):
        imgs = imgs.to(device)
        raw_out = model(imgs)
        # extract logits robustly
        logits = extract_logits(raw_out)
        # ensure logits is on CPU, detached
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        all_scores.append(probs)
        all_targets.append(targets.numpy())
        filenames.extend(names)
    all_scores = np.vstack(all_scores)
    all_targets = np.vstack(all_targets)
    return all_scores, all_targets, filenames

# find best per-class thresholds (same as before)
def find_best_thresholds_by_f1(y_true, y_score, steps=101):
    thresholds = np.linspace(0.0, 1.0, steps)
    C = y_true.shape[1]
    best_thresh = np.zeros(C)
    best_f1 = np.zeros(C)
    for c in range(C):
        truths = y_true[:, c]
        scores = y_score[:, c]
        if truths.sum() == 0:
            best_thresh[c] = 0.5
            best_f1[c] = 0.0
            continue
        bf = -1.0
        bt = 0.5
        for t in thresholds:
            preds = (scores >= t).astype(int)
            f = f1_score(truths, preds, zero_division=0)
            if f > bf:
                bf = f
                bt = t
        best_thresh[c] = bt
        best_f1[c] = bf
    return best_thresh, best_f1

def eval_with_thresholds(y_true, y_score, thresholds):
    if np.isscalar(thresholds):
        preds = (y_score >= thresholds).astype(int)
    else:
        thresh = np.array(thresholds)[None, :]
        preds = (y_score >= thresh).astype(int)
    C = y_true.shape[1]
    per_class_f1 = np.array([f1_score(y_true[:, c], preds[:, c], zero_division=0) for c in range(C)])
    micro_f1 = f1_score(y_true.ravel(), preds.ravel(), zero_division=0)

    ap = []
    auroc = []
    for c in range(C):
        truths = y_true[:, c]
        scores = y_score[:, c]
        try:
            ap.append(average_precision_score(truths, scores))
        except Exception:
            ap.append(np.nan)
        try:
            auroc.append(roc_auc_score(truths, scores))
        except Exception:
            auroc.append(np.nan)
    return {
        "per_class_f1": per_class_f1,
        "micro_f1": micro_f1,
        "mean_ap": np.nanmean(ap),
        "mean_auroc": np.nanmean(auroc),
        "ap_per_class": np.array(ap),
        "auroc_per_class": np.array(auroc),
        "preds": preds
    }

# ---- Run: load checkpoint, predict, tune thresholds ----
print("Loading checkpoint into model...")
model = load_checkpoint_to_model(model, path_checkpoint, device)  # uses model already created in notebook

print("Running predictions on validation set...")
y_score, y_true, filenames = get_probs_and_labels(model, val_loader, device)
print("Shapes: scores", y_score.shape, "truths", y_true.shape)

# Save raw arrays
np.save(os.path.join(out_dir, f"val_scores_{model_name}.npy"), y_score)
np.save(os.path.join(out_dir, f"val_truths_{model_name}.npy"), y_true)

# Per-class tuning
best_th, best_f1s = find_best_thresholds_by_f1(y_true, y_score, steps=101)
print("Best per-class thresholds:", best_th)
print("Best per-class F1s:", best_f1s)

# Global threshold (micro-F1)
grid = np.linspace(0.0, 1.0, 101)
best_micro = -1.0
best_global_t = 0.5
for t in grid:
    res = eval_with_thresholds(y_true, y_score, t)
    if res["micro_f1"] > best_micro:
        best_micro = res["micro_f1"]
        best_global_t = t
print(f"Best global threshold: {best_global_t:.2f} (micro-F1 = {best_micro:.4f})")

# Save results
res_perclass = eval_with_thresholds(y_true, y_score, best_th)
out = {
    "best_per_class_thresholds": best_th.tolist(),
    "best_per_class_f1s": best_f1s.tolist(),
    "best_global_threshold": float(best_global_t),
    "best_global_micro_f1": float(best_micro),
    "ap_per_class": res_perclass["ap_per_class"].tolist(),
    "auroc_per_class": res_perclass["auroc_per_class"].tolist()
}
with open(os.path.join(out_dir, f"thresholds_{model_name}.json"), "w") as fh:
    json.dump(out, fh, indent=2)
print("Saved thresholds & metrics to", out_dir)

# Save predicted binary labels (using per-class thresholds)
preds_perclass = (y_score >= best_th[None, :]).astype(int)
np.save(os.path.join(out_dir, f"val_preds_perclass_{model_name}.npy"), preds_perclass)

# Print per-class thresholds
for i, (t, f) in enumerate(zip(best_th, best_f1s)):
    label = class_label_idx_to_str[i] if 'class_label_idx_to_str' in globals() else str(i)
    print(f"Class {i:2d} ({label}): thresh={t:.2f}, F1={f:.3f}")

In [None]:
out_dir = f"results/{model_name}"
os.makedirs(out_dir, exist_ok=True)

# y_score: (N, C) predicted probabilities, y_true: (N, C) binary truths
if 'y_score' not in globals() or 'y_true' not in globals():
    # try loading from saved numpy files
    score_path = os.path.join(out_dir, f"val_scores_{model_name}.npy")
    truth_path = os.path.join(out_dir, f"val_truths_{model_name}.npy")
    if os.path.exists(score_path) and os.path.exists(truth_path):
        y_score = np.load(score_path)
        y_true = np.load(truth_path)
        print(f"Loaded y_score ({y_score.shape}) and y_true ({y_true.shape}) from {out_dir}")
    else:
        raise RuntimeError("y_score / y_true not found in memory or results/. Run the tuning cell first.")

# best_th: per-class thresholds
if 'best_th' not in globals():
    thr_path = os.path.join(out_dir, "thresholds_ManifoldMixup.json")
    if os.path.exists(thr_path):
        with open(thr_path, "r") as fh:
            thr_data = json.load(fh)
        best_th = np.array(thr_data.get("best_per_class_thresholds", thr_data.get("best_per_class_thresholds", [0.5]*y_true.shape[1])))
        print(f"Loaded thresholds from {thr_path}")
    else:
        # fallback to 0.5 if nothing found
        best_th = np.full(y_true.shape[1], 0.5)
        print("No thresholds found on disk; using 0.5 for all classes.")

# Class labels
if 'class_label_idx_to_str' in globals():
    labels = [class_label_idx_to_str[i] for i in range(len(best_th))]
else:
    labels = [str(i) for i in range(len(best_th))]

# Apply thresholds to get binary predictions
preds = (y_score >= best_th[None, :]).astype(int)  # (N, C)

# Compute TP, FP, FN, TN per class
TP = np.sum((preds == 1) & (y_true == 1), axis=0)
FP = np.sum((preds == 1) & (y_true == 0), axis=0)
FN = np.sum((preds == 0) & (y_true == 1), axis=0)
TN = np.sum((preds == 0) & (y_true == 0), axis=0)

# Build DataFrame for display
df = pd.DataFrame({
    "label_idx": list(range(len(best_th))),
    "label": labels,
    "threshold": best_th,
    "TP": TP,
    "FP": FP,
    "FN": FN,
    "TN": TN
})
# Add derived metrics (precision/recall/f1) for quick sanity check
# Use zero_division=0 semantics: if denom is zero, set metric to 0.
prec = []
rec = []
f1 = []
for i in range(len(best_th)):
    tp, fp, fn = int(TP[i]), int(FP[i]), int(FN[i])
    p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    if (p + r) > 0:
        f = 2 * p * r / (p + r)
    else:
        f = 0.0
    prec.append(p)
    rec.append(r)
    f1.append(f)

df["precision"] = np.round(prec, 4)
df["recall"] = np.round(rec, 4)
df["f1_from_confusion"] = np.round(f1, 4)

# Print nicely
pd.set_option("display.max_rows", None)
print("\nPer-class confusion matrix and derived metrics:\n")
display(df)

# Save results
csv_path = os.path.join(out_dir, "per_class_confusion_ASL.csv")
json_path = os.path.join(out_dir, "per_class_confusion_ASL.json")
df.to_csv(csv_path, index=False)
df.to_json(json_path, orient="records", indent=2)
print(f"\nSaved per-class confusion to:\n  {csv_path}\n  {json_path}")

# Print totals and micro/macro metrics 
total_TP = int(TP.sum()); total_FP = int(FP.sum()); total_FN = int(FN.sum()); total_TN = int(TN.sum())
micro_precision = total_TP / (total_TP + total_FP) if (total_TP + total_FP) > 0 else 0.0
micro_recall = total_TP / (total_TP + total_FN) if (total_TP + total_FN) > 0 else 0.0
micro_f1 = 2 * micro_precision * micro_recall / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0

macro_precision = float(np.nanmean(df["precision"].replace(0, np.nan).fillna(0)))  # careful with zeros
macro_recall = float(np.nanmean(df["recall"].replace(0, np.nan).fillna(0)))
macro_f1 = float(np.nanmean(df["f1_from_confusion"].replace(0, np.nan).fillna(0)))

print("\nAggregated totals:")
print(f"  TP={total_TP}, FP={total_FP}, FN={total_FN}, TN={total_TN}")
print(f"Micro precision={micro_precision:.4f}, micro recall={micro_recall:.4f}, micro F1={micro_f1:.4f}")
print(f"Macro precision (mean of per-class) ~ {macro_precision:.4f}, macro recall ~ {macro_recall:.4f}, macro F1 ~ {macro_f1:.4f}")

#show top classes by FP/ FN
print("\nTop 5 classes by FP (descending):")
display(df.sort_values("FP", ascending=False).head(5)[["label_idx","label","FP","TP","FN"]])

print("\nTop 5 classes by FN (descending):")
display(df.sort_values("FN", ascending=False).head(5)[["label_idx","label","FN","TP","FP"]])

## Run Prediction

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "ManiFold_Mixup"
model = models(model_name, backbone_name='resnet18')

checkpoint_path = "checkpoints/ManiFold_Mixup/ManiFold_Mixup_epoch005-best.pth"
print("Loading checkpoint:", checkpoint_path)

# Load checkpoint safely
ckpt = torch.load(checkpoint_path, map_location=device)

# Try standard keys used in your training code
if "model_state_dict" in ckpt:
    state = ckpt["model_state_dict"]
elif "state_dict" in ckpt:
    state = ckpt["state_dict"]
elif "model" in ckpt:
    state = ckpt["model"]
else:
    state = ckpt

new_state = {}
for k, v in state.items():
    if k.startswith("module."):
        new_state[k.replace("module.", "")] = v
    else:
        new_state[k] = v

model.load_state_dict(new_state, strict=False)
model.to(device)
model.eval()

In [None]:
from models.pred import *
# Load thresholds tuned for checkpoint 5
with open("results/ManiFold_Mixup/thresholds_ManifoldMixup.json", "r") as f:
    data = json.load(f)

thresholds = np.array(data["best_per_class_thresholds"], dtype=np.float32)
print("Loaded thresholds for checkpoint 5:", thresholds)

In [None]:
probs, labels, preds = pred(
    "ManiFold_Mixup",
    model,
    test_dataset,
    batch_size=32,
    device="cuda",
    thresholds=thresholds
)

## Evaluate threshold-tuned model

In [None]:
micro_f1 = f1_score(labels.ravel(), preds.ravel(), zero_division=0)
macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
macro_precision = precision_score(labels, preds, average="macro", zero_division=0)
macro_recall = recall_score(labels, preds, average="macro", zero_division=0)

print("\n=== GLOBAL METRICS (WITH TUNED THRESHOLDS) ===")
print(f"Micro F1:   {micro_f1:.4f}")
print(f"Macro F1:   {macro_f1:.4f}")
print(f"Macro Prec: {macro_precision:.4f}")
print(f"Macro Rec:  {macro_recall:.4f}")

# ---- Per-class confusion matrix ----
TP = ((preds == 1) & (labels == 1)).sum(axis=0)
FP = ((preds == 1) & (labels == 0)).sum(axis=0)
FN = ((preds == 0) & (labels == 1)).sum(axis=0)
TN = ((preds == 0) & (labels == 0)).sum(axis=0)

F1_perclass = []
AP_perclass = []
AUROC_perclass = []

c = labels.shape[1]

for c in range(C):
    F1_perclass.append(f1_score(labels[:, c], preds[:, c], zero_division=0))
    try:
        AP_perclass.append(average_precision_score(labels[:, c], probs[:, c]))
    except:
        AP_perclass.append(np.nan)
    try:
        AUROC_perclass.append(roc_auc_score(labels[:, c], probs[:, c]))
    except:
        AUROC_perclass.append(np.nan)

df_eval = pd.DataFrame({
    "Class": np.arange(C),
    "TP": TP,
    "FP": FP,
    "FN": FN,
    "TN": TN,
    "F1": np.round(F1_perclass, 4),
    "AP": np.round(AP_perclass, 4),
    "AUROC": np.round(AUROC_perclass, 4),
})

print("\n=== PER-CLASS METRICS ===")
display(df_eval)

# Save output
df_eval.to_csv("results/final_eval_with_thresholds.csv", index=False)
print("\nSaved: results/final_eval_with_thresholds.csv")