In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler


configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14

query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))


PROJ_SIZE = 512
device =  get_device()

encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
encoder.text_model.device = device
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=4, cls_weight=0.1)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
torch.autograd.set_detect_anomaly(True)

mtrainer.run_train(6, support_loader, query_loader, lr=2e-5, full_training=True)

  classes = torch.nonzero(label_inds)[:,1] # (Np,)


Batch 1: loss 72.8689956665039
Batch 2: loss 72.80627059936523
Batch 3: loss 71.58649190266927
Batch 4: loss 71.17413139343262
Batch 5: loss 70.92964019775391
Batch 6: loss 70.54846700032552
Batch 7: loss 70.33934020996094
Batch 8: loss 70.03028297424316
Batch 9: loss 69.64961496988933
Batch 10: loss 69.72439651489258
Batch 11: loss 69.76450070467862
Batch 12: loss 69.60139274597168
Batch 13: loss 69.37254450871394
Batch 14: loss 69.36271340506417
Batch 15: loss 69.18506266276042
Batch 16: loss 69.04267692565918
Batch 17: loss 69.0490789974437
Batch 18: loss 69.01732677883572
Batch 19: loss 68.91317347476357
Batch 20: loss 68.88681983947754
Batch 21: loss 68.79013897123791
Batch 22: loss 68.6995058926669
Batch 23: loss 68.62465236497962
Batch 24: loss 68.50263595581055
Batch 25: loss 68.34710876464844
Batch 26: loss 68.30223054152269
Batch 27: loss 68.21710261592159
Batch 28: loss 68.03733021872384
Batch 29: loss 67.97193237830852
Batch 30: loss 67.89395001729329
Batch 31: loss 67.8639

In [6]:
mtrainer.run_eval(mtrainer.best_model, query_loader)

(0.7689795918367346, 1.0, 13.967295932769776)

In [10]:
torch.save(mtrainer.model.attention, 'models/attention/model/vindr2/full/attention-model8h4l.pth')

In [11]:
torch.save(mtrainer.model.encoder, 'models/attention/model/vindr2/full/imgtxt-encoder.pth')

In [7]:
from utils.metrics import AverageMeter, calculate_auc, multilabel_logit_accuracy
from torchmetrics.classification import MultilabelRecall, MultilabelSpecificity, MultilabelPrecision
from models.attention.model import image_text_logits

def run_eval(model, dataloader, device, class_labels):
    model.eval()
    model = model.to(device)
    
    loss_meter = AverageMeter()
    auc_meter = AverageMeter()
    acc_meter = AverageMeter()

    specificity = MultilabelSpecificity(num_labels=len(class_labels)).to(device)
    spec_meter = AverageMeter()
    recall = MultilabelRecall(num_labels=len(class_labels)).to(device)
    rec_meter = AverageMeter()
    precision = MultilabelPrecision(num_labels=len(class_labels)).to(device)
    with torch.no_grad():
         for images, class_inds in dataloader:
                images, class_inds = images.to(device), class_inds.to(device)

                text_embeddings, _, prototypes = model(class_labels, images, class_inds)

                logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                loss = model.attention.contrastive_loss(prototypes, class_inds).item()
                loss += model.attention.classification_loss(logits_per_image, class_inds).item()
        
                loss_meter.update(loss, len(class_inds))

                auc = calculate_auc(logits_per_image, class_inds)
                auc_meter.update(auc, len(class_inds))
            
                acc = multilabel_logit_accuracy(logits_per_image, class_inds)
                acc_meter.update(acc, len(class_inds))

                spec = specificity(logits_per_image, class_inds)
                spec_meter.update(spec.item(), len(class_inds))
                rec = recall(logits_per_image, class_inds)
                rec_meter.update(rec.item(), len(class_inds))
                prec = precision(logits_per_image, class_inds)
                print(f"Loss {loss} | Accuracy {acc} | AUC {auc} | Specificity {spec} | Recall {rec} | Precision {prec}")
            
    return acc_meter.average(), auc_meter.average(), loss_meter.average(), spec_meter.average(), rec_meter.average()

In [8]:
test_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'test']['image_id'].to_list(), config.label_names_map, config.classes_split_map['test'], mean_std=config.mean_std)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.model, test_loader, device, test_dataset.class_labels())

Loss 6.035398006439209 | Accuracy 0.9163265306122449 | AUC 0.9601097442520172 | Specificity 0.8863561153411865 | Recall 1.0 | Precision 0.7190194129943848
Loss 6.016775846481323 | Accuracy 0.9020408163265307 | AUC 0.950413918427783 | Specificity 0.8675279021263123 | Recall 1.0 | Precision 0.6775742173194885
Loss nan | Accuracy 0.9061224489795918 | AUC 0.9621225231408158 | Specificity 0.8780972361564636 | Recall 1.0 | Precision 0.6373937726020813
Loss 5.877500772476196 | Accuracy 0.9051020408163265 | AUC 0.9520963027023325 | Specificity 0.8736120462417603 | Recall 1.0 | Precision 0.6980288028717041
Loss 5.879352807998657 | Accuracy 0.9153061224489796 | AUC 0.9580760437786499 | Specificity 0.888192355632782 | Recall 1.0 | Precision 0.6855992674827576
Loss 5.763930082321167 | Accuracy 0.9122448979591836 | AUC 0.9568867816190046 | Specificity 0.8859713673591614 | Recall 1.0 | Precision 0.7168556451797485
Loss 5.975746512413025 | Accuracy 0.9071428571428571 | AUC 0.9517384316075806 | Specif

(0.9079009322911763, 0.956451851131788, nan, 0.8774896472712135, 1.0)

In [9]:
run_eval(mtrainer.model, query_loader, device, query_dataset.class_labels())

Loss 11.984139442443848 | Accuracy 0.9978571428571429 | AUC 0.999432338137726 | Specificity 0.991599440574646 | Recall 1.0 | Precision 0.9951183795928955
Loss 11.039484977722168 | Accuracy 0.9957142857142857 | AUC 0.9983705482192878 | Specificity 0.9866666793823242 | Recall 1.0 | Precision 0.9876745939254761
Loss 12.2591233253479 | Accuracy 0.9978571428571429 | AUC 0.995954538069212 | Specificity 0.98682701587677 | Recall 1.0 | Precision 0.9947980046272278
Loss 11.675161838531494 | Accuracy 1.0 | AUC 1.0 | Specificity 1.0 | Recall 1.0 | Precision 1.0


(0.9976000000000002,
 0.9982520788393433,
 11.747195262908935,
 0.9902260780334473,
 1.0)