# Import

In [None]:
# CLIP
from open_clip import create_model_from_pretrained, get_tokenizer

# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms

# HuggingFace
from huggingface_hub import hf_hub_download
from datasets import load_dataset
from datasets.dataset_dict import DatasetDict
from datasets import arrow_dataset

# Metrics
from sklearn.metrics import (accuracy_score, confusion_matrix, f1_score,
                             precision_score, recall_score, roc_auc_score)
from sklearn.preprocessing import label_binarize


# Others
import pandas as pd
import numpy as np
from PIL import Image

# Helpers functions

## Remember Model

In [None]:
class Remember(nn.Module):
    def __init__(self, out):
        super().__init__()
        self.mlp = nn.Linear(512 * 4 * 2, 512)
        self.predict = nn.Sequential(
            torch.nn.Linear(512 * 2, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, out),
        )

    def forward(self, img_q, e_ast):
        img_q = img_q.to(torch.float32)
        e_ast = e_ast.to(torch.float32)
        
        e_bar = self.mlp(e_ast)

        e_bar_n = e_bar / e_bar.norm(dim = -1, keepdim=True)
        
        attn1 = torch.bmm(e_bar_n, img_q.unsqueeze(-1))
        attn1 = attn1.squeeze(-1)

        attn1 = torch.nn.functional.softmax(attn1, dim=-1)

        attn1_reshaped = attn1.unsqueeze(1)
    
        evidence_agg1_bmm_intermediate = torch.bmm(attn1_reshaped, e_bar)
        
        evidence_agg1 = evidence_agg1_bmm_intermediate.squeeze(1)

        input_evidence1 = torch.cat([
            img_q,
            evidence_agg1
        ], dim=-1)

        out = self.predict(input_evidence1)
        return out            

        
        

## Evidence Dataset

In [None]:
class EvidenceDataset(Dataset):
    def __init__(self, df: pd.DataFrame, all_dicts: dict, top_k=3, images: torch.tensor =None):
        self.df = df
        self.all_dicts = all_dicts
        self.top_k = top_k
        self.images = images 
        self.tokenizer = tokenizer
        self.label = torch.from_numpy(df["label"].values)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sel = self.df.iloc[idx]


        if self.images is not None:
            img_q = self.images[idx]
        else:
            img_q = self.all_dicts["path"][sel["img_path"]]

        e_images = torch.stack([self.all_dicts["path"][sel[f"i_{i}"]] for i in range(self.top_k)])
        e_ab = torch.stack([self.all_dicts["abnormality"][sel[f"abnormality_{i}"]] for i in range(self.top_k)])
        e_dem = torch.stack([self.all_dicts["dementia"][sel[f"dementia_{i}"]] for i in range(self.top_k)])
        e_desc = torch.stack([self.all_dicts["description"][sel[f"description_{i}"]] for i in range(self.top_k)])
        sims = torch.tensor([sel[f"sim_{i}"] for i in range(self.top_k)])

        e = torch.cat([e_images, e_ab, e_dem, e_desc], dim=-1)
        e_sim = e * sims.unsqueeze(-1)
        e_ast = torch.cat([e, e_sim], dim=-1)
        return img_q, e_ast, self.label[idx]

In [None]:
def get_preprocess_images_from_paths(image_paths: list[str], preprocess: transforms.Compose):
    return torch.stack([preprocess(Image.open(im)) for im in image_paths], dim=0)

def get_preprocess_images_from_public_data(ds: arrow_dataset.Dataset, preprocess: transforms.Compose):
    return torch.stack([preprocess(i) for i in ds["image"]], dim=0)

