In [1]:
import torch
import timm
import os
import glob
import torch.nn as nn
import yaml
import argparse
import platform
import open_clip

from accelerate import DistributedDataParallelKwargs
from MultiMEDal_multimodal_medical.src.plotting.plot_funcs import (
    plot_pr_curve,
    plot_roc_curve,
)
from MultiMEDal_multimodal_medical.src.plotting.plot_funcs import (
    plot_pr_curve_crossval,
    plot_roc_curve_crossval,
)

from MultiMEDal_multimodal_medical.src.evaluation.compute_metrics import (
    compute_binary_metrics, compute_multilabel_metrics,
    compute_binary_metrics_crossval,
)
from MultiMEDal_multimodal_medical.src.evaluation.compute_metrics import (
    compute_multiclass_metrics,
    compute_multiclass_metrics_crossval,
)

from MultiMEDal_multimodal_medical.src.datasets.data_transform import (
    build_transform_dict,
    build_transform_dict_mamm,
    build_transform_dict_blip2,
    build_transform_dict_openclip,
    build_transform_dict_pubmedclip
)
from MultiMEDal_multimodal_medical.src.datasets.data_loader import get_dataloaders
from MultiMEDal_multimodal_medical.src.datasets.dataset_init import get_datasets, get_combined_datasets
from MultiMEDal_multimodal_medical.src.datasets.custom_concat_dataset import CustomConcatDataset
from MultiMEDal_multimodal_medical.src.datasets.preprocessing.prompt_factory import tab2prompt_breast_lesion
from MultiMEDal_multimodal_medical.src.test import test_utils

from MultiMEDal_multimodal_medical.src.models.neural_net import Ann
from MultiMEDal_multimodal_medical.src.models.image_tabular_net import Image_Tabular_Concat_Model, Image_Tabular_Concat_ViT_Model, Image_Tabular_CrossAtt_ViT_Model

from MultiMEDal_multimodal_medical.src.models.open_clip import Clip_Image_Tabular_Model
from transformers import CLIPProcessor, CLIPModel

from libauc.losses import AUCMLoss 
from libauc.optimizers import PESG 

from accelerate.utils import set_seed
from torch.utils.data import ConcatDataset
from itertools import chain
from accelerate import Accelerator
from argparse import Namespace

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
patch_tabular_datasets = ["CBIS-DDSM-tfds-with-tabular-2classes", "CBIS-DDSM-tfds-with-tabular-methodist-mass-appearance",
                        "CBIS-DDSM-tfds-with-tabular-methodist-calc-morph", "CBIS-DDSM-tfds-with-tabular-methodist-calc-dist",
                        "CBIS-DDSM-tfds-with-tabular-mass-shape", "CBIS-DDSM-tfds-with-tabular-mass-margin",
                        "CBIS-DDSM-tfds-with-tabular-calc-morph", "CBIS-DDSM-tfds-with-tabular-calc-dist",
                          "EMBED-unique-mapping-tfds-with-tabular-2classes", "EMBED-unique-mapping-tfds-with-tabular-demography-only-2classes"]


In [3]:
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/cbis_ddsm/evaclip-vit-base-zeroshot_patches-224-tabular-ddsm_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/cbis_ddsm/biomedclip-vit-base-zeroshot_patches-224-tabular-ddsm_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/cbis_ddsm/pubmedclip-vit-base-zeroshot_patches-224-tabular-ddsm_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/cbis_ddsm/evaclip-vit-large-zeroshot_patches-224-tabular-ddsm_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/cbis_ddsm/evaclip-vit-giant-zeroshot_patches-224-tabular-ddsm_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/cbis_ddsm/evaclip-vit-bigG-zeroshot_patches-224-tabular-ddsm_2classes_datasets.yaml'

# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/embed/evaclip-vit-base-zeroshot_patches-224-tabular-embed_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/embed/biomedclip-vit-base-zeroshot_patches-224-tabular-embed_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/embed/pubmedclip-vit-base-zeroshot_patches-224-tabular-embed_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/embed/evaclip-vit-large-zeroshot_patches-224-tabular-embed_2classes_datasets.yaml'
# CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/embed/evaclip-vit-giant-zeroshot_patches-224-tabular-embed_2classes_datasets.yaml'
CONFIG_PATH = '/home/hqvo2/Projects/MultiMEDal_multimodal_medical/src/configs/paper_multimodal_config/patches_zero_shot_config/embed/evaclip-vit-bigG-zeroshot_patches-224-tabular-embed_2classes_datasets.yaml'

In [4]:
with open(CONFIG_PATH, 'r') as file:
    yaml_cfg = yaml.safe_load(file)

# %%
import psutil, time
p = psutil.Process(os.getppid())
dt_string = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(p.create_time()))

config_dict = yaml_cfg['hyperparams']
config_dict['save_root'] = os.path.join(config_dict['save_root'], str(os.getppid()) + '_' + dt_string)
if isinstance(config_dict['image_size'], list):
    config_dict['image_size'] = tuple(config_dict['image_size'])

In [5]:
CONTEXT_LENGTH = None
if config_dict.get('model_name') in ['open_clip']:
    
    transform_dict, txt_processors = build_transform_dict_openclip(config_dict.get('pretrain_model_name'), 
                                                                pretrained_data=config_dict.get('pretrain_data', None))
    if config_dict.get('pretrain_model_name') == "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224":
        CONTEXT_LENGTH = 256
elif config_dict.get('model_name') == 'pubmed_clip':
    transform_dict, txt_processors = build_transform_dict_pubmedclip()

In [6]:
dataset_name = config_dict['dataset']
data_dir = config_dict['datadir']

if isinstance(dataset_name, list):
    combined_datasets = get_combined_datasets(
        dataset_name[0],
        dataset_name[1],
        dataset_name[2],
        transform_dict,
        data_dir[0],
        data_dir[1],
        data_dir[2],
    )
    all_train_datasets, all_val_datasets, all_test_datasets = combined_datasets
    train_dataset = CustomConcatDataset(all_train_datasets)
    val_dataset = CustomConcatDataset(all_val_datasets)
    test_dataset = CustomConcatDataset(all_test_datasets)

    
    train_labels = train_dataset.get_all_labels()    


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 5610


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 1650


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 1650


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 5610


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 1650


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 1650


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 5610


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 1650


  df_clinical = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_clinical.csv"))
  df_metadata = pd.read_csv(os.path.join(self.data_root, "tables/EMBED_OpenData_metadata.csv"))


# Samples: (8485, 87)
# Lesions ID: (8485, 2)
#Samples 1650


In [7]:
train_sampler = None
val_sampler = None

train_dataloader, val_dataloader, test_dataloader = get_dataloaders(
    train_dataset,
    test_dataset,
    val_dataset,
    train_sampler,
    val_sampler,
    config_dict.get('batch_size'),
    config_dict.get('njobs'),
)

In [8]:
accelerator = Accelerator(mixed_precision='fp16', project_dir=[])



In [9]:
if config_dict.get('model_name') == "pubmed_clip":
    model = CLIPModel.from_pretrained(config_dict.get('pretrain_model_name'))
    processor = CLIPProcessor.from_pretrained("flaviagiammarino/pubmed-clip-vit-base-patch32")
    tokenizer = processor.tokenizer
else:
    model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(config_dict.get('pretrain_model_name'), \
                                                                                    pretrained=config_dict.get('pretrain_data'))
    tokenizer = open_clip.get_tokenizer(config_dict.get('pretrain_model_name'))

In [10]:
model = accelerator.prepare(
    model
)   

