In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel, LabelImageMHAttention
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler
from utils.f1_loss import BalAccuracyLoss


configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*14

query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))


PROJ_SIZE = 512
device =  get_device()

encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
encoder.text_model.device = device


In [2]:
attn_model = LabelImageMHAttention(PROJ_SIZE, 8, cls_weight=5, cls_loss=BalAccuracyLoss(), device=device)
# attn_model = torch.load('models/attention/model/vindr2/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=4, cls_weight=5, cls_loss=BalAccuracyLoss())
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [None]:
attn_model = torch.load('models/attention/model/vindr2/attention-model8h4l.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

mtrainer.run_train(10, support_loader, query_loader, lr=2e-5, encoder_only=True)

In [17]:
import copy
mtrainer.model = copy.deepcopy(mtrainer.best_model)

In [18]:
torch.autograd.set_detect_anomaly(True)

# mtrainer.run_train(6, support_loader, query_loader, lr=2e-5, full_training=True)
# mtrainer.run_train(10, support_loader, query_loader, lr=2e-5)
mtrainer.run_train(2, support_loader, query_loader, lr=1e-4, min_lr=5e-6)

Batch 1: loss 7.480624198913574
Batch 2: loss 7.522526741027832
Batch 3: loss 7.482800006866455
Batch 4: loss 7.477001667022705
Batch 5: loss 7.555340766906738
Batch 6: loss 7.448117733001709
Batch 7: loss 7.500673770904541
Batch 8: loss 7.479950904846191
Batch 9: loss 7.488985538482666
Batch 10: loss 7.4257402420043945
Batch 11: loss 7.4855756759643555
Batch 12: loss 7.478771209716797
Batch 13: loss 7.562317848205566
Batch 14: loss 7.489653587341309
Batch 15: loss 7.472197532653809
Batch 16: loss 7.457478046417236
Batch 17: loss 7.43039083480835
Batch 18: loss 7.437932014465332
Batch 19: loss 7.500885963439941
Batch 20: loss 7.49312686920166
Batch 21: loss 7.526444911956787
Batch 22: loss 7.493459701538086
Batch 23: loss 7.427828788757324
Batch 24: loss 7.429175853729248
Batch 25: loss 7.492195129394531
Batch 26: loss 7.5052571296691895
Batch 27: loss 7.477477073669434
Batch 28: loss 7.4575653076171875
Batch 29: loss 7.485163688659668
Batch 30: loss 7.45986270904541
Batch 31: loss 7.4

In [21]:
torch.save(mtrainer.best_model.attention, 'models/attention/model/vindr2/full-mh/attention-8h.pth')

In [22]:
torch.save(mtrainer.model.encoder, 'models/attention/model/vindr2/full-mh/imgtxt-encoder.pth')

In [10]:
torch.save(mtrainer.best_model.attention, 'models/attention/model/vindr2/attention-trans-8h4l.pth')

In [6]:
torch.save(mtrainer.model.attention, 'models/attention/model/vindr2/full-mh/attention-model8h.pth')

In [4]:
from utils.metrics import AverageMeter, calculate_auc, multilabel_logit_accuracy
from torchmetrics.classification import MultilabelRecall, MultilabelSpecificity, MultilabelPrecision, MultilabelF1Score
from models.attention.model import image_text_logits

def run_eval(model, dataloader, device, class_labels):
    model.eval()
    model = model.to(device)
    
    auc_meter = AverageMeter()
    acc_meter = AverageMeter()
    spec_meter = AverageMeter()
    rec_meter = AverageMeter()
    f1_meter = AverageMeter()

    num_labels = len(class_labels)
    specificity = MultilabelSpecificity(num_labels=num_labels).to(device)
    recall = MultilabelRecall(num_labels=num_labels).to(device)
    precision = MultilabelPrecision(num_labels=num_labels).to(device)
    f1_func = MultilabelF1Score(num_labels=num_labels).to(device)
    with torch.no_grad():
         for images, class_inds in dataloader:
                images, class_inds = images.to(device), class_inds.to(device)

                text_embeddings, _, prototypes = model(class_labels, images)

                logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                
                f1 = f1_func(logits_per_image, class_inds)
                f1_meter.update(f1.item(), len(class_inds))

                auc = calculate_auc(logits_per_image, class_inds)
                auc_meter.update(auc, len(class_inds))
            
                acc = multilabel_logit_accuracy(logits_per_image, class_inds)
                acc_meter.update(acc, len(class_inds))

                spec = specificity(logits_per_image, class_inds)
                spec_meter.update(spec.item(), len(class_inds))
                rec = recall(logits_per_image, class_inds)
                rec_meter.update(rec.item(), len(class_inds))
                prec = precision(logits_per_image, class_inds)
                print(f"F1 {f1} | Accuracy {acc} | AUC {auc} | Specificity {spec} | Recall {rec} | Precision {prec}")
            
    return f1_meter.average(), acc_meter.average(), auc_meter.average(), spec_meter.average(), rec_meter.average()

In [8]:
val_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'val']['image_id'].to_list(), config.label_names_map, config.classes_split_map['val'], mean_std=config.mean_std)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.model, val_loader, device, val_dataset.class_labels())

