In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel, LabelImageMHAttention
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler
from utils.f1_loss import BalAccuracyLoss


configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

batch_size = 10*14

query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))


PROJ_SIZE = 512
device =  get_device()

encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
encoder.text_model.device = device


In [2]:
attn_model = LabelImageMHAttention(PROJ_SIZE, 8, cls_weight=5, cls_loss=BalAccuracyLoss(), device=device)
# attn_model = torch.load('models/attention/model/vindr2/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=4, cls_weight=5, cls_loss=BalAccuracyLoss())
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [None]:
attn_model = torch.load('models/attention/model/vindr2/attention-model8h4l.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

mtrainer.run_train(10, support_loader, query_loader, lr=2e-5, encoder_only=True)

In [6]:
torch.autograd.set_detect_anomaly(True)

# mtrainer.run_train(6, support_loader, query_loader, lr=2e-5, full_training=True)
# mtrainer.run_train(10, support_loader, query_loader, lr=2e-5)
mtrainer.run_train(10, support_loader, query_loader, lr=1e-4, min_lr=5e-6)

Batch 1: loss 8.660057067871094
Batch 2: loss 8.663658618927002
Batch 3: loss 8.644646644592285
Batch 4: loss 8.623297691345215
Batch 5: loss 8.601280212402344
Batch 6: loss 8.568521658579508
Batch 7: loss 8.551353318350655
Batch 8: loss 8.525346159934998
Batch 9: loss 8.504292594061958
Batch 10: loss 8.486051654815673
Batch 11: loss 8.464292006059127
Batch 12: loss 8.448318401972452
Batch 13: loss 8.429268250098595
Batch 14: loss 8.41521692276001
Batch 15: loss 8.401623344421386
Batch 16: loss 8.39047235250473
Batch 17: loss 8.377981802996468
Batch 18: loss 8.365034421284994
Batch 19: loss 8.354037033884149
Batch 20: loss 8.342796993255615
Batch 21: loss 8.331245376950218
Batch 22: loss 8.321566061540084
Batch 23: loss 8.313089204871137
Batch 24: loss 8.304989377657572
Batch 25: loss 8.29705753326416
Batch 26: loss 8.288236764761118
Batch 27: loss 8.279079896432382
Batch 28: loss 8.266863226890564
Batch 29: loss 8.255518156906653
Batch 30: loss 8.244403266906739
Batch 31: loss 8.23330

In [24]:
torch.save(mtrainer.best_model.attention, 'models/attention/model/vindr2/attention-8h.pth')

In [10]:
torch.save(mtrainer.best_model.attention, 'models/attention/model/vindr2/attention-trans-8h4l.pth')

In [6]:
torch.save(mtrainer.model.attention, 'models/attention/model/vindr2/full-mh/attention-model8h.pth')

In [16]:
torch.save(mtrainer.model.encoder, 'models/attention/model/vindr2/full-mh/imgtxt-encoder.pth')

In [7]:
from utils.metrics import AverageMeter, calculate_auc, multilabel_logit_accuracy
from torchmetrics.classification import MultilabelRecall, MultilabelSpecificity, MultilabelPrecision, MultilabelF1Score
from models.attention.model import image_text_logits

def run_eval(model, dataloader, device, class_labels):
    model.eval()
    model = model.to(device)
    
    auc_meter = AverageMeter()
    acc_meter = AverageMeter()
    spec_meter = AverageMeter()
    rec_meter = AverageMeter()
    f1_meter = AverageMeter()

    num_labels = len(class_labels)
    specificity = MultilabelSpecificity(num_labels=num_labels).to(device)
    recall = MultilabelRecall(num_labels=num_labels).to(device)
    precision = MultilabelPrecision(num_labels=num_labels).to(device)
    f1_func = MultilabelF1Score(num_labels=num_labels).to(device)
    with torch.no_grad():
         for images, class_inds in dataloader:
                images, class_inds = images.to(device), class_inds.to(device)

                text_embeddings, _, prototypes = model(class_labels, images)

                logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                
                f1 = f1_func(logits_per_image, class_inds)
                f1_meter.update(f1.item(), len(class_inds))

                auc = calculate_auc(logits_per_image, class_inds)
                auc_meter.update(auc, len(class_inds))
            
                acc = multilabel_logit_accuracy(logits_per_image, class_inds)
                acc_meter.update(acc, len(class_inds))

                spec = specificity(logits_per_image, class_inds)
                spec_meter.update(spec.item(), len(class_inds))
                rec = recall(logits_per_image, class_inds)
                rec_meter.update(rec.item(), len(class_inds))
                prec = precision(logits_per_image, class_inds)
                print(f"F1 {f1} | Accuracy {acc} | AUC {auc} | Specificity {spec} | Recall {rec} | Precision {prec}")
            
    return f1_meter.average(), acc_meter.average(), auc_meter.average(), spec_meter.average(), rec_meter.average()

In [8]:
test_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'test']['image_id'].to_list(), config.label_names_map, config.classes_split_map['test'], mean_std=config.mean_std)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.model, test_loader, device, test_dataset.class_labels())

