# Import libraries

In [1]:
# CLIP
from open_clip import create_model_from_pretrained, get_tokenizer

# Torch
import torch

# Metrics
from sklearn.metrics import confusion_matrix, roc_auc_score, precision_score, recall_score, f1_score, accuracy_score
from imblearn.metrics import specificity_score

# FS
import os
import io

# Others
from PIL import Image
import pandas as pd
import numpy as np
from typing import List, Dict

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [2]:
DATA_PATH = "../data/"
DESC_PATH = os.path.join(DATA_PATH, "data_description.csv")
IMG_PATH = os.path.join(DATA_PATH, "images")

# Helpers

In [3]:
template = 'This is a photo of '
context_length = 512
SIM_THRESHOLD = 0.6
PROB_THRESHOLD = 50

ABNORMAL_TYPES = ['normal', 'mtl_atrophy', 'wmh', 'other_atrophy']

DEMENTIA_TYPES = {
    0: "no_dementia",
    1: "other_dementia",
    2: "AD"
}

def get_labels(model, preprocess, tokenizer, images: List[str], labels: List[str], device:str, top_k: int = -1, is_bytes: bool = False) -> Dict:
  # Get embeddings of images and texts
  model.eval()
  if not is_bytes:
    img_embs = torch.stack([preprocess(Image.open(os.path.join(IMG_PATH, img))) for img in images]).to(device)
  else:
    img_embs = torch.stack([preprocess(Image.open(io.BytesIO(img))) for img in images]).to(device)

  print(f'Generated embeddings of {len(img_embs)} images.')
  text_embs = tokenizer([template + l for l in labels], context_length=context_length).to(device)
  print(f'Generated embeddings of {len(text_embs)} labels.')

  # Calculate similarity
  with torch.no_grad():
    image_features, text_features, logit_scale = model(img_embs, text_embs)

    logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
    sorted_indices = torch.argsort(logits, dim=-1, descending=True)

    logits = logits.cpu().numpy()
    sorted_indices = sorted_indices.cpu().numpy()

  image_results = []
  for i, img in enumerate(images):
      pred = labels[sorted_indices[i][0]]
      top_k = len(labels) if top_k == -1 else top_k
      image_result = dict()
      for j in range(top_k):
          jth_index = sorted_indices[i][j]
          label = labels[jth_index].replace(template, "").replace(".", "").strip()
          score = logits[i][jth_index]
          image_result[label] = score
      image_results.append(image_result)

  return image_results

def replace_label(text: str, mapping_dict: dict):
    for key, value in mapping_dict.items():
        if key in text:
          text = text.replace(key, value)
    return text

def add_label_with_score(data: pd.DataFrame, result:dict, mapping_dict:dict):
  clone_data = data.copy()
  clone_data = clone_data.reset_index(drop=True)
  for idx, label in enumerate(result):
    threshold = min(max(label.values()), SIM_THRESHOLD)
    for lb_value, lb_score in label.items():
      lb_code = mapping_dict.get(lb_value, lb_value)
      clone_data.loc[idx, f'sim_score_{lb_code}'] = label[lb_value]

      clone_data.loc[idx, f'is_predicted_{lb_code}'] = 1 if label[lb_value] >= threshold else 0
  return clone_data

def is_correct_abnormality(row):
  for abnormality in row['abnormal_type'].split(","):
    if row[f'is_predicted_{abnormality}'] == 1:
      return 1
  return 0

def is_correct_dementia(row):
  true_dementia = row['label_text']
  return 1 if row[f'is_predicted_{true_dementia}'] == 1 else 0

def get_dementia_prob(row, dementia, diagnosis_prob:dict):
  dementia_prob = []
  for abnormal_type in ABNORMAL_TYPES:
    if row[f'is_predicted_{abnormal_type}'] == 1:
      dementia_prob.append(diagnosis_prob[f'is_{abnormal_type}'].get(dementia, 0))
  return max(dementia_prob) if len(dementia_prob) > 0 else 0

def add_predicted_dementia(data:pd.DataFrame, diagnosis_prob:dict, dementia_dict:dict):
  clone_data = data.copy()
  for dementia in dementia_dict.keys():
    clone_data[f'prob_{dementia_dict[dementia]}'] = clone_data.apply(lambda row: get_dementia_prob(row, dementia, diagnosis_prob), axis=1)
    clone_data[f'is_predicted_{dementia_dict[dementia]}'] = (clone_data[f'prob_{dementia_dict[dementia]}'] >= PROB_THRESHOLD).astype(int)
  return clone_data