F1 0.29421404004096985 | Accuracy 0.5387755102040817 | AUC 0.5364233972863903 | Specificity 0.529010534286499 | Recall 0.5449503064155579 | Precision 0.22477076947689056
F1 0.2881873548030853 | Accuracy 0.5704081632653061 | AUC 0.5492518288214842 | Specificity 0.5554752349853516 | Recall 0.549533486366272 | Precision 0.22559234499931335
F1 0.2743912935256958 | Accuracy 0.5265306122448979 | AUC 0.6043750009261554 | Specificity 0.5206221342086792 | Recall 0.6165170669555664 | Precision 0.21531684696674347
F1 0.24442556500434875 | Accuracy 0.536734693877551 | AUC 0.47709730714792276 | Specificity 0.551681399345398 | Recall 0.4351627230644226 | Precision 0.20781458914279938
F1 0.2748141288757324 | Accuracy 0.5673469387755102 | AUC 0.5204891878065561 | Specificity 0.5728052854537964 | Recall 0.4932914078235626 | Precision 0.22820404171943665
F1 0.29293304681777954 | Accuracy 0.573469387755102 | AUC 0.538747686447979 | Specificity 0.5742725729942322 | Recall 0.4904204308986664 | Precision 0.

(0.26910296070709305,
 0.5485450607401827,
 0.5204696436438535,
 0.548269678120874,
 0.48691349972135234)

In [20]:
run_eval(mtrainer.model, val_loader, device, val_dataset.class_labels())

F1 0.3052128553390503 | Accuracy 0.4673469387755102 | AUC 0.5405815330923319 | Specificity 0.44555407762527466 | Recall 0.6491703391075134 | Precision 0.23910662531852722
F1 0.33465278148651123 | Accuracy 0.4806122448979592 | AUC 0.6097163272350138 | Specificity 0.4510611295700073 | Recall 0.7075929641723633 | Precision 0.2729678153991699
F1 0.33510810136795044 | Accuracy 0.4479591836734694 | AUC 0.5603294442909382 | Specificity 0.39294689893722534 | Recall 0.706222653388977 | Precision 0.26464584469795227
F1 0.32385921478271484 | Accuracy 0.4642857142857143 | AUC 0.6043042554139143 | Specificity 0.4298592805862427 | Recall 0.7212198972702026 | Precision 0.2490430474281311
F1 0.28711625933647156 | Accuracy 0.4969387755102041 | AUC 0.565112697391663 | Specificity 0.48812243342399597 | Recall 0.6013959646224976 | Precision 0.22610297799110413
F1 0.2734333276748657 | Accuracy 0.49795918367346936 | AUC 0.5907128027494097 | Specificity 0.49308228492736816 | Recall 0.5784724950790405 | Preci