def prepare_MINDSet(path_to_MINDSet: str) -> pd.DataFrame:
    df = pd.read_csv(path_to_MINDSet)
    import numpy as np
    DEMENTIA_LABEL_MAP = {
        0: "no_dementia",
        1: "other_dementia",
        2: "AD"
    }

    ABNORMALITY_LABEL_MAP = {
        0 : "normal",
        1: "mtl_atrophy",
        2: "wmh",
        3: "other_atrophy"
    }
    df.rename(columns={'label': 'dementia_label'}, inplace=True)

    IMG_PATH = "vista/data/images/"
    df['dementia_type'] = df['dementia_label'].map(DEMENTIA_LABEL_MAP)
    df['binary_dementia_label'] = np.where(df['dementia_label'] == 0, 0, 1)

    df['img_path'] = df['img_path'].apply(lambda x: IMG_PATH + str(x))
    reversed_abnormality_types = {v:k for k, v in ABNORMALITY_LABEL_MAP.items()}
    df["abnormality_label"] = df['abnormal_type'].map(lambda x: int(reversed_abnormality_types[x]) if x in reversed_abnormality_types else reversed_abnormality_types[x.split(",")[0]])


    abnormality_description = {
    "normal": "MRI image shows normal brain structures without evidence of significant abnormalities or pathological changes.",
    "mtl_atrophy": "MRI image illustrates volume reduction and structural atrophy in the medial temporal lobes, including hippocampal shrinkage.",
    "wmh": "MRI image reveals hyperintense lesions within cerebral white matter regions, indicating white matter hyperintensities.",
    "mtl_atrophy,other_atrophy": "MRI image illustrates volume reduction and structural atrophy in the medial temporal lobes, including hippocampal shrinkage.MRI image indicates brain atrophy in cortical or subcortical regions other than medial temporal lobes, with notable structural volume loss.",
    "mtl_atrophy,wmh": "MRI image illustrates volume reduction and structural atrophy in the medial temporal lobes, including hippocampal shrinkage. MRI image reveals hyperintense lesions within cerebral white matter regions, indicating white matter hyperintensities.",
    "wmh,other_atrophy": "MRI image reveals hyperintense lesions within cerebral white matter regions, indicating white matter hyperintensities. MRI image indicates brain atrophy in cortical or subcortical regions other than medial temporal lobes, with notable structural volume loss.",
    "other_atrophy": "MRI image indicates brain atrophy in cortical or subcortical regions other than medial temporal lobes, with notable structural volume loss."
    }
    dementia_description = {
    "no_dementia": "MRI image presents no evident dementia-related structural changes, reflecting a normal cognitive state.",
    "AD": "MRI image shows characteristic patterns of brain atrophy suggestive of Alzheimer's Disease pathology.",
    "other_dementia":"MRI image shows structural brain abnormalities indicative of dementia types other than Alzheimer's Disease, such as Vascular dementia or Dementia with Lewy bodies."
    }
    df["abnormality_description"] = df["abnormal_type"].map(abnormality_description)
    df["dementia_description"] = df["dementia_type"].map(dementia_description)
    return df

In [None]:
def get_model():
    from open_clip import create_model_from_pretrained, get_tokenizer
    base_model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
    tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

    return base_model, preprocess, tokenizer

def load_state_dict(model, model_name, model_file):
  from huggingface_hub import hf_hub_download
  repo_id = f"henrytang785/VisTA_{model_name}"
  model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
  state = torch.load(model_path, map_location="cpu", weights_only=True)
  model.load_state_dict(state["model_state_dict"])
  return model

def get_image_embeddings(model, images, batch_size=16, device="cuda"):
    model.eval()
    model.to(device)

    image_loader = DataLoader(images, batch_size=batch_size, shuffle=False)
    all_embeddings = []


    with torch.no_grad():
        for img_emb in image_loader:
            # Get batch of image embeddings
            batch_img_embs = img_emb.to(device)

            # Get features for current batch
            image_features = model.encode_image(batch_img_embs).detach().cpu()
            batch_img_embs.to("cpu")
            all_embeddings.append(image_features)

    model.to("cpu")
    return torch.cat(all_embeddings)

## Training functions

