In [1]:
import torch
from torch.utils.data import DataLoader
import pandas as pd

from models.attention.model import LabelImageAttention, LabelImagePrototypeModel
from models.attention.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder

from utils.device import get_device
from utils.data import get_query_and_support_ids
from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from models.embedding.dataset import Dataset
from utils.sampling import MultilabelBalancedRandomSampler

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14

query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))

PROJ_SIZE = 512
# device = 'cpu'
device =  get_device()

encoder = torch.load('imgtext_model_trained.pth')
encoder.text_model.device = device
model = LabelImagePrototypeModel(encoder, 8, PROJ_SIZE, num_layers=2)
mtrainer = Trainer(model, support_dataset.class_labels(), device)

In [None]:
torch.autograd.set_detect_anomaly(True)

mtrainer.run_train(2, support_loader, query_loader, lr=2e-5, full_training=False)

In [None]:
mtrainer.run_eval(mtrainer.best_model, query_loader)

In [12]:
mtrainer.run_eval(mtrainer.model, query_loader)

(97.82653061224491, 0.9992765783681979, 14.54466733932495)

In [13]:
torch.save(mtrainer.model.attention, 'attention-model8h2l.pth')