(0.3003918312455607,
 0.47509181655523125,
 0.5670647324289892,
 0.45663241151138284,
 0.6381473661490624)

In [15]:
# val_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'val']['image_id'].to_list(), config.label_names_map, config.classes_split_map['val'], mean_std=config.mean_std)
# val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.best_model, val_loader, device, val_dataset.class_labels())

F1 0.2756957411766052 | Accuracy 0.4714285714285714 | AUC 0.5545612258428514 | Specificity 0.47387582063674927 | Recall 0.5659193992614746 | Precision 0.22475971281528473
F1 0.3287630081176758 | Accuracy 0.45102040816326533 | AUC 0.5815920929528132 | Specificity 0.4068641662597656 | Recall 0.6907780766487122 | Precision 0.24743859469890594
F1 0.33262795209884644 | Accuracy 0.4704081632653061 | AUC 0.6001559678993604 | Specificity 0.4284785985946655 | Recall 0.7443388104438782 | Precision 0.2552799582481384
F1 0.33989763259887695 | Accuracy 0.47959183673469385 | AUC 0.5936748513652162 | Specificity 0.44547414779663086 | Recall 0.7318285703659058 | Precision 0.263365238904953
F1 0.29540395736694336 | Accuracy 0.4326530612244898 | AUC 0.5642711386554464 | Specificity 0.4070116877555847 | Recall 0.6624724864959717 | Precision 0.23915618658065796
F1 0.27027347683906555 | Accuracy 0.45102040816326533 | AUC 0.5315696919619594 | Specificity 0.4515884518623352 | Recall 0.590854823589325 | Preci

(0.3015518079345448,
 0.4526791599962332,
 0.5659573852865287,
 0.42707320558608325,
 0.6649322124712548)

In [19]:
run_eval(mtrainer.best_model, query_loader, device, query_dataset.class_labels())

F1 0.5936726331710815 | Accuracy 0.6457142857142857 | AUC 0.7024162758835758 | Specificity 0.6376651525497437 | Recall 0.7216821908950806 | Precision 0.5882301330566406
F1 0.5837500095367432 | Accuracy 0.6521428571428571 | AUC 0.7137412539207839 | Specificity 0.6523699164390564 | Recall 0.7104381918907166 | Precision 0.5872392058372498
F1 0.5713423490524292 | Accuracy 0.6235714285714286 | AUC 0.6894183467670707 | Specificity 0.6026982665061951 | Recall 0.7092002630233765 | Precision 0.5633648633956909
F1 0.6020265817642212 | Accuracy 0.645 | AUC 0.6994260169397626 | Specificity 0.6121586561203003 | Recall 0.7530618906021118 | Precision 0.5963531732559204


(0.5859784507751464,
 0.6412,
 0.7014694081503626,
 0.6279107189178467,
 0.7200596833229065)

In [14]:
run_eval(mtrainer.model, query_loader, device, query_dataset.class_labels())

F1 0.600047767162323 | Accuracy 0.6485714285714286 | AUC 0.720599332865864 | Specificity 0.5798352360725403 | Recall 0.7575923204421997 | Precision 0.5843406319618225
F1 0.5760691165924072 | Accuracy 0.6307142857142857 | AUC 0.7187121760042463 | Specificity 0.5898844003677368 | Recall 0.728096604347229 | Precision 0.5598985552787781
F1 0.6033751964569092 | Accuracy 0.6507142857142857 | AUC 0.7523908916023636 | Specificity 0.5850102305412292 | Recall 0.7935389280319214 | Precision 0.5648530721664429
F1 0.6157442927360535 | Accuracy 0.6775 | AUC 0.7858159871491854 | Specificity 0.6618647575378418 | Recall 0.76849764585495 | Precision 0.6044007539749146


(0.5967768692970276,
 0.6487999999999999,
 0.7394072300761624,
 0.5972227239608765,
 0.76114342212677)