def get_count_values(data: pd.DataFrame, column_name: str, is_ascending: bool = False):
  counts = data[column_name].value_counts(ascending=is_ascending)
  percentage = (data[column_name].value_counts(normalize=True, ascending=is_ascending) * 100).round(2)
  return pd.concat([counts, percentage], axis=1)

class EvalMetric:
  def __init__(self, labels:pd.Series, scores:pd.Series, predictions:pd.Series):
    self.labels = labels
    self.scores = scores
    # self.predictions = (scores >= threshold).astype(int)
    self.predictions = predictions

  def get_accuracy(self) -> float:
    return accuracy_score(self.labels, self.predictions)

  def get_precision(self) -> float:
    return precision_score(self.labels, self.predictions)

  def get_recall(self) -> float:
    # This metric is also sensitivity
    return recall_score(self.labels, self.predictions)

  def get_f1_score(self) -> float:
    return f1_score(self.labels, self.predictions)

  def get_specificity(self) -> float:
    return specificity_score(self.labels, self.predictions)

  def get_auc_score(self) -> float:
    return roc_auc_score(self.labels, self.scores)

  def get_overall_result(self) -> dict:
    return {
        'precision': self.get_precision(),
        'recall': self.get_recall(),
        'f1_score': self.get_f1_score(),
        'specificity': self.get_specificity(),
        'auc': self.get_auc_score(),
        'accuracy': self.get_accuracy()
    }

def get_evaluation(data, label_col:str, score_col_prefix:str, label_list:list) -> dict:
  result_dict = dict()
  clone_data = data.copy()
  for label_value in label_list:
    clone_data[f'is_{label_value}'] = clone_data[label_col].map(lambda val: label_value in val).astype(int)
    result_dict[label_value] = EvalMetric(labels=clone_data[f'is_{label_value}'], scores=clone_data[f'{score_col_prefix}_{label_value}'], predictions=clone_data[f'is_predicted_{label_value}']).get_overall_result()
  return result_dict

# Load model

In [4]:
from huggingface_hub import hf_hub_download
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')

