## Set up datasets and dataloaders

In [1]:
from utils.device import get_device
from utils.data import DatasetConfig
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS

IMG_PATH = 'datasets/vindr-cxr-png'

NUM_SHOTS = 5
NUM_WAYS = 7
TRAIN_NUM_WAYS= 7
dataset_config = DatasetConfig(IMG_PATH, 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist'])
device =  get_device()

In [2]:
import torch
import random
import numpy as np

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

## Models with attention
### Run experiments on proposed model

In [6]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.model import ClsModel


encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('models/attention/model/attention-model-8h4l.pth')
model = ClsModel(encoder, attention, 512, class_prototype_inf, fc_hidden_size=16)
mtrainer = ControlledMetaTrainer(model, NUM_SHOTS, NUM_WAYS, dataset_config, train_n_ways=TRAIN_NUM_WAYS, device=device)

In [None]:
mtrainer.run_eval(mtrainer.model, mtrainer.test_loader)


In [7]:
mtrainer.run_eval(mtrainer.model, mtrainer.val_loader)

(20.361516034985424, 0.5202569099543621, 0.7382618546485901)

In [9]:
# 5 times?
mtrainer.run_train(2, lr=1e-5)

Batch 1: loss 0.6902595162391663 | Acc 53.46938775510204
Batch 2: loss 0.6968538761138916 | Acc 42.44897959183673
Batch 3: loss 0.7003634770711263 | Acc 37.95918367346939
Batch 4: loss 0.6974166184663773 | Acc 53.46938775510204
Batch 5: loss 0.6979597449302674 | Acc 44.08163265306123
Batch 6: loss 0.6973948379357656 | Acc 48.97959183673469
Batch 7: loss 0.6978198204721723 | Acc 44.08163265306123
Batch 8: loss 0.6980315819382668 | Acc 45.714285714285715
Batch 9: loss 0.6969953179359436 | Acc 53.87755102040816
Batch 10: loss 0.6979582905769348 | Acc 39.183673469387756
Batch 11: loss 0.6966771320863203 | Acc 57.55102040816327
Batch 12: loss 0.6973831256230673 | Acc 40.0
Batch 13: loss 0.6974497345777658 | Acc 46.53061224489796
Batch 14: loss 0.6974168930734906 | Acc 46.93877551020408
Batch 15: loss 0.697468622525533 | Acc 44.89795918367347
Batch 16: loss 0.6972164176404476 | Acc 49.38775510204081
Batch 17: loss 0.6974642136517692 | Acc 43.26530612244898
Batch 18: loss 0.6976404885450999 |

In [12]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader)

(19.93031358885018, 0.507083518313992, 0.7321400235338908)

In [13]:
mtrainer.model.attn_model.set_trainable(True)

In [17]:
mtrainer.run_train(3, lr=1e-5)

Batch 1: loss 0.6971045732498169 | Acc 42.44897959183673
Batch 2: loss 0.6954399049282074 | Acc 48.16326530612245
Batch 3: loss 0.6974401871363322 | Acc 39.183673469387756
Batch 4: loss 0.696072444319725 | Acc 50.61224489795918
Batch 5: loss 0.693632709980011 | Acc 60.0
Batch 6: loss 0.6950256526470184 | Acc 37.55102040816327
Batch 7: loss 0.6943262900624957 | Acc 51.83673469387755
Batch 8: loss 0.6948138996958733 | Acc 42.44897959183673
Batch 9: loss 0.6948209206263224 | Acc 46.93877551020408
Batch 10: loss 0.6951384782791138 | Acc 43.26530612244898
Batch 11: loss 0.6944867806001143 | Acc 52.6530612244898
Batch 12: loss 0.6939454972743988 | Acc 51.42857142857142
Batch 13: loss 0.6941679395162142 | Acc 44.89795918367347
Batch 14: loss 0.6944387938295092 | Acc 46.12244897959184
Batch 15: loss 0.6944648385047912 | Acc 48.16326530612245
Batch 16: loss 0.6943309046328068 | Acc 47.3469387755102
Batch 17: loss 0.6938948526101953 | Acc 55.51020408163265
Batch 18: loss 0.6942294736703237 | Acc

