In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel, LabelImageMHAttention
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler
from utils.f1_loss import BalAccuracyLoss


configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*14

query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))


PROJ_SIZE = 512
device =  get_device()

encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
encoder.text_model.device = device


In [2]:
attn_model = LabelImageMHAttention(PROJ_SIZE, 8, cls_weight=5, cls_loss=BalAccuracyLoss(), device=device)
# attn_model = torch.load('models/attention/model/vindr2/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=4, cls_weight=5, cls_loss=BalAccuracyLoss())
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
attn_model = torch.load('models/attention/model/vindr2/attention-trans-8h4l.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [37]:
torch.autograd.set_detect_anomaly(True)

mtrainer.run_train(2, support_loader, query_loader, lr=1e-4, min_lr=5e-6, full_training=True, enc_weight=0.05)
# mtrainer.run_train(2, support_loader, query_loader, lr=1e-4, min_lr=5e-6)
# mtrainer.run_train(2, support_loader, query_loader, lr=2e-5, min_lr=5e-6, full_training=True, enc_weight=0.05)

Batch 1: loss 12.53753662109375
Batch 2: loss 12.347467422485352
Batch 3: loss 12.59437370300293
Batch 4: loss 12.551023483276367
Batch 5: loss 12.562271118164062
Batch 6: loss 12.814994812011719
Batch 7: loss 12.426860809326172
Batch 8: loss 13.130172729492188
Batch 9: loss 12.404001235961914
Batch 10: loss 12.426783561706543
Batch 11: loss 12.2003755569458
Batch 12: loss 12.732625961303711
Batch 13: loss 13.251863479614258
Batch 14: loss 13.03114128112793
Batch 15: loss 12.641311645507812
Batch 16: loss 12.84710693359375
Batch 17: loss 12.50998592376709
Batch 18: loss 12.866361618041992
Batch 19: loss 12.923351287841797
Batch 20: loss 12.81854248046875
Batch 21: loss 12.55801773071289
Batch 22: loss 12.690763473510742
Batch 23: loss 12.739566802978516
Batch 24: loss 12.877098083496094
Batch 25: loss 12.657647132873535
Batch 26: loss 12.888240814208984
Batch 27: loss 12.23486328125
Batch 28: loss 12.315044403076172
Batch 29: loss 12.533363342285156
Batch 30: loss 12.814544677734375
Ba

In [12]:
import copy
mtrainer.model = copy.deepcopy(mtrainer.best_model)

In [None]:
torch.autograd.set_detect_anomaly(True)

# mtrainer.run_train(6, support_loader, query_loader, lr=2e-5, full_training=True)
# mtrainer.run_train(10, support_loader, query_loader, lr=2e-5)
mtrainer.run_train(2, support_loader, query_loader, lr=1e-4, min_lr=5e-6)

In [39]:
torch.save(mtrainer.model.attention, 'models/attention/model/vindr2/full/attention-trans-8h4l-ft-full.pth')
torch.save(mtrainer.model.encoder, 'models/attention/model/vindr2/full/imgtxt-encoder-ft-full.pth')

In [31]:
from utils.metrics import AverageMeter, calculate_auc, multilabel_logit_accuracy
from torchmetrics.classification import MultilabelRecall, MultilabelSpecificity, MultilabelPrecision, MultilabelF1Score
from models.attention.model import image_text_logits

def run_eval(model, dataloader, device, class_labels):
    model.eval()
    model = model.to(device)
    
    # auc_meter = AverageMeter()
    acc_meter = AverageMeter()
    spec_meter = AverageMeter()
    rec_meter = AverageMeter()
    f1_meter = AverageMeter()

    num_labels = len(class_labels)
    specificity = MultilabelSpecificity(num_labels=num_labels).to(device)
    recall = MultilabelRecall(num_labels=num_labels).to(device)
    precision = MultilabelPrecision(num_labels=num_labels).to(device)
    f1_func = MultilabelF1Score(num_labels=num_labels).to(device)
    with torch.no_grad():
         for images, class_inds in dataloader:
                images, class_inds = images.to(device), class_inds.to(device)

                text_embeddings, _, prototypes = model(class_labels, images)

                logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                
                f1 = f1_func(logits_per_image, class_inds)
                f1_meter.update(f1.item(), len(class_inds))

                auc = calculate_auc(logits_per_image, class_inds)
            
                acc = multilabel_logit_accuracy(logits_per_image, class_inds)
                acc_meter.update(acc, len(class_inds))

                spec = specificity(logits_per_image, class_inds)
                spec_meter.update(spec.item(), len(class_inds))
                rec = recall(logits_per_image, class_inds)
                rec_meter.update(rec.item(), len(class_inds))
                prec = precision(logits_per_image, class_inds)
                print(f"F1 {f1} | Accuracy {acc} | Specificity {spec} | Recall {rec} | Precision {prec} | AUC {auc}")
            
    return f1_meter.average(), acc_meter.average(), spec_meter.average(), rec_meter.average()

In [16]:
val_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'val']['image_id'].to_list(), config.label_names_map, config.classes_split_map['val'], mean_std=config.mean_std)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.model, val_loader, device, val_dataset.class_labels())

