In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel, LabelImageMHAttention
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.data import get_query_and_support_ids, DatasetConfig
from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT, VINDR_SPLIT2
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler
from utils.f1_loss import BalAccuracyLoss


configs = {
    'vindr1': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels.pkl', 'data/vindr_train_query_set.pkl', VINDR_CXR_LABELS, VINDR_SPLIT, MEAN_STDS['chestmnist']),
    'vindr2': DatasetConfig('datasets/vindr-cxr-png', 'data/vindr_cxr_split_labels2.pkl', 'data/vindr_train_query_set2.pkl', VINDR_CXR_LABELS, VINDR_SPLIT2, MEAN_STDS['chestmnist'])
}

config = configs['vindr2']

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14

query_image_ids, support_image_ids = get_query_and_support_ids(config.img_info, config.training_split_path)
query_dataset = Dataset(config.img_path, config.img_info, query_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(config.img_path, config.img_info, support_image_ids, config.label_names_map, config.classes_split_map['train'], mean_std=config.mean_std)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))


PROJ_SIZE = 512
device =  get_device()

encoder = torch.load('models/embedding/model/vindr2/imgtext_model_trained1.pth')
encoder.text_model.device = device


In [2]:
attn_model = LabelImageMHAttention(PROJ_SIZE, 8, cls_weight=1, cls_loss=BalAccuracyLoss(), device=device)
# attn_model = torch.load('models/attention/model/vindr2/attention-8h.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=4, cls_weight=5, cls_loss=BalAccuracyLoss())
mtrainer = Trainer(model, support_dataset.class_labels(), device)

mtrainer.run_train(10, support_loader, query_loader, lr=2e-5, lr=1e-4, min_lr=5e-6)

In [None]:
attn_model = torch.load('models/attention/model/vindr2/attention-model8h4l.pth')
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, attn_model=attn_model)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

mtrainer.run_train(10, support_loader, query_loader, lr=2e-5, encoder_only=True)

In [20]:
mtrainer.best_loss = 8.106417350769043

In [None]:
torch.autograd.set_detect_anomaly(True)

# mtrainer.run_train(6, support_loader, query_loader, lr=2e-5, full_training=True)
# mtrainer.run_train(10, support_loader, query_loader, lr=2e-5)
mtrainer.run_train(10, support_loader, query_loader, lr=1e-4, min_lr=5e-6)

In [24]:
torch.save(mtrainer.best_model.attention, 'models/attention/model/vindr2/attention-8h.pth')

In [7]:
torch.save(mtrainer.best_model.attention, 'models/attention/model/vindr2/attention-model8h.pth')

In [6]:
torch.save(mtrainer.model.attention, 'models/attention/model/vindr2/full-mh/attention-model8h.pth')

In [16]:
torch.save(mtrainer.model.encoder, 'models/attention/model/vindr2/full-mh/imgtxt-encoder.pth')

In [3]:
from utils.metrics import AverageMeter, calculate_auc, multilabel_logit_accuracy
from torchmetrics.classification import MultilabelRecall, MultilabelSpecificity, MultilabelPrecision, MultilabelF1Score
from models.attention.model import image_text_logits

def run_eval(model, dataloader, device, class_labels):
    model.eval()
    model = model.to(device)
    
    auc_meter = AverageMeter()
    acc_meter = AverageMeter()
    spec_meter = AverageMeter()
    rec_meter = AverageMeter()
    f1_meter = AverageMeter()

    num_labels = len(class_labels)
    specificity = MultilabelSpecificity(num_labels=num_labels).to(device)
    recall = MultilabelRecall(num_labels=num_labels).to(device)
    precision = MultilabelPrecision(num_labels=num_labels).to(device)
    f1_func = MultilabelF1Score(num_labels=num_labels).to(device)
    with torch.no_grad():
         for images, class_inds in dataloader:
                images, class_inds = images.to(device), class_inds.to(device)

                text_embeddings, _, prototypes = model(class_labels, images)

                logits_per_image = image_text_logits(text_embeddings, prototypes, model.encoder.get_logit_scale())
                
                f1 = f1_func(logits_per_image, class_inds)
                f1_meter.update(f1.item(), len(class_inds))

                auc = calculate_auc(logits_per_image, class_inds)
                auc_meter.update(auc, len(class_inds))
            
                acc = multilabel_logit_accuracy(logits_per_image, class_inds)
                acc_meter.update(acc, len(class_inds))

                spec = specificity(logits_per_image, class_inds)
                spec_meter.update(spec.item(), len(class_inds))
                rec = recall(logits_per_image, class_inds)
                rec_meter.update(rec.item(), len(class_inds))
                prec = precision(logits_per_image, class_inds)
                print(f"F1 {f1} | Accuracy {acc} | AUC {auc} | Specificity {spec} | Recall {rec} | Precision {prec}")
            
    return f1_meter.average(), acc_meter.average(), auc_meter.average(), spec_meter.average(), rec_meter.average()

