In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14

query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))

PROJ_SIZE = 512
# device = 'cpu'
device =  get_device()

encoder = torch.load('models/embedding/model/imgtext_model_trained.pth')
encoder.text_model.device = device
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=4, cls_weight=0.5)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [2]:
torch.autograd.set_detect_anomaly(True)

mtrainer.run_train(2, support_loader, query_loader, lr=2e-5, full_training=False)

  classes = torch.nonzero(label_inds)[:,1] # (Np,)


Batch 1: loss 11.806146621704102
Batch 2: loss 12.007411479949951
Batch 3: loss 12.039634386698404
Batch 4: loss 11.932925462722778
Batch 5: loss 11.890310096740723
Batch 6: loss 11.853495597839355
Batch 7: loss 11.849952697753906
Batch 8: loss 11.844443798065186
Batch 9: loss 11.81305980682373
Batch 10: loss 11.809188938140869
Batch 11: loss 11.803815234791148
Batch 12: loss 11.796727895736694
Batch 13: loss 11.746948608985313
Batch 14: loss 11.71504613331386
Batch 15: loss 11.693219248453776
Batch 16: loss 11.680330693721771
Batch 17: loss 11.701137711020078
Batch 18: loss 11.681871891021729
Batch 19: loss 11.655200908058568
Batch 20: loss 11.657014513015747
Batch 21: loss 11.635733740670341
Batch 22: loss 11.63504713231867
Batch 23: loss 11.62240335215693
Batch 24: loss 11.595344622929892
Batch 25: loss 11.595146560668946
Batch 26: loss 11.573901946728046
Batch 27: loss 11.571162011888292
Batch 28: loss 11.55950699533735
Batch 29: loss 11.54445378533725
Batch 30: loss 11.53104127248

In [None]:
mtrainer.run_eval(mtrainer.best_model, query_loader)

In [12]:
mtrainer.run_eval(mtrainer.model, query_loader)

(97.82653061224491, 0.9992765783681979, 14.54466733932495)

In [3]:
torch.save(mtrainer.model.attention, 'attention-model8h4l.pth')