repo_id = "ddlinh/vista"
filename = "ft_model_abnormality_only__lr_1e-5__batch_size_8.pth"
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
model.load_state_dict(torch.load(model_path, map_location="cpu"))

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Device: {device}")
model.to(device)

  _torch_pytree._register_pytree_node(


Device: cpu


CustomTextCLIP(
  (visual): TimmModel(
    (trunk): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768

# Data Processing

In [5]:
data = pd.read_csv(DESC_PATH)
data['label_text'] = data['label'].map(DEMENTIA_TYPES)
data.head(3)

Unnamed: 0,description,label,abnormal_type,img_path,set,label_text
0,This coronal T1-weighted MRI image shows mildl...,2,mtl_atrophy,image_139.png,train,AD
1,This is a sagittal T1-weighted MRI image which...,1,other_atrophy,image_30.png,train,other_dementia
2,This coronal T1-weighted MRI image shows a mar...,2,mtl_atrophy,image_119.jpeg,train,AD


In [6]:
# Get Dementia probability of each abnormality type 
DIAGNOSIS_PROB = dict()

for abnormal_type in ABNORMAL_TYPES:
  column_name = f'is_{abnormal_type}'
  data[column_name] = data['abnormal_type'].apply(lambda x: abnormal_type in x).astype(int)
  abnorm_to_dementia = get_count_values(data.groupby(column_name), 'label').reset_index()
  DIAGNOSIS_PROB[column_name] = abnorm_to_dementia[abnorm_to_dementia[column_name] == 1][['label', 'proportion']].set_index('label')['proportion'].T.to_dict()

DIAGNOSIS_PROB

{'is_normal': {0: 100.0},
 'is_mtl_atrophy': {2: 81.01, 1: 18.99},
 'is_wmh': {1: 57.14, 2: 42.86},
 'is_other_atrophy': {1: 65.22, 2: 34.78}}

# Baseline Performance

In [7]:
column = 'abnormal_type'
mapping_dict = {
    'normal': 'brain on MRI, without signs of dementia',
    'mtl_atrophy': 'medial temporal lobe atrophy',
    'wmh': 'white matter hyperintensities',
    'other_atrophy': 'a type of brain atrophy',
}

template = "This is a photo of "

reversed_mapping_dict = {v:k for k, v in mapping_dict.items()}
labels = [template + label for label in reversed_mapping_dict.keys()]
labels

['This is a photo of brain on MRI, without signs of dementia',
 'This is a photo of medial temporal lobe atrophy',
 'This is a photo of white matter hyperintensities',
 'This is a photo of a type of brain atrophy']

## Generate embeddings

In [8]:
train_data, test_data = data.iloc[:120], data.iloc[120:]
train_images = train_data["img_path"].tolist()
test_images = test_data["img_path"].tolist()

train_image_labels = get_labels(model, preprocess, tokenizer, device=device, images=train_images, labels=labels, top_k=-1)
test_image_labels = get_labels(model, preprocess, tokenizer, device=device, images=test_images, labels=labels, top_k=-1)


Generated embeddings of 120 images.
Generated embeddings of 4 labels.
Generated embeddings of 50 images.
Generated embeddings of 4 labels.


## MINDSet - Train set

In [9]:
train_data_with_result = add_label_with_score(data=train_data, result=train_image_labels, mapping_dict=reversed_mapping_dict)
train_data_with_result.head(3)

Unnamed: 0,description,label,abnormal_type,img_path,set,label_text,is_normal,is_mtl_atrophy,is_wmh,is_other_atrophy,sim_score_mtl_atrophy,is_predicted_mtl_atrophy,sim_score_normal,is_predicted_normal,sim_score_wmh,is_predicted_wmh,sim_score_other_atrophy,is_predicted_other_atrophy
0,This coronal T1-weighted MRI image shows mildl...,2,mtl_atrophy,image_139.png,train,AD,0,1,0,0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
1,This is a sagittal T1-weighted MRI image which...,1,other_atrophy,image_30.png,train,other_dementia,0,0,0,1,0.0,0.0,5.3111749999999994e-34,0.0,6.2186489999999995e-30,0.0,1.0,1.0
2,This coronal T1-weighted MRI image shows a mar...,2,mtl_atrophy,image_119.jpeg,train,AD,0,1,0,0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0


----
### Abnormality Retrieval

In [10]:
train_data_with_result['is_correct_abnormality'] = train_data_with_result.apply(lambda row: is_correct_abnormality(row), axis=1)
get_count_values(train_data_with_result, 'is_correct_abnormality')

Unnamed: 0_level_0,count,proportion
is_correct_abnormality,Unnamed: 1_level_1,Unnamed: 2_level_1
1,111,92.5
0,9,7.5


In [11]:
baseline_result = get_evaluation(train_data_with_result, label_col='abnormal_type', score_col_prefix='sim_score', label_list=ABNORMAL_TYPES)
for abnormal_type in baseline_result.keys():
  print(abnormal_type)
  print(baseline_result[abnormal_type])
  print('----')

normal
{'precision': 1.0, 'recall': 0.8260869565217391, 'f1_score': 0.9047619047619048, 'specificity': np.float64(1.0), 'auc': np.float64(0.9287315105333931), 'accuracy': 0.9666666666666667}
----
mtl_atrophy
{'precision': 0.9038461538461539, 'recall': 0.8703703703703703, 'f1_score': 0.8867924528301887, 'specificity': np.float64(0.9242424242424242), 'auc': np.float64(0.9281705948372616), 'accuracy': 0.9}
----
wmh
{'precision': 0.875, 'recall': 1.0, 'f1_score': 0.9333333333333333, 'specificity': np.float64(0.9811320754716981), 'auc': np.float64(0.9952830188679246), 'accuracy': 0.9833333333333333}
----
other_atrophy
{'precision': 0.9393939393939394, 'recall': 0.9393939393939394, 'f1_score': 0.9393939393939394, 'specificity': np.float64(0.9770114942528736), 'auc': np.float64(0.9531522117729014), 'accuracy': 0.9666666666666667}
----


----

### Dementia Prediction

In [12]:
DIAGNOSIS_PROB

{'is_normal': {0: 100.0},
 'is_mtl_atrophy': {2: 81.01, 1: 18.99},
 'is_wmh': {1: 57.14, 2: 42.86},
 'is_other_atrophy': {1: 65.22, 2: 34.78}}

In [13]:
data_with_predicted_dementia = add_predicted_dementia(train_data_with_result, diagnosis_prob=DIAGNOSIS_PROB, dementia_dict=DEMENTIA_TYPES)
data_with_predicted_dementia.head(3)

Unnamed: 0,description,label,abnormal_type,img_path,set,label_text,is_normal,is_mtl_atrophy,is_wmh,is_other_atrophy,...,is_predicted_wmh,sim_score_other_atrophy,is_predicted_other_atrophy,is_correct_abnormality,prob_no_dementia,is_predicted_no_dementia,prob_other_dementia,is_predicted_other_dementia,prob_AD,is_predicted_AD
0,This coronal T1-weighted MRI image shows mildl...,2,mtl_atrophy,image_139.png,train,AD,0,1,0,0,...,0.0,0.0,0.0,1,0.0,0,18.99,0,81.01,1
1,This is a sagittal T1-weighted MRI image which...,1,other_atrophy,image_30.png,train,other_dementia,0,0,0,1,...,0.0,1.0,1.0,1,0.0,0,65.22,1,34.78,0
2,This coronal T1-weighted MRI image shows a mar...,2,mtl_atrophy,image_119.jpeg,train,AD,0,1,0,0,...,0.0,0.0,0.0,1,0.0,0,18.99,0,81.01,1


In [16]:
data_with_predicted_dementia['is_dementia'] = (data_with_predicted_dementia['label'] != 0).astype(int)
data_with_predicted_dementia.head(3)

Unnamed: 0,description,label,abnormal_type,img_path,set,label_text,is_normal,is_mtl_atrophy,is_wmh,is_other_atrophy,...,is_predicted_no_dementia,prob_other_dementia,is_predicted_other_dementia,prob_AD,is_predicted_AD,is_correct_dementia_type,is_predicted_dementia,max_sim_score,predicted_score,is_dementia
0,This coronal T1-weighted MRI image shows mildl...,2,mtl_atrophy,image_139.png,train,AD,0,1,0,0,...,0,18.99,0,81.01,1,1,1,1.0,1.0,1
1,This is a sagittal T1-weighted MRI image which...,1,other_atrophy,image_30.png,train,other_dementia,0,0,0,1,...,0,65.22,1,34.78,0,1,1,1.0,1.0,1
2,This coronal T1-weighted MRI image shows a mar...,2,mtl_atrophy,image_119.jpeg,train,AD,0,1,0,0,...,0,18.99,0,81.01,1,1,1,1.0,1.0,1


In [14]:
data_with_predicted_dementia['is_correct_dementia_type'] = data_with_predicted_dementia.apply(lambda row: is_correct_dementia(row), axis=1)
get_count_values(data_with_predicted_dementia, 'is_correct_dementia_type')

Unnamed: 0_level_0,count,proportion
is_correct_dementia_type,Unnamed: 1_level_1,Unnamed: 2_level_1
1,88,73.33
0,32,26.67


In [17]:
data_with_predicted_dementia['is_predicted_dementia'] = (data_with_predicted_dementia[[f'sim_score_{ab_type}' for ab_type in ABNORMAL_TYPES]].idxmax(axis=1).str.replace('sim_score_', '') != 'normal').astype(int)
data_with_predicted_dementia['max_sim_score'] = data_with_predicted_dementia[[f'sim_score_{ab_type}' for ab_type in ABNORMAL_TYPES]].max(axis=1)

# Reverse the score of normal to get the score on dementia
data_with_predicted_dementia['predicted_score'] = data_with_predicted_dementia.apply(lambda row: row['max_sim_score'] if row['is_predicted_dementia'] == 1 else 1 - row['max_sim_score'], axis=1)

prediction_result = EvalMetric(data_with_predicted_dementia['is_dementia'], 
                               data_with_predicted_dementia['predicted_score'], 
                               data_with_predicted_dementia['is_predicted_dementia']).get_overall_result()
prediction_result

{'precision': 0.9603960396039604,
 'recall': 1.0,
 'f1_score': 0.9797979797979798,
 'specificity': np.float64(0.8260869565217391),
 'auc': np.float64(0.9094576423128642),
 'accuracy': 0.9666666666666667}

In [18]:
dementia_result = get_evaluation(data_with_predicted_dementia, label_col='label_text', score_col_prefix='prob', label_list=DEMENTIA_TYPES.values())
for dementia_type in dementia_result.keys():
  print(dementia_type)
  print(dementia_result[dementia_type])
  print('----')

no_dementia
{'precision': 1.0, 'recall': 0.8260869565217391, 'f1_score': 0.9047619047619048, 'specificity': np.float64(1.0), 'auc': np.float64(0.9130434782608696), 'accuracy': 0.9666666666666667}
----
other_dementia
{'precision': 0.6122448979591837, 'recall': 0.75, 'f1_score': 0.6741573033707865, 'specificity': np.float64(0.7625), 'auc': np.float64(0.79203125), 'accuracy': 0.7583333333333333}
----
AD
{'precision': 0.75, 'recall': 0.6842105263157895, 'f1_score': 0.7155963302752294, 'specificity': np.float64(0.7936507936507936), 'auc': np.float64(0.7942077415761626), 'accuracy': 0.7416666666666667}
----


------

## MINDSet - Test set

In [19]:
test_data_with_result = add_label_with_score(data=test_data, result=test_image_labels, mapping_dict=reversed_mapping_dict)
test_data_with_result.head(3)

Unnamed: 0,description,label,abnormal_type,img_path,set,label_text,is_normal,is_mtl_atrophy,is_wmh,is_other_atrophy,sim_score_normal,is_predicted_normal,sim_score_wmh,is_predicted_wmh,sim_score_other_atrophy,is_predicted_other_atrophy,sim_score_mtl_atrophy,is_predicted_mtl_atrophy
0,This is an axial T2-weighted MRI image showing...,0,normal,image_91.png,test,no_dementia,1,0,0,0,1.0,1.0,1.550419e-24,0.0,6.38489e-39,0.0,0.0,0.0
1,This MRI image shows a normal hippocampus but ...,2,other_atrophy,image_83.png,test,AD,0,0,0,1,1.1045689999999999e-26,0.0,3.709014e-34,0.0,1.0,1.0,0.0,0.0
2,The MRI scan shows evidence of atrophy in the ...,2,mtl_atrophy,image_161.jpg,test,AD,0,1,0,0,4.420747e-38,0.0,4.810306e-34,0.0,1.0,1.0,0.0,0.0


----
### Abnormality Retrieval

In [20]:
test_data_with_result['is_correct_abnormality'] = test_data_with_result.apply(lambda row: is_correct_abnormality(row), axis=1)
get_count_values(test_data_with_result, 'is_correct_abnormality')

Unnamed: 0_level_0,count,proportion
is_correct_abnormality,Unnamed: 1_level_1,Unnamed: 2_level_1
1,37,74.0
0,13,26.0


In [21]:
baseline_result = get_evaluation(test_data_with_result, label_col='abnormal_type', score_col_prefix='sim_score', label_list=ABNORMAL_TYPES)
for abnormal_type in baseline_result.keys():
  print(abnormal_type)
  print(baseline_result[abnormal_type])
  print('----')

normal
{'precision': 0.75, 'recall': 0.6, 'f1_score': 0.6666666666666666, 'specificity': np.float64(0.95), 'auc': np.float64(0.845), 'accuracy': 0.88}
----
mtl_atrophy
{'precision': 0.8333333333333334, 'recall': 0.6, 'f1_score': 0.6976744186046512, 'specificity': np.float64(0.88), 'auc': np.float64(0.7928), 'accuracy': 0.74}
----
wmh
{'precision': 0.7777777777777778, 'recall': 1.0, 'f1_score': 0.875, 'specificity': np.float64(0.9534883720930233), 'auc': np.float64(0.9883720930232558), 'accuracy': 0.96}
----
other_atrophy
{'precision': 0.6, 'recall': 0.6923076923076923, 'f1_score': 0.6428571428571429, 'specificity': np.float64(0.8378378378378378), 'auc': np.float64(0.8586278586278586), 'accuracy': 0.8}
----


----

### Dementia Prediction

In [22]:
data_with_predicted_dementia = add_predicted_dementia(test_data_with_result, diagnosis_prob=DIAGNOSIS_PROB, dementia_dict=DEMENTIA_TYPES)
data_with_predicted_dementia.head(3)

Unnamed: 0,description,label,abnormal_type,img_path,set,label_text,is_normal,is_mtl_atrophy,is_wmh,is_other_atrophy,...,is_predicted_other_atrophy,sim_score_mtl_atrophy,is_predicted_mtl_atrophy,is_correct_abnormality,prob_no_dementia,is_predicted_no_dementia,prob_other_dementia,is_predicted_other_dementia,prob_AD,is_predicted_AD
0,This is an axial T2-weighted MRI image showing...,0,normal,image_91.png,test,no_dementia,1,0,0,0,...,0.0,0.0,0.0,1,100.0,1,0.0,0,0.0,0
1,This MRI image shows a normal hippocampus but ...,2,other_atrophy,image_83.png,test,AD,0,0,0,1,...,1.0,0.0,0.0,1,0.0,0,65.22,1,34.78,0
2,The MRI scan shows evidence of atrophy in the ...,2,mtl_atrophy,image_161.jpg,test,AD,0,1,0,0,...,1.0,0.0,0.0,0,0.0,0,65.22,1,34.78,0


In [23]:
data_with_predicted_dementia['is_dementia'] = (data_with_predicted_dementia['label'] != 0).astype(int)
data_with_predicted_dementia.head(3)

Unnamed: 0,description,label,abnormal_type,img_path,set,label_text,is_normal,is_mtl_atrophy,is_wmh,is_other_atrophy,...,sim_score_mtl_atrophy,is_predicted_mtl_atrophy,is_correct_abnormality,prob_no_dementia,is_predicted_no_dementia,prob_other_dementia,is_predicted_other_dementia,prob_AD,is_predicted_AD,is_dementia
0,This is an axial T2-weighted MRI image showing...,0,normal,image_91.png,test,no_dementia,1,0,0,0,...,0.0,0.0,1,100.0,1,0.0,0,0.0,0,0
1,This MRI image shows a normal hippocampus but ...,2,other_atrophy,image_83.png,test,AD,0,0,0,1,...,0.0,0.0,1,0.0,0,65.22,1,34.78,0,1
2,The MRI scan shows evidence of atrophy in the ...,2,mtl_atrophy,image_161.jpg,test,AD,0,1,0,0,...,0.0,0.0,0,0.0,0,65.22,1,34.78,0,1


In [24]:
data_with_predicted_dementia['is_correct_dementia_type'] = data_with_predicted_dementia.apply(lambda row: is_correct_dementia(row), axis=1)
get_count_values(data_with_predicted_dementia, 'is_correct_dementia_type')

Unnamed: 0_level_0,count,proportion
is_correct_dementia_type,Unnamed: 1_level_1,Unnamed: 2_level_1
1,26,52.0
0,24,48.0


In [25]:
data_with_predicted_dementia['is_predicted_dementia'] = (data_with_predicted_dementia[[f'sim_score_{ab_type}' for ab_type in ABNORMAL_TYPES]].idxmax(axis=1).str.replace('sim_score_', '') != 'normal').astype(int)
data_with_predicted_dementia['max_sim_score'] = data_with_predicted_dementia[[f'sim_score_{ab_type}' for ab_type in ABNORMAL_TYPES]].max(axis=1)

# Reverse the score of normal to get the score on dementia
data_with_predicted_dementia['predicted_score'] = data_with_predicted_dementia.apply(lambda row: row['max_sim_score'] if row['is_predicted_dementia'] == 1 else 1 - row['max_sim_score'], axis=1)

prediction_result = EvalMetric(data_with_predicted_dementia['is_dementia'], 
                               data_with_predicted_dementia['predicted_score'], 
                               data_with_predicted_dementia['is_predicted_dementia']).get_overall_result()
prediction_result

{'precision': 0.9047619047619048,
 'recall': 0.95,
 'f1_score': 0.926829268292683,
 'specificity': np.float64(0.6),
 'auc': np.float64(0.8237500000000001),
 'accuracy': 0.88}

In [26]:
dementia_result = get_evaluation(data_with_predicted_dementia, label_col='label_text', score_col_prefix='prob', label_list=DEMENTIA_TYPES.values())
for dementia_type in dementia_result.keys():
  print(dementia_type)
  print(dementia_result[dementia_type])
  print('----')

no_dementia
{'precision': 0.75, 'recall': 0.6, 'f1_score': 0.6666666666666666, 'specificity': np.float64(0.95), 'auc': np.float64(0.775), 'accuracy': 0.88}
----
other_dementia
{'precision': 0.3333333333333333, 'recall': 0.5714285714285714, 'f1_score': 0.42105263157894735, 'specificity': np.float64(0.5555555555555556), 'auc': np.float64(0.5853174603174603), 'accuracy': 0.56}
----
AD
{'precision': 0.6666666666666666, 'recall': 0.46153846153846156, 'f1_score': 0.5454545454545454, 'specificity': np.float64(0.75), 'auc': np.float64(0.6722756410256411), 'accuracy': 0.6}
----


------

## HF Data
https://huggingface.co/datasets/Falah/Alzheimer_MRI

In [27]:
PUBLIC_DATASET_PATH = os.path.join(DATA_PATH, "public_HF_dataset")
public_test = pd.read_parquet(os.path.join(PUBLIC_DATASET_PATH, 'test.parquet'))
public_test.shape

(1280, 2)

In [28]:
public_images = [data["bytes"] for data in public_test["image"].tolist()]
public_test_result = get_labels(model=model, preprocess=preprocess, tokenizer=tokenizer, device=device, images=public_images, labels=labels, is_bytes=True)
len(public_test_result)

Generated embeddings of 1280 images.
Generated embeddings of 4 labels.


1280

In [37]:
test_public_with_result = add_label_with_score(public_test, result=public_test_result, mapping_dict=reversed_mapping_dict)
test_public_with_result.head(3)

Unnamed: 0,image,label,sim_score_other_atrophy,is_predicted_other_atrophy,sim_score_wmh,is_predicted_wmh,sim_score_normal,is_predicted_normal,sim_score_mtl_atrophy,is_predicted_mtl_atrophy
0,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,3,1.0,1.0,3.339516e-30,0.0,1.717683e-32,0.0,0.0,0.0
1,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,0,1.0,1.0,9.649233000000001e-33,0.0,7.615173000000001e-32,0.0,0.0,0.0
2,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,2,0.837178,1.0,5.364493e-16,0.0,0.1628221,0.0,0.0,0.0


In [38]:
test_public_with_result['is_dementia'] = (test_public_with_result['label'] != 2).astype(int)
test_public_with_result['is_predicted_dementia'] = (test_public_with_result[[f'sim_score_{abnormality}' for abnormality in ABNORMAL_TYPES]].idxmax(axis=1).str.replace("sim_score_", "") != 'normal').astype(int)
test_public_with_result['sim_score'] = test_public_with_result[[f'sim_score_{abnormality}' for abnormality in ABNORMAL_TYPES]].max(axis=1)
test_public_with_result['predicted_score'] = test_public_with_result.apply(lambda row: row['sim_score'] if row['is_predicted_dementia'] == 1 else 1 - row['sim_score'], axis=1)
test_public_with_result.head(3)

Unnamed: 0,image,label,sim_score_other_atrophy,is_predicted_other_atrophy,sim_score_wmh,is_predicted_wmh,sim_score_normal,is_predicted_normal,sim_score_mtl_atrophy,is_predicted_mtl_atrophy,is_dementia,is_predicted_dementia,sim_score,predicted_score
0,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,3,1.0,1.0,3.339516e-30,0.0,1.717683e-32,0.0,0.0,0.0,1,1,1.0,1.0
1,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,0,1.0,1.0,9.649233000000001e-33,0.0,7.615173000000001e-32,0.0,0.0,0.0,1,1,1.0,1.0
2,{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x...,2,0.837178,1.0,5.364493e-16,0.0,0.1628221,0.0,0.0,0.0,0,1,0.837178,0.837178


In [39]:
test_public_with_result['is_correct'] = (test_public_with_result['is_dementia'] == test_public_with_result['is_predicted_dementia']).astype(int)
get_count_values(test_public_with_result, 'is_correct')

Unnamed: 0_level_0,count,proportion
is_correct,Unnamed: 1_level_1,Unnamed: 2_level_1
1,661,51.64
0,619,48.36


In [36]:
public_result = EvalMetric(test_public_with_result['is_dementia'], test_public_with_result['predicted_score'], test_public_with_result['is_predicted_dementia']).get_overall_result()
public_result

{'precision': 0.5106719367588933,
 'recall': 1.0,
 'f1_score': 0.6760858189429618,
 'specificity': np.float64(0.02365930599369085),
 'auc': np.float64(0.539598206873651),
 'accuracy': 0.51640625}