In [18]:
# Model trained 3 epochs with lr=5e-6, 3 epochs with lr=1e-6, 3 epochs with lr=1e-5
mtrainer.run_eval(mtrainer.model, mtrainer.test_loader)

(78.01891488302637, 0.4961423906653175, 0.6495074789698531)

In [19]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader)

(79.21353907416626, 0.5099220733360005, 0.6617257754977156)

In [20]:
torch.save(mtrainer.best_model.cls.state_dict(), 'models/metaclassifier/model/comb3/cls_weights-16.pkl')

In [21]:
torch.save(mtrainer.best_model.attn_model, 'models/metaclassifier/model/comb3/attention-model-8h4l.pth')

### Prototypical Network with attention

In [3]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.base import euclidean_distance
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.model import ProtoNetAttention

device =  get_device()

encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('models/attention/model/attention-model-8h4l.pth')
# imgtxt_encoder, attn_model, class_prototype_aggregator, distance_func
model = ProtoNetAttention(encoder, attention, class_prototype_inf, euclidean_distance)
mtrainer = ControlledMetaTrainer(model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [4]:
mtrainer.model.encoder.set_trainable(True, True, include_logit_scale=False)

In [None]:
mtrainer.model.attn_model.set_trainable(False)

In [3]:
mtrainer.run_eval(model, mtrainer.test_loader, True)

  classes_count = torch.nonzero(label_inds)[:,1].bincount()


Loss 0.6663146018981934 | Accuracy 81.63265306122449 | AUC 0.5520890348476556
Loss 0.6675408482551575 | Accuracy 77.9591836734694 | AUC 0.592805177243311
Loss 0.6662786602973938 | Accuracy 80.81632653061224 | AUC 0.5031611300576818
Loss 0.6649683117866516 | Accuracy 80.0 | AUC 0.5192899818312419
Loss 0.6666518449783325 | Accuracy 80.0 | AUC 0.5661068620956023
Loss 0.665653645992279 | Accuracy 80.81632653061224 | AUC 0.5215153647665963
Loss 0.6630535125732422 | Accuracy 81.63265306122449 | AUC 0.48086569424007847
Loss 0.6627781987190247 | Accuracy 81.63265306122449 | AUC 0.5596055745021262
Loss 0.6620592474937439 | Accuracy 82.44897959183673 | AUC 0.45533174308684515
Loss 0.6656913161277771 | Accuracy 79.59183673469387 | AUC 0.5300442635824945
Loss 0.666347324848175 | Accuracy 81.22448979591836 | AUC 0.5316662740132128
Loss 0.6675103306770325 | Accuracy 79.18367346938776 | AUC 0.5914888911229517
Loss 0.6687719821929932 | Accuracy 77.9591836734694 | AUC 0.5190180989547633
Loss 0.66781353

(80.15928322548531, 0.5328198297988869, 0.6663182726720485)

In [4]:
mtrainer.run_eval(model, mtrainer.val_loader)

(79.37026239067053, 0.5161583489962621, 0.6703500832830157)

In [5]:
mtrainer.run_train(3, lr=1e-5)

  classes_count = torch.nonzero(label_inds)[:,1].bincount()


Batch 1: loss 0.6557514071464539 | Acc 0.6938775510204082
Batch 2: loss 0.6739354431629181 | Acc 0.4775510204081633
Batch 3: loss 0.6683227817217509 | Acc 0.6204081632653061
Batch 4: loss 0.6763685345649719 | Acc 0.4204081632653061
Batch 5: loss 0.680125379562378 | Acc 0.4775510204081633
Batch 6: loss 0.6733721892038981 | Acc 0.6204081632653061
Batch 7: loss 0.6762673003332955 | Acc 0.4897959183673469
Batch 8: loss 0.6740930899977684 | Acc 0.5673469387755102
Batch 9: loss 0.6698068777720133 | Acc 0.6122448979591837
Batch 10: loss 0.6712128102779389 | Acc 0.5224489795918368
Batch 11: loss 0.6722890626300465 | Acc 0.5183673469387755
Batch 12: loss 0.6703207542498907 | Acc 0.5959183673469388
Batch 13: loss 0.6714858687840976 | Acc 0.4897959183673469
Batch 14: loss 0.6678845116070339 | Acc 0.6163265306122448
Batch 15: loss 0.6657224218050639 | Acc 0.5714285714285714
Batch 16: loss 0.6668382249772549 | Acc 0.4897959183673469
Batch 17: loss 0.6636558841256535 | Acc 0.6244897959183674
Batch 1

In [7]:
print(mtrainer.best_model.scale)

Parameter containing:
tensor(1.0008, device='mps:0', requires_grad=True)


In [None]:
mtrainer.run_eval(mtrainer.model, mtrainer.create_query_eval_dataloader('train'), True)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.val_loader, True)