F1 0.2835038900375366 | Accuracy 0.523469387755102 | Specificity 0.5099503993988037 | Recall 0.5991833209991455 | Precision 0.2197861671447754
F1 0.28325724601745605 | Accuracy 0.5244897959183673 | Specificity 0.5086755752563477 | Recall 0.6203547716140747 | Precision 0.2166588455438614
F1 0.2844228148460388 | Accuracy 0.5122448979591837 | Specificity 0.4827204942703247 | Recall 0.6014989614486694 | Precision 0.21262170374393463
F1 0.2793104946613312 | Accuracy 0.513265306122449 | Specificity 0.49325329065322876 | Recall 0.5948232412338257 | Precision 0.2130517065525055
F1 0.2766838073730469 | Accuracy 0.5265306122448979 | Specificity 0.5137312412261963 | Recall 0.5460197925567627 | Precision 0.21771463751792908
F1 0.28565657138824463 | Accuracy 0.513265306122449 | Specificity 0.48942142724990845 | Recall 0.5607023239135742 | Precision 0.21876876056194305
F1 0.2999957799911499 | Accuracy 0.5173469387755102 | Specificity 0.499584436416626 | Recall 0.6280947327613831 | Precision 0.233812

(0.2816464371240171, 0.516277288395251, 0.4981172612141668, 0.593422002351316)

In [38]:
run_eval(mtrainer.model, val_loader, device, val_dataset.class_labels())

F1 0.28466761112213135 | Accuracy 0.4816326530612245 | Specificity 0.43771088123321533 | Recall 0.5836288928985596 | Precision 0.22124546766281128 | AUC 0.5081570134984049
F1 0.30008992552757263 | Accuracy 0.4897959183673469 | Specificity 0.44034457206726074 | Recall 0.6981704235076904 | Precision 0.22444313764572144 | AUC 0.5500166015801123
F1 0.29914581775665283 | Accuracy 0.48367346938775513 | Specificity 0.4301510453224182 | Recall 0.6670995950698853 | Precision 0.2211931347846985 | AUC 0.5345979288112909
F1 0.27294325828552246 | Accuracy 0.4683673469387755 | Specificity 0.42632099986076355 | Recall 0.5648905038833618 | Precision 0.21171772480010986 | AUC 0.49221297037974016
F1 0.3053981065750122 | Accuracy 0.49795918367346936 | Specificity 0.4555540382862091 | Recall 0.6629575490951538 | Precision 0.24609719216823578 | AUC 0.5427515004231926
F1 0.30428019165992737 | Accuracy 0.4887755102040816 | Specificity 0.438229501247406 | Recall 0.6809741258621216 | Precision 0.23644500970840

(0.29260882919020054,
 0.4860206817311375,
 0.4409013297059261,
 0.6473267725581459)

In [33]:
# val_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'val']['image_id'].to_list(), config.label_names_map, config.classes_split_map['val'], mean_std=config.mean_std)
# val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.best_model, val_loader, device, val_dataset.class_labels())