model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 1664, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0): ResidualAttentionBlock(
          (ln_1): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1664, out_features=1664, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((1664,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=1664, out_features=8192, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=8192, out_features=1664, bias=True)
          )
          (ls_2): Identity()
        )
        (1): ResidualAttentionBlock(
          (ln_1): LayerNorm((1664,), ep

In [11]:
PROMPT_TYPE = 'with_context' 
assert PROMPT_TYPE in ['normal', 'with_context']

In [12]:

for dataloader in [val_dataloader, test_dataloader]:
    dataloader = accelerator.prepare(
        dataloader
    )

    all_outputs = torch.tensor([], device=accelerator.device)
    all_labels = torch.tensor([], device=accelerator.device, dtype=torch.int32)


    for batch_data in dataloader:
        image_samples = batch_data['image']

        if PROMPT_TYPE == 'normal':
            if config_dict.get('model_name') == "pubmed_clip":
                tokenized_text_samples = tokenizer(["an image of a benign lesion", "an image of a malignant lesion"], return_tensors="pt", padding=True)
                tokenized_text_samples = tokenized_text_samples['input_ids']        
            else:
                if CONTEXT_LENGTH is not None:
                    tokenized_text_samples = tokenizer(["an image of a benign lesion", "an image of a malignant lesion"], \
                                                        context_length=CONTEXT_LENGTH)
                else:
                    tokenized_text_samples = tokenizer(["an image of a benign lesion", "an image of a malignant lesion"])
        
        elif PROMPT_TYPE == 'with_context':
            _, text_samples = tab2prompt_breast_lesion(config_dict.get('model_name'), 'val',
                                                    batch_data, txt_processors, _context_length=CONTEXT_LENGTH, _group_age=config_dict.get('group_age'))
            
        
            all_processed_texts = list(map(lambda text: (
                                            text.replace(\
                                                'mass lesion', \
                                                'benign mass lesion').replace(\
                                                    'calcification lesion', \
                                                    'benign calcification lesion'),
                                            text.replace(\
                                                'mass lesion', \
                                                'malignant mass lesion').replace(\
                                                    'calcification lesion', \
                                                    'malignant calcification lesion'),
                                            ), text_samples))

            unravel_processed_texts = list(sum(all_processed_texts, ()))   


            if config_dict.get('model_name') == "pubmed_clip":
                tokenized_text_samples = tokenizer(unravel_processed_texts, return_tensors="pt", padding=True)
                tokenized_text_samples = tokenized_text_samples['input_ids']        
            else:
                if CONTEXT_LENGTH is not None:
                    tokenized_text_samples = tokenizer(unravel_processed_texts, \
                                                        context_length=CONTEXT_LENGTH)
                else:
                    tokenized_text_samples = tokenizer(unravel_processed_texts)



        # Forward Images and Texts
        with torch.no_grad(), torch.cuda.amp.autocast():
            if config_dict.get('model_name') == "pubmed_clip":
                image_features = model.visual_projection(model.vision_model(image_samples)['pooler_output'])
                text_features = model.text_projection(model.text_model(tokenized_text_samples.to(accelerator.device)[:, :77])['pooler_output'])
            else:
                image_features = model.encode_image(image_samples)
                text_features = model.encode_text(tokenized_text_samples.to(accelerator.device))
                image_features /= image_features.norm(dim=-1, keepdim=True)
                text_features /= text_features.norm(dim=-1, keepdim=True)


        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

        if PROMPT_TYPE == 'with_context':
            text_probs = torch.gather(text_probs, 1, torch.arange(text_probs.shape[1]).reshape(-1, 2).to(accelerator.device))
        


        # Concatenate all outpus and labels
        all_outputs = torch.cat((all_outputs, text_probs), dim=0)
        all_labels = torch.cat((all_labels, batch_data['label']), dim=0)
 



    eval_log = compute_binary_metrics(all_outputs, all_labels, accelerator.device)

    print(eval_log['acc'], eval_log['auroc'], eval_log['ap'])

tensor(0.6606, device='cuda:0') tensor(0.4788, device='cuda:0') tensor(0.3232, device='cuda:0')
tensor(0.6606, device='cuda:0') tensor(0.4788, device='cuda:0') tensor(0.3232, device='cuda:0')