F1 0.29421404004096985 | Accuracy 0.5387755102040817 | AUC 0.5364233972863903 | Specificity 0.529010534286499 | Recall 0.5449503064155579 | Precision 0.22477076947689056
F1 0.2881873548030853 | Accuracy 0.5704081632653061 | AUC 0.5492518288214842 | Specificity 0.5554752349853516 | Recall 0.549533486366272 | Precision 0.22559234499931335
F1 0.2743912935256958 | Accuracy 0.5265306122448979 | AUC 0.6043750009261554 | Specificity 0.5206221342086792 | Recall 0.6165170669555664 | Precision 0.21531684696674347
F1 0.24442556500434875 | Accuracy 0.536734693877551 | AUC 0.47709730714792276 | Specificity 0.551681399345398 | Recall 0.4351627230644226 | Precision 0.20781458914279938
F1 0.2748141288757324 | Accuracy 0.5673469387755102 | AUC 0.5204891878065561 | Specificity 0.5728052854537964 | Recall 0.4932914078235626 | Precision 0.22820404171943665
F1 0.29293304681777954 | Accuracy 0.573469387755102 | AUC 0.538747686447979 | Specificity 0.5742725729942322 | Recall 0.4904204308986664 | Precision 0.

(0.26910296070709305,
 0.5485450607401827,
 0.5204696436438535,
 0.548269678120874,
 0.48691349972135234)

In [9]:
run_eval(mtrainer.best_model, test_loader, device, test_dataset.class_labels())

F1 0.33683517575263977 | Accuracy 0.5867346938775511 | AUC 0.5775634642396389 | Specificity 0.5719772577285767 | Recall 0.5820599794387817 | Precision 0.2731553614139557
F1 0.23812535405158997 | Accuracy 0.523469387755102 | AUC 0.4520744551409893 | Specificity 0.5198322534561157 | Recall 0.37458890676498413 | Precision 0.20049630105495453
F1 0.26939091086387634 | Accuracy 0.5173469387755102 | AUC 0.4940022823974166 | Specificity 0.5003655552864075 | Recall 0.5163915157318115 | Precision 0.2047646939754486
F1 0.2377161681652069 | Accuracy 0.5316326530612245 | AUC 0.5195210574333502 | Specificity 0.5455586910247803 | Recall 0.45703208446502686 | Precision 0.18662647902965546
F1 0.24718846380710602 | Accuracy 0.5540816326530612 | AUC 0.47868969583608384 | Specificity 0.5614677667617798 | Recall 0.36818647384643555 | Precision 0.2032890021800995
F1 0.2704866826534271 | Accuracy 0.4959183673469388 | AUC 0.5424808905209917 | Specificity 0.4695279002189636 | Recall 0.562262773513794 | Precisi

(0.2708613712809631,
 0.5380920990677088,
 0.5195715479458692,
 0.531355715575938,
 0.49139137352935247)

In [11]:
run_eval(mtrainer.model, query_loader, device, query_dataset.class_labels())

F1 0.551537275314331 | Accuracy 0.675 | AUC 0.7134491798187556 | Specificity 0.6937583088874817 | Recall 0.630710244178772 | Precision 0.5674383640289307
F1 0.5394936800003052 | Accuracy 0.6928571428571428 | AUC 0.7143775906323018 | Specificity 0.7368800640106201 | Recall 0.605082094669342 | Precision 0.5787002444267273
F1 0.5987014770507812 | Accuracy 0.6964285714285714 | AUC 0.7165412844172835 | Specificity 0.7004305124282837 | Recall 0.6797438263893127 | Precision 0.6326056122779846
F1 0.6119371652603149 | Accuracy 0.7025 | AUC 0.756248347628764 | Specificity 0.7451474666595459 | Recall 0.6682404279708862 | Precision 0.6369891166687012


(0.5710350275039673,
 0.6904,
 0.7214227909837376,
 0.7159228825569153,
 0.6432685947418213)