In [7]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader, True)

#### Result of training backbone for 3 iterations
Training Query set: 0.5583506872136611
Test set:
Validation set:


#### Result of attention for 3 iterations


## Run experiments on baseline models without attention
### RelationNet

In [32]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.baselines import RelationNet

encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
encoder.text_model.device = device
base_model = RelationNet(encoder, 512, class_prototype_inf, fc_hidden_size=16)
btrainer = ControlledMetaTrainer(base_model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [33]:
btrainer.run_eval(btrainer.model, btrainer.val_loader)

(21.01457725947522, 0.5016376709062274, 0.6972464016505651)

In [34]:
btrainer.run_eval(btrainer.model, btrainer.test_loader)

(20.079641612742655, 0.5146355623202179, 0.6978601042817278)

In [43]:
btrainer.run_train(2, lr=1e-6)

Batch 1: loss 0.679679274559021 | Acc 68.16326530612244
Batch 2: loss 0.6872178614139557 | Acc 48.16326530612245
Batch 3: loss 0.6854563554128011 | Acc 63.6734693877551
Batch 4: loss 0.6872857213020325 | Acc 53.46938775510204
Batch 5: loss 0.6890603423118591 | Acc 52.6530612244898
Batch 6: loss 0.687427838643392 | Acc 66.12244897959184
Batch 7: loss 0.6865879552704948 | Acc 62.04081632653061
Batch 8: loss 0.6865909770131111 | Acc 61.224489795918366
Batch 9: loss 0.6880093084441291 | Acc 43.673469387755105
Batch 10: loss 0.6870020091533661 | Acc 68.57142857142857
Batch 11: loss 0.6873770952224731 | Acc 53.87755102040816
Batch 12: loss 0.6869907428820928 | Acc 61.63265306122449
Batch 13: loss 0.6864107847213745 | Acc 69.38775510204081
Batch 14: loss 0.6862963523183551 | Acc 58.77551020408164
Batch 15: loss 0.6861640095710755 | Acc 60.0
Batch 16: loss 0.6858026422560215 | Acc 64.08163265306122
Batch 17: loss 0.68585064831902 | Acc 57.14285714285714
Batch 18: loss 0.6858793861336179 | Acc 

In [39]:
btrainer.run_eval(btrainer.best_model, btrainer.val_loader)

(72.12827988338194, 0.5262049313442498, 0.6789789761815752)

In [44]:
btrainer.run_eval(btrainer.best_model, btrainer.test_loader)

(70.60228969636637, 0.5078939691060346, 0.6737378544923736)

In [41]:
torch.save(btrainer.best_model.cls.state_dict(), 'relnet_weights-16.pkl')

### Prototypical Network

In [2]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.base import euclidean_distance
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.baselines import ProtoNet

encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
encoder.text_model.device = device
base_model = ProtoNet(encoder, class_prototype_inf, euclidean_distance, trainable_base=False)
btrainer = ControlledMetaTrainer(base_model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.val_loader, True)

In [4]:
btrainer.run_eval(btrainer.model, btrainer.test_loader, True)

Loss 0.5569669604301453 | Accuracy 79.59183673469387 | AUC 0.6209134773272705
Loss 0.5399513244628906 | Accuracy 79.59183673469387 | AUC 0.6711944518066967
Loss 0.5567415356636047 | Accuracy 80.40816326530611 | AUC 0.5641938202283031
Loss 0.5585796236991882 | Accuracy 78.36734693877551 | AUC 0.6835645847945377
Loss 0.5578333735466003 | Accuracy 80.81632653061224 | AUC 0.6120503887997691
Loss 0.5560277104377747 | Accuracy 79.59183673469387 | AUC 0.5566583374979178
Loss 0.5344257950782776 | Accuracy 80.81632653061224 | AUC 0.6238864838864838
Loss 0.5641512870788574 | Accuracy 79.18367346938776 | AUC 0.5742744103968594
Loss 0.5355821251869202 | Accuracy 79.59183673469387 | AUC 0.6636724640595155
Loss 0.5311662554740906 | Accuracy 80.81632653061224 | AUC 0.6199343504515918
Loss 0.553301215171814 | Accuracy 79.59183673469387 | AUC 0.605984796864459
Loss 0.5594642162322998 | Accuracy 82.0408163265306 | AUC 0.6736303714032197
Loss 0.5388792753219604 | Accuracy 80.40816326530611 | AUC 0.589047

(79.95022399203586, 0.6175112298039078, 0.5496329301741065)

In [4]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf, class_prototype_mean
from models.metaclassifier.base import euclidean_distance
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.baselines import ProtoNet

encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
encoder.text_model.device = device
base_model = ProtoNet(encoder, class_prototype_mean, euclidean_distance, trainable_base=False)
btrainer = ControlledMetaTrainer(base_model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [5]:
btrainer.run_eval(btrainer.model, btrainer.create_query_eval_dataloader(), True)
# mean prototype 55.16034985422738, 0.7399913123640167 (auc)

['no finding', 'tuberculosis', 'other lesion', 'infiltration', 'pulmonary fibrosis', 'aortic enlargement', 'other diseases']
Loss 0.6143988966941833 | Accuracy 57.14285714285714 | AUC 0.8143151004762809
['consolidation', 'interstitial lung disease', 'pleural effusion', 'lung opacity', 'pneumonia', 'cardiomegaly', 'pleural thickening']
Loss 0.6604911684989929 | Accuracy 55.10204081632652 | AUC 0.7138259071509845
['consolidation', 'interstitial lung disease', 'tuberculosis', 'other lesion', 'infiltration', 'pneumonia', 'pulmonary fibrosis']
Loss 0.6800097823143005 | Accuracy 54.285714285714285 | AUC 0.6943492781728074
['no finding', 'pleural effusion', 'lung opacity', 'cardiomegaly', 'pleural thickening', 'aortic enlargement', 'other diseases']
Loss 0.607733964920044 | Accuracy 54.69387755102041 | AUC 0.7853427164454344
['interstitial lung disease', 'tuberculosis', 'other lesion', 'infiltration', 'pulmonary fibrosis', 'aortic enlargement', 'other diseases']
Loss 0.7186079621315002 | Accu

(55.16034985422738, 0.7399913123640167, 0.6413699126973444)

In [3]:
btrainer.run_eval(btrainer.model, btrainer.create_query_eval_dataloader(), True)
# inf prototype: 55.826738858808845 (acc), 0.7317367810483341 (auc)

['no finding', 'interstitial lung disease', 'tuberculosis', 'other lesion', 'lung opacity', 'aortic enlargement', 'other diseases']


  classes_count = torch.nonzero(label_inds)[:,1].bincount()


Loss 0.5694094300270081 | Accuracy 62.44897959183674 | AUC 0.7926244586226436
['consolidation', 'infiltration', 'pleural effusion', 'pneumonia', 'cardiomegaly', 'pleural thickening', 'pulmonary fibrosis']
Loss 0.6985782384872437 | Accuracy 50.204081632653065 | AUC 0.7035248454675698
['consolidation', 'pleural effusion', 'lung opacity', 'cardiomegaly', 'pleural thickening', 'pulmonary fibrosis', 'aortic enlargement']
Loss 0.7009211182594299 | Accuracy 46.93877551020408 | AUC 0.7360091733326218
['no finding', 'interstitial lung disease', 'tuberculosis', 'other lesion', 'infiltration', 'pneumonia', 'other diseases']
Loss 0.5711962580680847 | Accuracy 58.77551020408164 | AUC 0.811180126345821
['no finding', 'other lesion', 'infiltration', 'pleural effusion', 'pneumonia', 'cardiomegaly', 'pleural thickening']
Loss 0.5496819615364075 | Accuracy 60.816326530612244 | AUC 0.8068685701322694
['consolidation', 'interstitial lung disease', 'tuberculosis', 'lung opacity', 'pulmonary fibrosis', 'aor

(55.826738858808845, 0.7317367810483341, 0.6398802235418436)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.test_loader, True)