# Import

In [None]:
# CLIP
from open_clip import create_model_from_pretrained, get_tokenizer

# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

# HuggingFace
from huggingface_hub import hf_hub_download
from datasets import load_dataset
from datasets.dataset_dict import DatasetDict
from datasets import arrow_dataset

# Metrics
from sklearn.metrics import (accuracy_score, confusion_matrix, f1_score,
                             precision_score, recall_score, roc_auc_score)
from sklearn.preprocessing import label_binarize


# Others
import pandas as pd
import numpy as np
from PIL import Image

# Load Model

In [None]:
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

repo_id = "to be replaced"
model_file = "to be replaced"
model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
state = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state["model_state_dict"])
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Running on", device)

# Helper functions

In [None]:
def get_preprocess_images_from_paths(image_paths: list[str], preprocess: transforms.Compose):
    return torch.stack([preprocess(Image.open(im)) for im in image_paths], dim=0)

def get_preprocess_images_from_public_data(ds: arrow_dataset.Dataset, preprocess: transforms.Compose):
    return torch.stack([preprocess(i) for i in ds["image"]], dim=0)

def prepare_MINDSet(path_to_MINDSet: str) -> pd.DataFrame:
    df = pd.read_csv(path_to_MINDSet)
    import numpy as np
    DEMENTIA_LABEL_MAP = {
        0: "no_dementia",
        1: "other_dementia",
        2: "AD"
    }

    ABNORMALITY_LABEL_MAP = {
        0 : "normal",
        1: "mtl_atrophy",
        2: "wmh",
        3: "other_atrophy"
    }
    df.rename(columns={'label': 'dementia_label'}, inplace=True)

    IMG_PATH = "vista/data/images/"
    df['dementia_type'] = df['dementia_label'].map(DEMENTIA_LABEL_MAP)
    df['binary_dementia_label'] = np.where(df['dementia_label'] == 0, 0, 1)

    df['img_path'] = df['img_path'].apply(lambda x: IMG_PATH + str(x))
    reversed_abnormality_types = {v:k for k, v in ABNORMALITY_LABEL_MAP.items()}
    df["abnormality_label"] = df['abnormal_type'].map(lambda x: int(reversed_abnormality_types[x]) if x in reversed_abnormality_types else reversed_abnormality_types[x.split(",")[0]])


    abnormality_description = {
    "normal": "MRI image shows normal brain structures without evidence of significant abnormalities or pathological changes.",
    "mtl_atrophy": "MRI image illustrates volume reduction and structural atrophy in the medial temporal lobes, including hippocampal shrinkage.",
    "wmh": "MRI image reveals hyperintense lesions within cerebral white matter regions, indicating white matter hyperintensities.",
    "mtl_atrophy,other_atrophy": "MRI image illustrates volume reduction and structural atrophy in the medial temporal lobes, including hippocampal shrinkage.MRI image indicates brain atrophy in cortical or subcortical regions other than medial temporal lobes, with notable structural volume loss.",
    "mtl_atrophy,wmh": "MRI image illustrates volume reduction and structural atrophy in the medial temporal lobes, including hippocampal shrinkage. MRI image reveals hyperintense lesions within cerebral white matter regions, indicating white matter hyperintensities.",
    "wmh,other_atrophy": "MRI image reveals hyperintense lesions within cerebral white matter regions, indicating white matter hyperintensities. MRI image indicates brain atrophy in cortical or subcortical regions other than medial temporal lobes, with notable structural volume loss.",
    "other_atrophy": "MRI image indicates brain atrophy in cortical or subcortical regions other than medial temporal lobes, with notable structural volume loss."
    }
    dementia_description = {
    "no_dementia": "MRI image presents no evident dementia-related structural changes, reflecting a normal cognitive state.",
    "AD": "MRI image shows characteristic patterns of brain atrophy suggestive of Alzheimer's Disease pathology.",
    "other_dementia":"MRI image shows structural brain abnormalities indicative of dementia types other than Alzheimer's Disease, such as Vascular dementia or Dementia with Lewy bodies."
    }
    df["abnormality_description"] = df["abnormal_type"].map(abnormality_description)
    df["dementia_description"] = df["dementia_type"].map(dementia_description)
    return df

In [None]:
def get_zeroshot_predictions(model: nn.Module, images: torch.tensor, texts: torch.tensor, batch_size:int = 16, device="cuda"):
    model.eval()
    model.to(device)
    texts = texts.to(device) 

    image_loader = DataLoader(images, batch_size=batch_size, shuffle=False)
    
    all_predicted_labels = []
    all_logits = []

    with torch.no_grad():
        for img_emb in image_loader:
            batch_images = img_emb.to(device)

            image_features, text_features,_ = model(batch_images, texts)

            logits = (image_features @ text_features.t()).detach()

            all_logits.append(logits.cpu().numpy())

            logits = logits.softmax(dim=-1).cpu().numpy()

            predicted_labels = logits.argmax(axis=-1)

            all_predicted_labels.append(predicted_labels)

        predicted_labels = np.concatenate(all_predicted_labels)
        all_logits = np.concatenate(all_logits)

    model.to("cpu")
    texts.to("cpu")
    return predicted_labels, all_logits