F1 0.3168458044528961 | Accuracy 0.5377551020408163 | Specificity 0.5009139776229858 | Recall 0.6474238634109497 | Precision 0.24894434213638306 | AUC 0.6040259751781537
F1 0.32223689556121826 | Accuracy 0.5326530612244897 | Specificity 0.4862861633300781 | Recall 0.7088021636009216 | Precision 0.24293038249015808 | AUC 0.596529608164046
F1 0.32525309920310974 | Accuracy 0.5061224489795918 | Specificity 0.45004671812057495 | Recall 0.6912938356399536 | Precision 0.2510389983654022 | AUC 0.5747561257302763
F1 0.2913922965526581 | Accuracy 0.4969387755102041 | Specificity 0.4636285901069641 | Recall 0.6402331590652466 | Precision 0.2520001530647278 | AUC 0.5648150129288486
F1 0.294315904378891 | Accuracy 0.5112244897959184 | Specificity 0.461722195148468 | Recall 0.6752387285232544 | Precision 0.21453696489334106 | AUC 0.6065127145629111
F1 0.2989193797111511 | Accuracy 0.48775510204081635 | Specificity 0.4344598948955536 | Recall 0.6043060421943665 | Precision 0.24122211337089539 | AUC 

(0.3061240766866597, 0.5095748755266181, 0.4628647585017432, 0.651173084414996)

In [34]:
test_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'test']['image_id'].to_list(), config.label_names_map, config.classes_split_map['test'], mean_std=config.mean_std)
test_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.model, test_loader, device, test_dataset.class_labels())

# (0.21328095488509927,
#  0.4927230945997703,
#  0.5013440553368577,
#  0.5060264142205184)

F1 0.2427985966205597 | Accuracy 0.5673469387755102 | Specificity 0.5913485884666443 | Recall 0.506176233291626 | Precision 0.18456903100013733 | AUC 0.524497932764934
F1 0.2210523635149002 | Accuracy 0.5112244897959184 | Specificity 0.5201810598373413 | Recall 0.48085007071495056 | Precision 0.16113218665122986 | AUC 0.5275451708371257
F1 0.2002246230840683 | Accuracy 0.5010204081632653 | Specificity 0.5285993814468384 | Recall 0.4356781542301178 | Precision 0.1515847146511078 | AUC 0.5245683634473359
F1 0.19169631600379944 | Accuracy 0.513265306122449 | Specificity 0.545072615146637 | Recall 0.43255993723869324 | Precision 0.1460404396057129 | AUC 0.5006397961552599
F1 0.19519956409931183 | Accuracy 0.503061224489796 | Specificity 0.5268279314041138 | Recall 0.39201992750167847 | Precision 0.1536368876695633 | AUC 0.4972520023913885
F1 0.21040309965610504 | Accuracy 0.55 | Specificity 0.5807123780250549 | Recall 0.4601671099662781 | Precision 0.16845354437828064 | AUC 0.5243045957048

(0.20526467239728882,
 0.5244159325928763,
 0.5495938714643586,
 0.44161645640636576)

In [35]:
run_eval(mtrainer.best_model, test_loader, device, test_dataset.class_labels())

# (0.22087397899768307,
#  0.4117196476445806,
#  0.39774811515539643,
#  0.6270336436841827)

F1 0.1965721845626831 | Accuracy 0.5306122448979592 | Specificity 0.5559758543968201 | Recall 0.46655580401420593 | Precision 0.15712666511535645 | AUC 0.5584646248623403
F1 0.21839505434036255 | Accuracy 0.5346938775510204 | Specificity 0.5573030114173889 | Recall 0.4745098352432251 | Precision 0.16633214056491852 | AUC 0.5407466953886637
F1 0.20195996761322021 | Accuracy 0.5214285714285715 | Specificity 0.5496224164962769 | Recall 0.4608705937862396 | Precision 0.15436065196990967 | AUC 0.4915952867996541
F1 0.19545826315879822 | Accuracy 0.5204081632653061 | Specificity 0.5526095628738403 | Recall 0.4603337049484253 | Precision 0.1460273265838623 | AUC 0.5267423984203196
F1 0.21582238376140594 | Accuracy 0.5265306122448979 | Specificity 0.5333627462387085 | Recall 0.5011880993843079 | Precision 0.155477374792099 | AUC 0.5614997706100538
F1 0.19238397479057312 | Accuracy 0.5153061224489796 | Specificity 0.5435122847557068 | Recall 0.5054166913032532 | Precision 0.14201249182224274 | 

