In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel, LabelImageMHAttention
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler
from utils.f1_loss import BalAccuracyLoss


configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14

query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))


PROJ_SIZE = 512
device =  get_device()

encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
encoder.text_model.device = device


In [2]:
# attn_model = LabelImageMHAttention(PROJ_SIZE, 8, cls_weight=1, cls_loss=BalAccuracyLoss(), device=device)
attn_model = torch.load('models/attention/model/vindr2/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=4, cls_weight=1, cls_loss=BalAccuracyLoss())
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [None]:
attn_model = torch.load('models/attention/model/vindr2/attention-model8h4l.pth') # previously trained up to 10, 6 the best epoch
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

mtrainer.run_train(10, support_loader, query_loader, lr=2e-5, encoder_only=True)

In [3]:
torch.autograd.set_detect_anomaly(True)

# mtrainer.run_train(6, support_loader, query_loader, lr=2e-5, full_training=True)
# mtrainer.run_train(10, support_loader, query_loader, lr=2e-5)
mtrainer.run_train(4, support_loader, query_loader, lr=5e-5)

  classes = torch.nonzero(label_inds)[:,1] # (Np,)


Batch 1: loss 6.230784893035889
Batch 2: loss 6.230344533920288
Batch 3: loss 6.243095715840657
Batch 4: loss 6.236725568771362
Batch 5: loss 6.236250114440918
Batch 6: loss 6.228386958440145
Batch 7: loss 6.232144560132708
Batch 8: loss 6.2259631752967834
Batch 9: loss 6.220034069485134
Batch 10: loss 6.217113447189331
Batch 11: loss 6.217128970406272
Batch 12: loss 6.216728091239929
Batch 13: loss 6.217156556936411
Batch 14: loss 6.2182431902204245
Batch 15: loss 6.215466626485189
Batch 16: loss 6.213330298662186
Batch 17: loss 6.215746262494256
Batch 18: loss 6.21423535876804
Batch 19: loss 6.21462345123291
Batch 20: loss 6.2128478527069095
Batch 21: loss 6.212313402266729
Batch 22: loss 6.2133310274644336
Batch 23: loss 6.212853307309358
Batch 24: loss 6.213221510251363
Batch 25: loss 6.212041397094726
Batch 26: loss 6.211079414074238
Batch 27: loss 6.209144609945792
Batch 28: loss 6.209561807768686
Batch 29: loss 6.211801611143967
Batch 30: loss 6.212913354237874
Batch 31: loss 6.

In [6]:
torch.save(mtrainer.best_model.attention, 'models/attention/model/vindr2/attention-8h.pth')

In [7]:
torch.save(mtrainer.best_model.attention, 'models/attention/model/vindr2/attention-model8h.pth')

In [None]:
mtrainer.run_eval(mtrainer.model, query_loader, additional_stats=True)

In [6]:
torch.save(mtrainer.model.attention, 'models/attention/model/vindr2/full-mh/attention-model8h.pth')

In [16]:
torch.save(mtrainer.model.encoder, 'models/attention/model/vindr2/full-mh/imgtxt-encoder.pth')

In [6]:
from utils.metrics import AverageMeter, calculate_auc, multilabel_logit_accuracy
from torchmetrics.classification import MultilabelRecall, MultilabelSpecificity, MultilabelPrecision, MultilabelF1Score
from models.attention.model import image_text_logits

def run_eval(model, dataloader, device, class_labels):
    model.eval()
    model = model.to(device)
    
    auc_meter = AverageMeter()
    acc_meter = AverageMeter()
    spec_meter = AverageMeter()
    rec_meter = AverageMeter()
    f1_meter = AverageMeter()

    num_labels = len(class_labels)
    specificity = MultilabelSpecificity(num_labels=num_labels).to(device)
    recall = MultilabelRecall(num_labels=num_labels).to(device)
    precision = MultilabelPrecision(num_labels=num_labels).to(device)
    f1_func = MultilabelF1Score(num_labels=num_labels).to(device)
    with torch.no_grad():
         for images, class_inds in dataloader:
                images, class_inds = images.to(device), class_inds.to(device)

                text_embeddings, _, prototypes = model(class_labels, images)

                logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                
                f1 = f1_func(logits_per_image, class_inds)
                f1_meter.update(f1.item(), len(class_inds))

                auc = calculate_auc(logits_per_image, class_inds)
                auc_meter.update(auc, len(class_inds))
            
                acc = multilabel_logit_accuracy(logits_per_image, class_inds)
                acc_meter.update(acc, len(class_inds))

                spec = specificity(logits_per_image, class_inds)
                spec_meter.update(spec.item(), len(class_inds))
                rec = recall(logits_per_image, class_inds)
                rec_meter.update(rec.item(), len(class_inds))
                prec = precision(logits_per_image, class_inds)
                print(f"F1 {f1} | Accuracy {acc} | AUC {auc} | Specificity {spec} | Recall {rec} | Precision {prec}")
            
    return f1_meter.average(), acc_meter.average(), auc_meter.average(), spec_meter.average(), rec_meter.average()

In [5]:
test_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'test']['image_id'].to_list(), config.label_names_map, config.classes_split_map['test'], mean_std=config.mean_std)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.model, test_loader, device, test_dataset.class_labels())

  output, inverse_indices, counts = torch._unique2(


F1 0.1788504421710968 | Accuracy 0.30918367346938774 | AUC 0.4892710192307759 | Specificity 0.2857142984867096 | Recall 0.7142857313156128 | Precision 0.11530613154172897
F1 0.19723308086395264 | Accuracy 0.3346938775510204 | AUC 0.47060236046163706 | Specificity 0.2857142984867096 | Recall 0.7142857313156128 | Precision 0.12959183752536774
F1 0.17511191964149475 | Accuracy 0.30612244897959184 | AUC 0.511020381339966 | Specificity 0.2857142984867096 | Recall 0.7142857313156128 | Precision 0.11122449487447739
F1 0.1742706298828125 | Accuracy 0.30612244897959184 | AUC 0.46802414855294805 | Specificity 0.2857142984867096 | Recall 0.7142857313156128 | Precision 0.11122449487447739
F1 0.19095423817634583 | Accuracy 0.32653061224489793 | AUC 0.4958032233303745 | Specificity 0.2857142984867096 | Recall 0.7142857313156128 | Precision 0.12551021575927734
F1 0.19806042313575745 | Accuracy 0.32755102040816325 | AUC 0.41697510452881664 | Specificity 0.2857142984867096 | Recall 0.7142857313156128 |

(0.18860228517846048,
 0.322158395329127,
 0.4720930012424481,
 0.28571429573633045,
 0.7142857313156128)

In [15]:
run_eval(mtrainer.best_model, test_loader, device, test_dataset.class_labels())

F1 0.19492429494857788 | Accuracy 0.3948979591836735 | AUC 0.44718458887647267 | Specificity 0.35591596364974976 | Recall 0.5714285969734192 | Precision 0.12244898080825806
F1 0.20341134071350098 | Accuracy 0.4010204081632653 | AUC 0.5006839339776653 | Specificity 0.3451247215270996 | Recall 0.581632673740387 | Precision 0.13001494109630585
F1 0.19422617554664612 | Accuracy 0.386734693877551 | AUC 0.48136104522629214 | Specificity 0.3431149423122406 | Recall 0.5904762148857117 | Precision 0.12210884690284729
F1 0.20395749807357788 | Accuracy 0.3877551020408163 | AUC 0.5019968995514166 | Specificity 0.34141138195991516 | Recall 0.6181318759918213 | Precision 0.1273934245109558
F1 0.17863449454307556 | Accuracy 0.373469387755102 | AUC 0.4428214562747441 | Specificity 0.354105144739151 | Recall 0.5714285969734192 | Precision 0.11122448742389679
F1 0.1938256323337555 | Accuracy 0.3836734693877551 | AUC 0.4724027450705705 | Specificity 0.3388245701789856 | Recall 0.5714285969734192 | Precis

(0.19852096105612024,
 0.38459365288633574,
 0.47138756152101796,
 0.34555131530950317,
 0.5927569414725332)

In [14]:
run_eval(mtrainer.model, test_loader, device, test_dataset.class_labels())

F1 0.262796014547348 | Accuracy 0.30306122448979594 | AUC 0.471879121071541 | Specificity 0.21807709336280823 | Recall 0.7181841135025024 | Precision 0.19198301434516907
F1 0.2434573471546173 | Accuracy 0.31326530612244896 | AUC 0.490890122745222 | Specificity 0.24420768022537231 | Recall 0.7156593799591064 | Precision 0.19905515015125275
F1 0.24523988366127014 | Accuracy 0.30612244897959184 | AUC 0.4906162395933992 | Specificity 0.24129928648471832 | Recall 0.7169082760810852 | Precision 0.19371946156024933
F1 0.2599810063838959 | Accuracy 0.3408163265306122 | AUC 0.5007050690120927 | Specificity 0.25909245014190674 | Recall 0.747619092464447 | Precision 0.20544257760047913
F1 0.2756465673446655 | Accuracy 0.30714285714285716 | AUC 0.4967361411829624 | Specificity 0.2088400423526764 | Recall 0.810031533241272 | Precision 0.20025786757469177
F1 0.28784602880477905 | Accuracy 0.31326530612244896 | AUC 0.5232660702535715 | Specificity 0.21574905514717102 | Recall 0.8339827656745911 | Pre

(0.25941179013990934,
 0.3147189000847537,
 0.48343970417718274,
 0.23450510232324417,
 0.7454920730546815)

In [None]:
run_eval(mtrainer.best_model, query_loader, device, query_dataset.class_labels())