## Eval Metric

In [None]:
class EvalMetric:
    def __init__(self, labels: pd.Series, predictions: pd.Series):
        self.labels = labels
        self.predictions = predictions
        # Determine if problem is binary or multiclass
        self.n_classes = len(np.unique(labels))
        assert self.n_classes > 1
        self.average_method = 'binary' if self.n_classes <= 2 else 'macro'

    def get_accuracy(self) -> float:
        return accuracy_score(self.labels, self.predictions)

    def get_precision(self) -> float:
        return precision_score(self.labels, self.predictions, average=self. average_method, zero_division=0)

    def get_recall(self) -> float:
        return recall_score(self.labels, self.predictions, average=self.average_method, zero_division=0)

    def get_f1_score(self) -> float:
        return f1_score(self.labels, self.predictions, average=self.average_method, zero_division=0)

    def binary_specificity(self, y_true, y_pred):
        """Calculate specificity for binary classification."""
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan
        return specificity

    def multiclass_specificity(self, y_true, y_pred):
        """Calculate specificity for multiclass classification."""
        cm = confusion_matrix(y_true, y_pred)
        num_classes = cm.shape[0]
        specificities = []

        for i in range(num_classes):
            tn = np.sum(np.delete(np.delete(cm, i, axis=0), i, axis=1))
            fp = np.sum(cm[:, i]) - cm[i, i]
            specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan
            specificities.append(specificity)
        macro_average_specificity = np.mean(specificities)  # Macro average (mean of all class specificities)
        return macro_average_specificity

    def get_specificity(self):
        y_true = self.labels
        y_pred = self.predictions
        """Check if binary or multiclass, and calculate specificity accordingly."""
        if self.average_method == "binary":
            return self.binary_specificity(y_true, y_pred)
        else:
            return self.multiclass_specificity(y_true, y_pred)

    def get_overall_result(self) -> dict:
        return {
            'accuracy': round(self.get_accuracy(), 4),
            'precision': round(self.get_precision(), 4),
            'recall': round(self.get_recall(), 4),
            'f1_score': round(self.get_f1_score(), 4),
            'specificity': round(self.get_specificity(), 4)
        }


# Prepare Data

In [None]:
context_length = 512

ABNORMALITY_LABEL_MAP = {
                    0 : "normal",
                    1: "mtl_atrophy",
                    2: "wmh",
                    3: "other_atrophy"
                }


abnormality_description = {
    "normal": "MRI image shows normal brain structures without evidence of significant abnormalities or pathological changes.",
    "mtl_atrophy": "MRI image illustrates volume reduction and structural atrophy in the medial temporal lobes, including hippocampal shrinkage.",
    "wmh": "MRI image reveals hyperintense lesions within cerebral white matter regions, indicating white matter hyperintensities.",
    "other_atrophy": "MRI image indicates brain atrophy in cortical or subcortical regions other than medial temporal lobes, with notable structural volume loss."
}


## Prepare mindset

In [None]:
path_to_MINDSet = "to be replaced"
MINDSet = prepare_MINDSet(path_to_MINDSet = path_to_MINDSet)
MINDSet.head(3)

# Data preprocessing

In [None]:
MINDSet_images =  get_preprocess_images_from_paths(MINDSet["img_path"].values, preprocess)
MINDSet_images.shape

In [None]:
abnormality_texts = tokenizer(list(abnormality_description.values()), context_length)
abnormality_texts.shape

# Zero-shot Diagnosis 

In [None]:
zeroshot_df = MINDSet[['img_path', "binary_dementia_label"]].loc[120:].copy().reset_index(drop=True)
zeroshot_df.head(3)

##  Binary Dementia Classification on MINDSet

In [None]:
abnormality_predictions, abnormality_logits = get_zeroshot_predictions(model, MINDSet_images[120:], abnormality_texts)

In [None]:
zeroshot_df["binary_dementia_prediction"] = np.where(abnormality_predictions == 0, 0, 1)
zeroshot_df["binary_dementia_probbability"] = np.minimum(np.where(abnormality_predictions == 0, 1 - abnormality_logits[:, 0], abnormality_logits.max(axis=-1)), 1-(1e-9))
zeroshot_df.head(3)

In [None]:
binary_dementia_metric = EvalMetric(zeroshot_df["binary_dementia_label"], zeroshot_df["binary_dementia_prediction"])
binary_dementia_metric.get_overall_result()