## Set up datasets and dataloaders

In [1]:
from utils.device import get_device
from utils.data import DatasetConfig
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS

IMG_PATH = 'datasets/vindr-cxr-png'

NUM_SHOTS = 5
NUM_WAYS = 7
N_QUERY = 10
TRAIN_NUM_WAYS= 7
# dataset_config = DatasetConfig(IMG_PATH, 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist'])
dataset_config = DatasetConfig(IMG_PATH, 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
device =  get_device()


import torch
import random
import numpy as np

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)

## Models with attention
### Run experiments on proposed model

In [None]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.model import ClsModel


encoder = torch.load('models/embedding/model/vindr1/imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('models/attention/model/vindr1/attention-model-8h4l.pth')
model = ClsModel(encoder, attention, 512, class_prototype_inf, fc_hidden_size=16)
mtrainer = ControlledMetaTrainer(model, NUM_SHOTS, NUM_WAYS, dataset_config, train_n_ways=TRAIN_NUM_WAYS, device=device)

In [None]:
mtrainer.run_eval(mtrainer.model, mtrainer.test_loader)


In [None]:
mtrainer.run_eval(mtrainer.model, mtrainer.val_loader)

In [None]:
# 5 times?
mtrainer.run_train(2, lr=1e-5)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader)

In [None]:
mtrainer.model.attn_model.set_trainable(True)

In [None]:
mtrainer.run_train(3, lr=1e-5)

In [None]:
# Model trained 3 epochs with lr=5e-6, 3 epochs with lr=1e-6, 3 epochs with lr=1e-5
mtrainer.run_eval(mtrainer.model, mtrainer.test_loader)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader)

In [None]:
torch.save(mtrainer.best_model.cls.state_dict(), 'models/metaclassifier/model/comb3/cls_weights-16.pkl')

In [None]:
torch.save(mtrainer.best_model.attn_model, 'models/metaclassifier/model/comb3/attention-model-8h4l.pth')

### Prototypical Network with attention

In [2]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.base import euclidean_distance
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.model import ProtoNetAttention

device =  get_device()

encoder = torch.load('models/embedding/model/vindr1/imgtext_model_trained.pth')
encoder.text_model.device = device
attention = torch.load('models/attention/model/vindr1/attention-model-8h4l.pth')
# imgtxt_encoder, attn_model, class_prototype_aggregator, distance_func
model = ProtoNetAttention(encoder, attention, class_prototype_inf, euclidean_distance)
mtrainer = ControlledMetaTrainer(model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [None]:
mtrainer.model.encoder.set_trainable(True, True, include_logit_scale=False)

In [None]:
mtrainer.model.attn_model.set_trainable(True)

In [None]:
mtrainer.run_eval(model, mtrainer.test_loader, True)

In [None]:
mtrainer.run_eval(model, mtrainer.val_loader)

In [None]:
mtrainer.run_train(3, lr=1e-5)

In [None]:
print(mtrainer.best_model.scale)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.create_query_eval_dataloader('train'), True)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.val_loader, True)

In [None]:
mtrainer.run_eval(mtrainer.best_model, mtrainer.test_loader, True)

## Run experiments on baseline models without attention
### RelationNet

In [None]:
import torch
from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.prototype import class_prototype_inf
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.baselines import RelationNet

encoder = torch.load('models/embedding/model/vindr1/imgtext_model_trained.pth')
encoder.text_model.device = device
base_model = RelationNet(encoder, 512, class_prototype_inf, fc_hidden_size=16)
btrainer = ControlledMetaTrainer(base_model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.val_loader)

In [None]:
btrainer.run_eval(btrainer.model, btrainer.test_loader)

In [None]:
btrainer.run_train(2, lr=1e-6)

In [None]:
btrainer.run_eval(btrainer.best_model, btrainer.val_loader)

In [None]:
btrainer.run_eval(btrainer.best_model, btrainer.test_loader)

In [None]:
torch.save(btrainer.best_model.cls.state_dict(), 'relnet_weights-16.pkl')

### Prototypical Network

In [2]:
import torch
from torchvision.models import resnet50

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder, ImageOnlyEmbedding, resnet_backbone, load_pretrained_resnet, adapt_resnet_input_channels

from utils.prototype import class_prototype_inf, class_prototype_mean, class_prototype_rrp
from models.metaclassifier.base import euclidean_distance, cosine_distance
from models.metaclassifier.trainer import ControlledMetaTrainer
from models.metaclassifier.baselines import ProtoNet

# img_backbone = resnet_backbone(adapt_resnet_input_channels(resnet50(weights=None), 1)) # 13.0714
img_backbone = resnet_backbone(load_pretrained_resnet(1, 14, 'models/backbone/pretrained/cxr_backbone_bal.pkl')) # 5.5300
# img_backbone = resnet_backbone(torch.load('models/backbone/pretrained/vindr2/trained-backbone.pth')) # 3.4861
# img_backbone = resnet_backbone(torch.load('models/backbone/pretrained/vindr2/trained-backbone-balacc.pth')) # 3.6244
encoder = ImageOnlyEmbedding(img_backbone, 512)

# encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth') #  0.9843
# encoder = torch.load('models/attention/model/vindr2/full/imgtxt-encoder.pth')
# encoder.text_model.device = device
base_model = ProtoNet(encoder, class_prototype_inf, euclidean_distance, trainable_base=False, scale=5.5300)

btrainer = ControlledMetaTrainer(base_model, NUM_SHOTS, NUM_WAYS, dataset_config, device=device, n_query=N_QUERY)

In [None]:
mtrainer.run_train(3, lr=0.05, min_lr=1e-4, lr_change_step=8)

In [None]:
set_seed(42)
mtrainer.run_eval(btrainer.best_model, btrainer.create_query_eval_dataloader(), True)

In [4]:
seeds = [42]
# seeds = [42, 142, 321]
dataloaders = {
    'query': btrainer.create_query_eval_dataloader(),
    'val': btrainer.val_loader,
    'test': btrainer.test_loader
}
for seed in seeds:
    print(f"Running seed: {seed}")
    for k, d in dataloaders.items():
        set_seed(seed)
        print(k)
        print(btrainer.run_eval(btrainer.model, d, True))

Running seed: 42
query


  classes_count = torch.nonzero(label_inds)[:,1].bincount()


Loss 0.5017755031585693 | F1 0.5352621674537659 | AUC 0.5025913566060083 | Specificity 0.398845911026001 | Recall 0.6332180500030518 | Bal Acc 0.5160319805145264
Loss 0.5897843837738037 | F1 0.6261849403381348 | AUC 0.4740823387334818 | Specificity 0.23998019099235535 | Recall 0.7194490432739258 | Bal Acc 0.47971463203430176
Loss 0.5285687446594238 | F1 0.49893999099731445 | AUC 0.4637256728778468 | Specificity 0.27443066239356995 | Recall 0.7666667699813843 | Bal Acc 0.5205487012863159
Loss 0.49999916553497314 | F1 0.5323529243469238 | AUC 0.5195971267587426 | Specificity 0.39143839478492737 | Recall 0.6073631048202515 | Bal Acc 0.4994007349014282
Loss 0.5379074811935425 | F1 0.5855317115783691 | AUC 0.5170592909442527 | Specificity 0.23362255096435547 | Recall 0.8241515755653381 | Bal Acc 0.5288870334625244
Loss 0.5325268507003784 | F1 0.6284292936325073 | AUC 0.4969740219522829 | Specificity 0.19544607400894165 | Recall 0.7958242297172546 | Bal Acc 0.49563515186309814
Loss 0.5106482