(0.2044379652824862,
 0.5244159325928762,
 0.5496255138604315,
 0.45490357351686617)

## Zero-shot evaluations on validation and test sets

In [1]:
from utils.baseline import run_baseline_eval
from models.metaclassifier.utils import set_seed
import numpy as np
import torch
import random

from utils.device import get_device
from utils.baseline import run_baseline_eval
from utils.data import DatasetConfig
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT2

from models.backbone.datasets import MEAN_STDS
from models.metaclassifier.trainer import create_baseline_eval_dataloaders

from utils.baseline import run_attention_zeroshot_eval
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel, LabelImageMHAttention



dataset_config = DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
device =  get_device()
NUM_LABELS = 7
NUM_QUERY = 10
set_seed(420)
PROJ_SIZE = 512

query_loader, val_loader, test_loader = create_baseline_eval_dataloaders(dataset_config, NUM_LABELS, NUM_QUERY)

In [2]:
encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
attn_model = torch.load('models/attention/model/vindr2/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, 512, attn_model=attn_model)

In [3]:
seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.5530612468719482 | F1 0.32198143005371094 | Specificity 0.5547138452529907 | Recall 0.5399669408798218 | Precision 0.2254353016614914 | Bal Acc 0.5473403930664062
Episode 2 | Raw Acc 0.5306122899055481 | F1 0.33139535784721375 | Specificity 0.5171219110488892 | Recall 0.5765469074249268 | Precision 0.2458820343017578 | Bal Acc 0.546834409236908
Episode 3 | Raw Acc 0.5795918703079224 | F1 0.38323354721069336 | Specificity 0.580405592918396 | Recall 0.5460691452026367 | Precision 0.28861716389656067 | Bal Acc 0.5632373690605164
Episode 4 | Raw Acc 0.5102040767669678 | F1 0.375 | Specificity 0.4610194265842438 | Recall 0.6753779649734497 | Precision 0.2609666585922241 | Bal Acc 0.5681986808776855
Episode 5 | Raw Acc 0.5 | F1 0.3287671208381653 | Specificity 0.468084454536438 | Recall 0.6140708923339844 | Precision 0.22463200986385345 | Bal Acc 0.5410776734352112
Episode 6 | Raw Acc 0.5244898200035095 | F1 0.3786666691303253 | Specificity 0.47696834802627563 | Recall 

In [None]:
encoder = torch.load('models/attention/model/vindr2/full-mh/imgtxt-encoder.pth')
attn_model = torch.load('models/attention/model/vindr2/full-mh/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)

In [None]:
seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    print('val')
    print(run_img_txt_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_img_txt_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

In [2]:
encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
attn_model = torch.load('models/attention/model/vindr2/attention-trans-8h4l.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)

In [3]:
seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.5653061866760254 | F1 0.27796611189842224 | Specificity 0.6041334271430969 | Recall 0.448091983795166 | Precision 0.22017502784729004 | Bal Acc 0.5261126756668091
Episode 2 | Raw Acc 0.5897959470748901 | F1 0.3699059784412384 | Specificity 0.5960823893547058 | Recall 0.5970055460929871 | Precision 0.31974154710769653 | Bal Acc 0.5965439677238464
Episode 3 | Raw Acc 0.5959184169769287 | F1 0.3571428656578064 | Specificity 0.6301459074020386 | Recall 0.48106446862220764 | Precision 0.29574093222618103 | Bal Acc 0.5556051731109619
Episode 4 | Raw Acc 0.5612245202064514 | F1 0.32176655530929565 | Specificity 0.588890552520752 | Recall 0.5079894065856934 | Precision 0.2506626844406128 | Bal Acc 0.5484399795532227
Episode 5 | Raw Acc 0.5510204434394836 | F1 0.31677019596099854 | Specificity 0.564183235168457 | Recall 0.5474938154220581 | Precision 0.2425869107246399 | Bal Acc 0.5558385252952576
Episode 6 | Raw Acc 0.5693877935409546 | F1 0.3215433955192566 | Specificity

In [None]:
encoder = torch.load('models/attention/model/vindr2/full/imgtxt-encoder-ft-full.pth')
attn_model = torch.load('models/attention/model/vindr2/full/attention-trans-8h4l-ft-full.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)

In [None]:
seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))