In [None]:
def get_label(model: nn.Module, data, batch_size, device):
    model.eval()
    model.to(device)
    dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)

    all_labels = []
    predicted_labels = []
    all_probs = []

    with torch.no_grad():
        for img_q, e_ast, label in dataloader:
            img_q = img_q.to(device)
            e_ast = e_ast.to(device)
            label = label.to(device)

            logits = model(img_q, e_ast)
            probs = F.softmax(logits, dim=1)
            _, preds = torch.max(probs, 1)

            all_labels.extend(label.cpu().numpy())
            predicted_labels.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    return all_labels, predicted_labels, all_probs

def get_image_embeddings(model: nn.Module, images: torch.tensor, batch_size=16, device="cuda"):
    model.eval()
    model.to(device)

    image_loader = DataLoader(images, batch_size=batch_size, shuffle=False)
    all_embeddings = []


    with torch.no_grad():
        for img_emb in image_loader:
            # Get batch of image embeddings
            batch_img_embs = img_emb.to(device)

            # Get features for current batch
            image_features = model.encode_image(batch_img_embs).detach().cpu()
            batch_img_embs.to("cpu")
            all_embeddings.append(image_features)

    model.to("cpu")
    return torch.cat(all_embeddings)

def normalize_embeddings(embeddings):
    return F.normalize(embeddings, p=2, dim=-1)

def train_one_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    
    for idx, (img_q, e_ast, label) in enumerate(dataloader):
        optimizer.zero_grad()
        img_q = img_q.to(device)
        e_ast = e_ast.to(device)
        label = label.to(device)
        pred = model(img_q, e_ast)
     
        loss = loss_fn(pred, label)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        _, predicted = torch.max(pred.data, 1)
        correct += (predicted == label).sum().item()

        optimizer.zero_grad()
       

    return total_loss / len(dataloader), correct / len(dataloader.dataset), correct

def train_loop(params):
    model = params["model"]
    train_data = params["train_data"]
    test_data = params["test_data"]
    num_epoch = params.get("num_epoch", 100)
    loss_fn = params["loss_fn"]
    optimizer = params["optimizer"]
    save_path = params.get("save_path", "checkpoints")
    device = params.get("device", "cuda")
    num_workers = params.get("num_workers", 4)
    log_file = params.get("log_file", "log.csv")
    batch_size = params.get("batch_size", 32)
    best_model = None


    model.to(device)
    os.makedirs(save_path, exist_ok=True)
    best_loss = float("inf")
    num_epoch_since_best = 0
    previous_best_model = None 
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    with open(log_file, "w") as f:
        f.write("Epoch,Train Loss,Train Accuracy,Train Correct,Val Loss,Val Accuracy,Val Correct\n")
    
    for epoch in range(num_epoch):
        print("*" * 100)
        print("Epoch", epoch)
        train_loss, train_accuracy, train_correct = train_one_epoch(model, train_loader, loss_fn, optimizer, device)
        print(f"Train Loss: {train_loss}, Train Accuracy: {train_accuracy}, Correct: {train_correct}")
        val_loss, val_accuracy, val_correct = evaluate(model, test_loader, loss_fn,device)
        print(f"Val Loss: {val_loss}, Val Accuracy: {val_accuracy}, Correct: {val_correct}")
        
        with open(log_file, "a") as f:
            f.write(f"{epoch},{train_loss},{train_accuracy},{train_correct},{val_loss},{val_accuracy},{val_correct}\n")
       

    checkpoint = {
    "epoch": num_epoch,
    "model_state_dict": model.state_dict(),
    }
    model_filename = f"{save_path}/model_epoch_{num_epoch}.pt"
    torch.save(checkpoint, model_filename)
    print(f"Model saved at epoch {epoch}")
    return num_epoch



