## Training Image-Text embedding space

In [3]:
import torch
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from models.embedding.dataset import Dataset
from models.embedding.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder, resnet_backbone, load_medclip_retrained_resnet
from utils.sampling import MultilabelBalancedRandomSampler

configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*14
query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))

PROJ_SIZE = 512
device = get_device()

In [None]:
# backbone = resnet_backbone(load_pretrained_resnet(1, 14, 'models/backbone/pretrained/cxr_backbone_bal.pkl'))
# backbone = load_medclip_retrained_resnet('models/backbone/pretrained/medclip_resnet50.pkl')
# model = ImageTextEmbedding(backbone, PROJ_SIZE, device=device)
model = torch.load('models/embedding/model/vindr2/imgtext_model_trained.pth')
# model = torch.load('models/embedding/model/imgtext_model_from_medclip.pth')

mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [None]:
model = torch.load('models/embedding/model/imgtext_model_from_medclip.pth')
model.text_model = TextEncoder(PROJ_SIZE, device=device, bert_pretrained_type='bert-base-uncased')
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
mtrainer.model.text_model.set_backbone_trainable(False)
mtrainer.model.img_model.set_backbone_trainable(False)
mtrainer.model.logit_scale.requires_grad = False

mtrainer.model.img_model.set_backbone_layer_trainable(True, -1)

In [None]:
# 5 epochs: training projection + image encoder
# 2 epochs: training projection
# 1 epoch: training projection + image encoder

# 2 epoch: projection + last image encoder layer 
mtrainer.run_train(2, support_loader, query_loader, lr=1e-6, min_lr=5e-7)

In [4]:
mtrainer.run_eval(mtrainer.model, query_loader, True)

Loss 141.581787109375 | Accuracy 0.6071428571428571 | AUC 0.6692567693823752 | Specificity 0.618476927280426 | Recall 0.5859165787696838
Loss 144.48428344726562 | Accuracy 0.6321428571428571 | AUC 0.6833994106105094 | Specificity 0.6308995485305786 | Recall 0.6337515711784363
Loss 138.49468994140625 | Accuracy 0.6228571428571429 | AUC 0.6700765917360793 | Specificity 0.6043102741241455 | Recall 0.6448074579238892
Loss 81.15347290039062 | Accuracy 0.60625 | AUC 0.6773642007916304 | Specificity 0.5987976789474487 | Recall 0.6153738498687744


(0.6184,
 0.6747434482107707,
 131.86156860351562,
 0.6148399186134338,
 0.6205129861831665)

In [None]:
torch.save(mtrainer.model, 'models/embedding/model/vindr2/imgtext_model_trained1.pth')

## Zero-shot evaluations on validation and test sets

In [1]:
from utils.baseline import run_baseline_eval
from models.metaclassifier.utils import set_seed
import numpy as np
import torch
import random

from utils.device import get_device
from utils.baseline import run_baseline_eval
from utils.data import DatasetConfig
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT2

from models.backbone.datasets import MEAN_STDS
from models.metaclassifier.trainer import create_baseline_eval_dataloaders

from utils.baseline import run_img_txt_zeroshot_eval


dataset_config = DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
device =  get_device()
NUM_LABELS = 7
NUM_QUERY = 10
set_seed(420)

query_loader, val_loader, test_loader = create_baseline_eval_dataloaders(dataset_config, NUM_LABELS, NUM_QUERY)

In [2]:
## Trained image-text encoder
model = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')

seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    print('val')
    print(run_img_txt_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_img_txt_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.4836735129356384 | F1 0.33595800399780273 | Specificity 0.4385283589363098 | Recall 0.6619229912757874 | Precision 0.22292044758796692 | Bal Acc 0.5502256751060486
Episode 2 | Raw Acc 0.4551020562648773 | F1 0.35351091623306274 | Specificity 0.38347795605659485 | Recall 0.7345018982887268 | Precision 0.24292980134487152 | Bal Acc 0.558989942073822
Episode 3 | Raw Acc 0.4959184229373932 | F1 0.3682864308357239 | Specificity 0.45832040905952454 | Recall 0.631923496723175 | Precision 0.2784240245819092 | Bal Acc 0.545121967792511
Episode 4 | Raw Acc 0.43877553939819336 | F1 0.3648960590362549 | Specificity 0.3544105887413025 | Recall 0.7564001083374023 | Precision 0.24291440844535828 | Bal Acc 0.5554053783416748
Episode 5 | Raw Acc 0.41224491596221924 | F1 0.3424657881259918 | Specificity 0.321372926235199 | Recall 0.7547261714935303 | Precision 0.21879617869853973 | Bal Acc 0.538049578666687
Episode 6 | Raw Acc 0.43877553939819336 | F1 0.36781609058380127 | Specific

In [2]:
## Medclip encoder
model = torch.load('models/embedding/model/imgtext_model_from_medclip.pth')

seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    print('val')
    print(run_img_txt_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_img_txt_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.6244897842407227 | F1 0.22689077258110046 | Specificity 0.7071453928947449 | Recall 0.24112556874752045 | Precision 0.07526132464408875 | Bal Acc 0.47413548827171326
Episode 2 | Raw Acc 0.6510204672813416 | F1 0.24000000953674316 | Specificity 0.7497860193252563 | Recall 0.24681022763252258 | Precision 0.08391053974628448 | Bal Acc 0.49829810857772827
Episode 3 | Raw Acc 0.21836735308170319 | F1 0.3584589660167694 | Specificity 0.0 | Recall 1.0 | Precision 0.21836735308170319 | Bal Acc 0.5
Episode 4 | Raw Acc 0.6367347240447998 | F1 0.2053571492433548 | Specificity 0.7575844526290894 | Recall 0.20574162900447845 | Precision 0.07088745385408401 | Bal Acc 0.4816630482673645
Episode 5 | Raw Acc 0.6571428775787354 | F1 0.25 | Specificity 0.7450080513954163 | Recall 0.2577873468399048 | Precision 0.08659908175468445 | Bal Acc 0.5013977289199829
Episode 6 | Raw Acc 0.6653060913085938 | F1 0.2612612545490265 | Specificity 0.7614418268203735 | Recall 0.2206493616104126 | 