## Training Image-Text embedding space

In [3]:
import torch
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from models.embedding.dataset import Dataset
from models.embedding.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder, resnet_backbone, load_medclip_retrained_resnet
from utils.sampling import MultilabelBalancedRandomSampler

configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*14
query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))

PROJ_SIZE = 512
device = get_device()
# backbone = resnet_backbone(load_pretrained_resnet(1, 14, 'models/backbone/pretrained/cxr_backbone_bal.pkl'))
# backbone = load_medclip_retrained_resnet('models/backbone/pretrained/medclip_resnet50.pkl')
# model = ImageTextEmbedding(backbone, PROJ_SIZE, device=device)
model = torch.load('models/embedding/model/vindr2/imgtext_model_trained.pth')
# model = torch.load('models/embedding/model/imgtext_model_trained.pth')

mtrainer = Trainer(model, support_dataset.class_labels(), device)


In [2]:
mtrainer.model.text_model.set_backbone_trainable(False)
mtrainer.model.img_model.set_backbone_trainable(False)
mtrainer.model.logit_scale.requires_grad = False

mtrainer.model.img_model.set_backbone_layer_trainable(True, -1)

In [None]:
# 5 epochs: training projection + image encoder
# 2 epochs: training projection
# 1 epoch: training projection + image encoder

# 2 epoch: projection + last image encoder layer 
mtrainer.run_train(2, support_loader, query_loader, lr=1e-6, min_lr=5e-7)

In [4]:
mtrainer.run_eval(mtrainer.model, query_loader, True)

Loss 141.581787109375 | Accuracy 0.6071428571428571 | AUC 0.6692567693823752 | Specificity 0.618476927280426 | Recall 0.5859165787696838
Loss 144.48428344726562 | Accuracy 0.6321428571428571 | AUC 0.6833994106105094 | Specificity 0.6308995485305786 | Recall 0.6337515711784363
Loss 138.49468994140625 | Accuracy 0.6228571428571429 | AUC 0.6700765917360793 | Specificity 0.6043102741241455 | Recall 0.6448074579238892
Loss 81.15347290039062 | Accuracy 0.60625 | AUC 0.6773642007916304 | Specificity 0.5987976789474487 | Recall 0.6153738498687744


(0.6184,
 0.6747434482107707,
 131.86156860351562,
 0.6148399186134338,
 0.6205129861831665)

In [None]:
torch.save(mtrainer.model, 'models/embedding/model/vindr2/imgtext_model_trained1.pth')

In [6]:
from utils.metrics import AverageMeter, calculate_auc, multilabel_logit_accuracy
from torchmetrics.classification import MultilabelRecall, MultilabelSpecificity, MultilabelPrecision

def run_eval(model, dataloader, device, class_labels):
    model.eval()
    model = model.to(device)
    
    loss_meter = AverageMeter()
    auc_meter = AverageMeter()
    acc_meter = AverageMeter()

    specificity = MultilabelSpecificity(num_labels=len(class_labels)).to(device)
    spec_meter = AverageMeter()
    recall = MultilabelRecall(num_labels=len(class_labels)).to(device)
    rec_meter = AverageMeter()
    precision = MultilabelPrecision(num_labels=len(class_labels)).to(device)
    with torch.no_grad():
        for images, class_inds in dataloader:
            images, class_inds = images.to(device), class_inds.to(device)
            text_embeddings, image_embeddings = model(class_labels, images, pool=True)

            logits_per_text, logits_per_image = model.compute_logits(text_embeddings, image_embeddings)
            loss = model.contrastive_logit_loss(logits_per_text, logits_per_image, class_inds)
            loss_meter.update(loss.item(), len(class_inds))

            auc = calculate_auc(logits_per_image, class_inds)
            auc_meter.update(auc, len(class_inds))
            
            acc = multilabel_logit_accuracy(logits_per_image, class_inds)
            acc_meter.update(acc, len(class_inds))

            spec = specificity(logits_per_image, class_inds)
            spec_meter.update(spec.item(), len(class_inds))
            rec = recall(logits_per_image, class_inds)
            rec_meter.update(rec.item(), len(class_inds))
            prec = precision(logits_per_image, class_inds)
            print(f"Loss {loss} | Accuracy {acc} | AUC {auc} | Specificity {spec} | Recall {rec} | Precision {prec}")
            
    return acc_meter.average(), auc_meter.average(), loss_meter.average(), spec_meter.average(), rec_meter.average()

In [8]:
test_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'test']['image_id'].to_list(), config.label_names_map, config.classes_split_map['test'], mean_std=config.mean_std)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.model, test_loader, device, test_dataset.class_labels())

Loss 71.04728698730469 | Accuracy 0.4479591836734694 | AUC 0.5890015108344036 | Specificity 0.42932575941085815 | Recall 0.6371333003044128 | Precision 0.20914509892463684
Loss 75.24569702148438 | Accuracy 0.4887755102040816 | AUC 0.6385256910915584 | Specificity 0.45598727464675903 | Recall 0.6956271529197693 | Precision 0.2463836818933487
Loss 78.05477905273438 | Accuracy 0.4959183673469388 | AUC 0.5532589315705337 | Specificity 0.4835784435272217 | Recall 0.5971997380256653 | Precision 0.2374926656484604
Loss 79.44345092773438 | Accuracy 0.4357142857142857 | AUC 0.5604148251867628 | Specificity 0.3963735103607178 | Recall 0.6791589856147766 | Precision 0.2589060366153717
Loss 76.35698699951172 | Accuracy 0.4806122448979592 | AUC 0.5677557253138265 | Specificity 0.46071845293045044 | Recall 0.6279104351997375 | Precision 0.22663699090480804
Loss 71.93860626220703 | Accuracy 0.4479591836734694 | AUC 0.5371337493866897 | Specificity 0.44221925735473633 | Recall 0.5372446775436401 | Pre

(0.46784066296261423,
 0.572297766822187,
 74.54131226611908,
 0.44420735527200766,
 0.6445646255363661)

In [15]:
run_eval(mtrainer.best_model, test_loader, device, test_dataset.class_labels())

Loss 78.56269836425781 | Accuracy 0.49387755102040815 | AUC 0.5851240751768196 | Specificity 0.4700332283973694 | Recall 0.6316505670547485 | Precision 0.2542865574359894
Loss 72.25923919677734 | Accuracy 0.49795918367346936 | AUC 0.6096031818152909 | Specificity 0.4781932234764099 | Recall 0.6903427839279175 | Precision 0.23942965269088745
Loss 74.80560302734375 | Accuracy 0.49183673469387756 | AUC 0.6209842205929144 | Specificity 0.47085481882095337 | Recall 0.711128294467926 | Precision 0.23181715607643127
Loss 72.50208282470703 | Accuracy 0.4387755102040816 | AUC 0.5273632507381782 | Specificity 0.41236233711242676 | Recall 0.5766076445579529 | Precision 0.2051064372062683
Loss 73.56822967529297 | Accuracy 0.47551020408163264 | AUC 0.5694476210302147 | Specificity 0.4573631286621094 | Recall 0.6312494277954102 | Precision 0.22090637683868408
Loss 74.90245056152344 | Accuracy 0.4785714285714286 | AUC 0.5932214505670862 | Specificity 0.4474429488182068 | Recall 0.6691378355026245 | P

(0.4729258875600339,
 0.573232993531073,
 74.3792878932362,
 0.45160263533174166,
 0.6348146872058383)