def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    model.to(device)
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for idx, (img_q, e_ast, label) in enumerate(dataloader):
            img_q = img_q.to(device)
            e_ast = e_ast.to(device)
            label = label.to(device)
            
            pred = model(img_q, e_ast)

            loss = loss_fn(pred, label)
            total_loss += loss.item()
            _, predicted = torch.max(pred.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
    
    return total_loss / len(dataloader), correct / total, correct

## Eval Metrics

In [None]:
class EvalMetric:
    def __init__(self, labels: pd.Series, predictions: pd.Series):
        self.labels = labels
        self.predictions = predictions
        # Determine if problem is binary or multiclass
        self.n_classes = len(np.unique(labels))
        assert self.n_classes > 1
        self.average_method = 'binary' if self.n_classes <= 2 else 'macro'

    def get_accuracy(self) -> float:
        return accuracy_score(self.labels, self.predictions)

    def get_precision(self) -> float:
        return precision_score(self.labels, self.predictions, average=self. average_method, zero_division=0)

    def get_recall(self) -> float:
        return recall_score(self.labels, self.predictions, average=self.average_method, zero_division=0)

    def get_f1_score(self) -> float:
        return f1_score(self.labels, self.predictions, average=self.average_method, zero_division=0)

    def binary_specificity(self, y_true, y_pred):
        """Calculate specificity for binary classification."""
        cm = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan
        return specificity

    def multiclass_specificity(self, y_true, y_pred):
        """Calculate specificity for multiclass classification."""
        cm = confusion_matrix(y_true, y_pred)
        num_classes = cm.shape[0]
        specificities = []

        for i in range(num_classes):
            tn = np.sum(np.delete(np.delete(cm, i, axis=0), i, axis=1))
            fp = np.sum(cm[:, i]) - cm[i, i]
            specificity = tn / (tn + fp) if (tn + fp) > 0 else np.nan
            specificities.append(specificity)
        macro_average_specificity = np.mean(specificities)  # Macro average (mean of all class specificities)
        return macro_average_specificity

    def get_specificity(self):
        y_true = self.labels
        y_pred = self.predictions
        """Check if binary or multiclass, and calculate specificity accordingly."""
        if self.average_method == "binary":
            return self.binary_specificity(y_true, y_pred)
        else:
            return self.multiclass_specificity(y_true, y_pred)

    def get_overall_result(self) -> dict:
        return {
            'accuracy': round(self.get_accuracy(), 4),
            'precision': round(self.get_precision(), 4),
            'recall': round(self.get_recall(), 4),
            'f1_score': round(self.get_f1_score(), 4),
            'specificity': round(self.get_specificity(), 4)
        }


# Load Model

In [None]:
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

repo_id = "to be replaced"
model_file = "to be replaced"
model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
state = torch.load(model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state["model_state_dict"])
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Running on", device)

# Prepare *dataset*

## Prepare Mindset

In [None]:
path_to_MINDSet = "to be replaced"
MINDSet = prepare_MINDSet(path_to_MINDSet = path_to_MINDSet)
mindset_train = MINDSet.iloc[:120].copy().reset_index(drop=True)
mindset_test = MINDSet.iloc[120:].copy().reset_index(drop=True)

print("Train:", len(mindset_train), "Test:", len(mindset_test))

# Get evidences

In [None]:
ms_train_evidence_path = "to be replaced"
evidences_mindset_train = pd.read_csv(ms_train_evidence_path)
evidences_mindset_train.drop("Unnamed: 0", axis = 1, inplace=True)
evidences_mindset_train.head(3)

In [None]:
ms_test_evidence_path = "to be replaced"
evidences_mindset_test = pd.read_csv(ms_test_evidence_path)
evidences_mindset_test.drop("Unnamed: 0", axis = 1, inplace=True)
evidences_mindset_test.head(3)

# Generate embeddings

In [None]:
# Preprocess
mindset_images_train_raw =  get_preprocess_images_from_paths(mindset_train["img_path"].values, preprocess)
mindset_images_test_raw =  get_preprocess_images_from_paths(mindset_test["img_path"].values, preprocess)

# Get embeddings
mindset_images_train_embs = get_image_embeddings(model, mindset_images_train_raw, device=device)
mindset_images_test_embs = get_image_embeddings(model, mindset_images_test_raw, device=device)

# Normalize
mindset_images_train_embs = normalize_embeddings(mindset_images_train_embs)
mindset_images_test_embs = normalize_embeddings(mindset_images_test_embs)

In [None]:
path_embeddings_map = {}
for i, k in mindset_train.iterrows():
    path_embeddings_map[k["img_path"]] =  mindset_images_train_embs[i]
for i, k in mindset_test.iterrows():
    path_embeddings_map[k["img_path"]] = mindset_images_test_embs[i]

In [None]:
all_des = MINDSet["description"].unique()
all_des_tok = tokenizer(all_des, context_length=512).to(torch.float32)
all_des_tok = [normalize_embeddings(tok) for tok in all_des_tok]
description_embeddings_map = { all_des[i] : all_des_tok[i] for i in range(len(all_des)) }

all_ab = MINDSet["abnormal_type"].unique()
all_ab_tok = tokenizer(all_ab, context_length=512).to(torch.float32)
all_ab_tok = [normalize_embeddings(tok) for tok in all_ab_tok]
abnormality_embeddings_map = { all_ab[i] : all_ab_tok[i] for i in range(len(all_ab)) }

all_dem = MINDSet["dementia_type"].unique()
all_dem_tok = tokenizer(all_dem, context_length=512).to(torch.float32)
all_dem_tok = [normalize_embeddings(tok) for tok in all_dem_tok]
dementia_embeddings_map = { all_dem[i] : all_dem_tok[i] for i in range(len(all_dem)) }


In [None]:
all_dicts = {
    "path": path_embeddings_map,
    "abnormality": abnormality_embeddings_map,
    "dementia": dementia_embeddings_map,
    "description": description_embeddings_map
}

# Training process

## Hyperparameter

In [None]:
num_epoch = 100
lr = 5e-5
batch_size = 4
patience = 5
num_workers = 4

## Binary Dementia Classification

In [None]:
label_column = "dementia"
evidences_mindset_train["label"] = mindset_train[f"{label_column}_label"]
evidences_mindset_test["label"] = mindset_test[f"{label_column}_label"]
evidences_mindset_train["label"] = evidences_mindset_train["label"].map(lambda x: 0 if x == 0 else 1)
evidences_mindset_test["label"] = evidences_mindset_test["label"].map(lambda x: 0 if x == 0 else 1)

In [None]:
train_data = EvidenceDataset(evidences_mindset_train , all_dicts, top_k=3)
test_data = EvidenceDataset(evidences_mindset_test, all_dicts, top_k = 3)
len(train_data), len(test_data)

In [None]:
m = Remember(2)
optimizer = torch.optim.Adam(m.parameters(), lr=lr)

In [None]:
params = {
    "model": m,  # your model instance
    "train_data": train_data,  # torch.utils.data.Dataset or DataLoader
    "test_data": test_data,    # torch.utils.data.Dataset or DataLoader
    "num_epoch": num_epoch,              # total number of training epochs
    "loss_fn": torch.nn.CrossEntropyLoss(),  # or any other loss function
    "optimizer": optimizer,  # or another optimizer
    "save_path": "bin_dem",   # directory to save best model checkpoints
    "patience": patience,                # early stopping patience
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "num_workers": num_workers,             # number of workers for data loading
    "log_file": "bin_dem_log.csv",        # where to save training logs
    "batch_size": batch_size              # batch size for training
}

In [None]:
best_dementia_bin = train_loop(params)

# Evaluate

In [None]:
m.load_state_dict(torch.load("to be replaced", weights_only=True)["model_state_dict"])
lables, predictions, _ = get_label(m, test_data, 16, device)
metric = EvalMetric(lables, predictions)
pd.DataFrame.from_dict([metric.get_overall_result()])