In [5]:
test_dataset = Dataset(config.img_path, config.img_info, config.img_info[config.img_info['meta_split'] == 'test']['image_id'].to_list(), config.label_names_map, config.classes_split_map['test'], mean_std=config.mean_std)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

run_eval(mtrainer.model, test_loader, device, test_dataset.class_labels())

F1 0.28933292627334595 | Accuracy 0.42653061224489797 | AUC 0.5012760106407754 | Specificity 0.34877878427505493 | Recall 0.6482062339782715 | Precision 0.2067118138074875
F1 0.28000032901763916 | Accuracy 0.45 | AUC 0.5244511868726545 | Specificity 0.39517742395401 | Recall 0.6888952255249023 | Precision 0.2086719423532486
F1 0.28447121381759644 | Accuracy 0.45 | AUC 0.5205382476454347 | Specificity 0.4083195924758911 | Recall 0.5929675698280334 | Precision 0.20806168019771576
F1 0.29833951592445374 | Accuracy 0.47653061224489796 | AUC 0.5744529014172269 | Specificity 0.4363011419773102 | Recall 0.6751984357833862 | Precision 0.2179604172706604
F1 0.28733593225479126 | Accuracy 0.47244897959183674 | AUC 0.5298072579951354 | Specificity 0.4340851306915283 | Recall 0.5941333174705505 | Precision 0.2122020423412323
F1 0.31361648440361023 | Accuracy 0.47551020408163264 | AUC 0.5771737606285515 | Specificity 0.41940829157829285 | Recall 0.7063462734222412 | Precision 0.22470252215862274
F1

(0.28909819139488135,
 0.46510970901214793,
 0.5371338871001969,
 0.4200109924111778,
 0.6311609967020522)

In [25]:
run_eval(mtrainer.model, test_loader, device, test_dataset.class_labels())

F1 0.283608078956604 | Accuracy 0.5510204081632653 | AUC 0.5591138550574418 | Specificity 0.5776880979537964 | Recall 0.5120512247085571 | Precision 0.24715235829353333
F1 0.3002997934818268 | Accuracy 0.5489795918367347 | AUC 0.6044194470160384 | Specificity 0.5632289052009583 | Recall 0.59834223985672 | Precision 0.24182595312595367
F1 0.2862740457057953 | Accuracy 0.5163265306122449 | AUC 0.5949728348230874 | Specificity 0.5229647159576416 | Recall 0.6066675782203674 | Precision 0.23028656840324402
F1 0.3166683316230774 | Accuracy 0.5683673469387756 | AUC 0.580749794828874 | Specificity 0.5908064246177673 | Recall 0.5454789400100708 | Precision 0.26926490664482117
F1 0.24924519658088684 | Accuracy 0.5275510204081633 | AUC 0.4699966260166683 | Specificity 0.5696976184844971 | Recall 0.3806300759315491 | Precision 0.21931084990501404
F1 0.2658153176307678 | Accuracy 0.5285714285714286 | AUC 0.5439945277405026 | Specificity 0.5573750138282776 | Recall 0.4719454348087311 | Precision 0.2

(0.28474679646243684,
 0.5383746115453434,
 0.5571112070759975,
 0.557079474668245,
 0.5215828832074381)

In [22]:
run_eval(mtrainer.model, query_loader, device, query_dataset.class_labels())

F1 0.5260613560676575 | Accuracy 0.58 | AUC 0.6544703604938192 | Specificity 0.5661594867706299 | Recall 0.6639364957809448 | Precision 0.4823448956012726
F1 0.4983893930912018 | Accuracy 0.5385714285714286 | AUC 0.6108061858041806 | Specificity 0.49644139409065247 | Recall 0.6740500330924988 | Precision 0.45193716883659363
F1 0.5306516885757446 | Accuracy 0.5778571428571428 | AUC 0.6377309784179848 | Specificity 0.5418899059295654 | Recall 0.6695412993431091 | Precision 0.48143211007118225
F1 0.527908205986023 | Accuracy 0.5475 | AUC 0.6262019690082068 | Specificity 0.4911428987979889 | Recall 0.682529091835022 | Precision 0.48489272594451904


(0.5198939955234527,
 0.5626,
 0.6330344219617888,
 0.5278402841091157,
 0.6713124465942383)