In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel, LabelImageMHAttention
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler
from utils.f1_loss import BalAccuracyLoss


configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*14

query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))


PROJ_SIZE = 512
device =  get_device()

encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
encoder.text_model.device = device


In [2]:
attn_model = LabelImageMHAttention(PROJ_SIZE, 8, cls_weight=5, cls_loss=BalAccuracyLoss(), device=device)
# attn_model = torch.load('models/attention/model/vindr2/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=4, cls_weight=5, cls_loss=BalAccuracyLoss())
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
attn_model = torch.load('models/attention/model/vindr2/attention-trans-8h4l.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [37]:
torch.autograd.set_detect_anomaly(True)

mtrainer.run_train(2, support_loader, query_loader, lr=1e-4, min_lr=5e-6, full_training=True, enc_weight=0.05)
# mtrainer.run_train(2, support_loader, query_loader, lr=1e-4, min_lr=5e-6)
# mtrainer.run_train(2, support_loader, query_loader, lr=2e-5, min_lr=5e-6, full_training=True, enc_weight=0.05)

Batch 1: loss 12.53753662109375
Batch 2: loss 12.347467422485352
Batch 3: loss 12.59437370300293
Batch 4: loss 12.551023483276367
Batch 5: loss 12.562271118164062
Batch 6: loss 12.814994812011719
Batch 7: loss 12.426860809326172
Batch 8: loss 13.130172729492188
Batch 9: loss 12.404001235961914
Batch 10: loss 12.426783561706543
Batch 11: loss 12.2003755569458
Batch 12: loss 12.732625961303711
Batch 13: loss 13.251863479614258
Batch 14: loss 13.03114128112793
Batch 15: loss 12.641311645507812
Batch 16: loss 12.84710693359375
Batch 17: loss 12.50998592376709
Batch 18: loss 12.866361618041992
Batch 19: loss 12.923351287841797
Batch 20: loss 12.81854248046875
Batch 21: loss 12.55801773071289
Batch 22: loss 12.690763473510742
Batch 23: loss 12.739566802978516
Batch 24: loss 12.877098083496094
Batch 25: loss 12.657647132873535
Batch 26: loss 12.888240814208984
Batch 27: loss 12.23486328125
Batch 28: loss 12.315044403076172
Batch 29: loss 12.533363342285156
Batch 30: loss 12.814544677734375
Ba

In [12]:
import copy
mtrainer.model = copy.deepcopy(mtrainer.best_model)

In [None]:
torch.autograd.set_detect_anomaly(True)

# mtrainer.run_train(6, support_loader, query_loader, lr=2e-5, full_training=True)
# mtrainer.run_train(10, support_loader, query_loader, lr=2e-5)
mtrainer.run_train(2, support_loader, query_loader, lr=1e-4, min_lr=5e-6)

In [39]:
torch.save(mtrainer.model.attention, 'models/attention/model/vindr2/full/attention-trans-8h4l-ft-full.pth')
torch.save(mtrainer.model.encoder, 'models/attention/model/vindr2/full/imgtxt-encoder-ft-full.pth')

## Zero-shot evaluations on validation and test sets

In [1]:
from utils.baseline import run_baseline_eval
from models.metaclassifier.utils import set_seed
import numpy as np
import torch
import random

from utils.device import get_device
from utils.baseline import run_baseline_eval
from utils.data import DatasetConfig
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT2

from models.backbone.datasets import MEAN_STDS
from models.metaclassifier.trainer import create_baseline_eval_dataloaders

from utils.baseline import run_attention_zeroshot_eval
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel, LabelImageMHAttention



dataset_config = DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
device =  get_device()
NUM_LABELS = 7
NUM_QUERY = 10
set_seed(420)
PROJ_SIZE = 512

query_loader, val_loader, test_loader = create_baseline_eval_dataloaders(dataset_config, NUM_LABELS, NUM_QUERY)

In [2]:
encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
attn_model = torch.load('models/attention/model/vindr2/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, 512, attn_model=attn_model)

In [3]:
seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    set_seed(s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.5530612468719482 | F1 0.32198143005371094 | Specificity 0.5547138452529907 | Recall 0.5399669408798218 | Precision 0.2254353016614914 | Bal Acc 0.5473403930664062
Episode 2 | Raw Acc 0.5306122899055481 | F1 0.33139535784721375 | Specificity 0.5171219110488892 | Recall 0.5765469074249268 | Precision 0.2458820343017578 | Bal Acc 0.546834409236908
Episode 3 | Raw Acc 0.5795918703079224 | F1 0.38323354721069336 | Specificity 0.580405592918396 | Recall 0.5460691452026367 | Precision 0.28861716389656067 | Bal Acc 0.5632373690605164
Episode 4 | Raw Acc 0.5102040767669678 | F1 0.375 | Specificity 0.4610194265842438 | Recall 0.6753779649734497 | Precision 0.2609666585922241 | Bal Acc 0.5681986808776855
Episode 5 | Raw Acc 0.5 | F1 0.3287671208381653 | Specificity 0.468084454536438 | Recall 0.6140708923339844 | Precision 0.22463200986385345 | Bal Acc 0.5410776734352112
Episode 6 | Raw Acc 0.5244898200035095 | F1 0.3786666691303253 | Specificity 0.47696834802627563 | Recall 

In [2]:
encoder = torch.load('models/attention/model/vindr2/full-mh/imgtxt-encoder.pth')
attn_model = torch.load('models/attention/model/vindr2/full-mh/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)

In [3]:
seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    set_seed(s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.5448979735374451 | F1 0.3343283534049988 | Specificity 0.5388146042823792 | Recall 0.5750263929367065 | Precision 0.23949380218982697 | Bal Acc 0.5569205284118652
Episode 2 | Raw Acc 0.5408163666725159 | F1 0.38016530871391296 | Specificity 0.5070095658302307 | Recall 0.6919474601745605 | Precision 0.277873158454895 | Bal Acc 0.5994784832000732
Episode 3 | Raw Acc 0.56734699010849 | F1 0.369047611951828 | Specificity 0.5785291790962219 | Recall 0.551344633102417 | Precision 0.3047148883342743 | Bal Acc 0.5649368762969971
Episode 4 | Raw Acc 0.5204082131385803 | F1 0.352617084980011 | Specificity 0.5080944895744324 | Recall 0.6487962007522583 | Precision 0.267536461353302 | Bal Acc 0.578445315361023
Episode 5 | Raw Acc 0.47755104303359985 | F1 0.31182795763015747 | Specificity 0.45249027013778687 | Recall 0.6196742057800293 | Precision 0.22267144918441772 | Bal Acc 0.5360822677612305
Episode 6 | Raw Acc 0.48367348313331604 | F1 0.3217158317565918 | Specificity 0.46

In [2]:
encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
attn_model = torch.load('models/attention/model/vindr2/attention-trans-8h4l.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)

In [3]:
seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    set_seed(s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.5653061866760254 | F1 0.27796611189842224 | Specificity 0.6041334271430969 | Recall 0.448091983795166 | Precision 0.22017502784729004 | Bal Acc 0.5261126756668091
Episode 2 | Raw Acc 0.5897959470748901 | F1 0.3699059784412384 | Specificity 0.5960823893547058 | Recall 0.5970055460929871 | Precision 0.31974154710769653 | Bal Acc 0.5965439677238464
Episode 3 | Raw Acc 0.5959184169769287 | F1 0.3571428656578064 | Specificity 0.6301459074020386 | Recall 0.48106446862220764 | Precision 0.29574093222618103 | Bal Acc 0.5556051731109619
Episode 4 | Raw Acc 0.5612245202064514 | F1 0.32176655530929565 | Specificity 0.588890552520752 | Recall 0.5079894065856934 | Precision 0.2506626844406128 | Bal Acc 0.5484399795532227
Episode 5 | Raw Acc 0.5510204434394836 | F1 0.31677019596099854 | Specificity 0.564183235168457 | Recall 0.5474938154220581 | Precision 0.2425869107246399 | Bal Acc 0.5558385252952576
Episode 6 | Raw Acc 0.5693877935409546 | F1 0.3215433955192566 | Specificity

In [2]:
encoder = torch.load('models/attention/model/vindr2/full/imgtxt-encoder-ft-full.pth')
attn_model = torch.load('models/attention/model/vindr2/full/attention-trans-8h4l-ft-full.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)

In [3]:
seeds = [42, 142, 321]
for s in seeds:
    print('seed', s)
    set_seed(s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.5020408630371094 | F1 0.3333333134651184 | Specificity 0.46862730383872986 | Recall 0.6173743009567261 | Precision 0.242111474275589 | Bal Acc 0.5430008172988892
Episode 2 | Raw Acc 0.4612245261669159 | F1 0.3465346395969391 | Specificity 0.3955380320549011 | Recall 0.6895924806594849 | Precision 0.2627190351486206 | Bal Acc 0.5425652265548706
Episode 3 | Raw Acc 0.5224490165710449 | F1 0.3675675690174103 | Specificity 0.5028469562530518 | Recall 0.5801961421966553 | Precision 0.2890571355819702 | Bal Acc 0.5415215492248535
Episode 4 | Raw Acc 0.4653061628341675 | F1 0.3671497404575348 | Specificity 0.3906916081905365 | Recall 0.7264384031295776 | Precision 0.2705789804458618 | Bal Acc 0.5585650205612183
Episode 5 | Raw Acc 0.4183673858642578 | F1 0.326241135597229 | Specificity 0.3451482653617859 | Recall 0.7029538750648499 | Precision 0.22527645528316498 | Bal Acc 0.5240510702133179
Episode 6 | Raw Acc 0.4632653295993805 | F1 0.36626505851745605 | Specificity 0.

In [2]:
encoder = torch.load('models/attention/model/vindr2/full/imgtxt-encoder-ft.pth')
attn_model = torch.load('models/attention/model/vindr2/full/attention-trans-8h4l-ft.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)

In [3]:
print('val')
print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
print('test')
print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.5081632733345032 | F1 0.3397260308265686 | Specificity 0.4720962345600128 | Recall 0.6276638507843018 | Precision 0.2528010904788971 | Bal Acc 0.5498800277709961
Episode 2 | Raw Acc 0.4632653594017029 | F1 0.35061728954315186 | Specificity 0.39494025707244873 | Recall 0.6954464912414551 | Precision 0.27105703949928284 | Bal Acc 0.5451933741569519
Episode 3 | Raw Acc 0.548979640007019 | F1 0.3844011127948761 | Specificity 0.5313406586647034 | Recall 0.6014472842216492 | Precision 0.3104356527328491 | Bal Acc 0.5663939714431763
Episode 4 | Raw Acc 0.47755104303359985 | F1 0.3663366138935089 | Specificity 0.4125639498233795 | Recall 0.7194840908050537 | Precision 0.30700069665908813 | Bal Acc 0.5660240054130554
Episode 5 | Raw Acc 0.44285720586776733 | F1 0.3546099364757538 | Specificity 0.35959339141845703 | Recall 0.7614930868148804 | Precision 0.2682323157787323 | Bal Acc 0.5605432391166687
Episode 6 | Raw Acc 0.4632653594017029 | F1 0.3441396653652191 | Specifici

In [3]:
seeds = [142, 321]
for s in seeds:
    set_seed(s)
    print('seed', s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 142
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.47755104303359985 | F1 0.37560975551605225 | Specificity 0.40819215774536133 | Recall 0.7292343378067017 | Precision 0.3149920701980591 | Bal Acc 0.5687132477760315
Episode 2 | Raw Acc 0.48367348313331604 | F1 0.3324538469314575 | Specificity 0.444195419549942 | Recall 0.5974209904670715 | Precision 0.25814661383628845 | Bal Acc 0.520808219909668
Episode 3 | Raw Acc 0.4591836929321289 | F1 0.3552311360836029 | Specificity 0.38685503602027893 | Recall 0.6862245798110962 | Precision 0.26260530948638916 | Bal Acc 0.5365397930145264
Episode 4 | Raw Acc 0.4306122660636902 | F1 0.32445523142814636 | Specificity 0.36008352041244507 | Recall 0.7497295141220093 | Precision 0.24204161763191223 | Bal Acc 0.5549064874649048
Episode 5 | Raw Acc 0.5204081535339355 | F1 0.3766578137874603 | Specificity 0.47640126943588257 | Recall 0.6472428441047668 | Precision 0.2946084141731262 | Bal Acc 0.5618220567703247
Episode 6 | Raw Acc 0.4959183931350708 | F1 0.35170602798461914 | Speci

In [2]:
encoder = torch.load('models/attention/model/vindr2/full/imgtxt-encoder-ft1.pth')
attn_model = torch.load('models/attention/model/vindr2/full/attention-trans-8h4l-ft1.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)

In [3]:
seeds = [42, 142]
for s in seeds:
    print('seed', s)
    set_seed(s)
    print('val')
    print(run_attention_zeroshot_eval(model, val_loader, NUM_LABELS, device, episodes=60))
    print('test')
    print(run_attention_zeroshot_eval(model, test_loader, NUM_LABELS, device, episodes=60))

seed 42
val


  output, inverse_indices, counts = torch._unique2(


Episode 1 | Raw Acc 0.5224490165710449 | F1 0.31976744532585144 | Specificity 0.510280191898346 | Recall 0.5605933666229248 | Precision 0.2342972755432129 | Bal Acc 0.535436749458313
Episode 2 | Raw Acc 0.48775514960289 | F1 0.35475578904151917 | Specificity 0.4328639507293701 | Recall 0.6853774189949036 | Precision 0.26729899644851685 | Bal Acc 0.5591206550598145
Episode 3 | Raw Acc 0.5530612468719482 | F1 0.3830986022949219 | Specificity 0.5405319333076477 | Recall 0.5841643810272217 | Precision 0.29749512672424316 | Bal Acc 0.5623481273651123
Episode 4 | Raw Acc 0.5122449398040771 | F1 0.366047739982605 | Specificity 0.47028547525405884 | Recall 0.6736032962799072 | Precision 0.2734237313270569 | Bal Acc 0.5719443559646606
Episode 5 | Raw Acc 0.44693878293037415 | F1 0.33086422085762024 | Specificity 0.3862629532814026 | Recall 0.6943430304527283 | Precision 0.228959321975708 | Bal Acc 0.5403029918670654
Episode 6 | Raw Acc 0.4979591965675354 | F1 0.3819095492362976 | Specificity 0.