In [6]:
import torch
import pickle
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from utils.device import get_device
from models.embedding.dataset import Dataset
from models.embedding.trainer import Trainer
from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder, resnet_backbone
from utils.sampling import MultilabelBalancedRandomSampler

def get_query_and_support_ids(img_info, split_file):
    with open(split_file, 'rb') as fp:
        cxr_train_query = pickle.load(fp)
    query_image_ids = []
    for ids in cxr_train_query.values():
        query_image_ids.extend(ids)
    support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()
    return query_image_ids, support_image_ids

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')
# support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14
query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
# support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, shuffle=True)
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, sampler=MultilabelBalancedRandomSampler(support_dataset.get_class_indicators()))

PROJ_SIZE = 512
device = get_device()
# backbone = resnet_backbone(load_pretrained_resnet(1, 14, 'models/backbone/pretrained/cxr_backbone_bal.pkl'))

# backbone = load_medclip_retrained_resnet('models/backbone/pretrained/medclip_resnet50.pkl')
# model = ImageTextEmbedding(backbone, PROJ_SIZE, device=device)
model = torch.load('imgtext_model_trained.pth')

mtrainer = Trainer(model, support_dataset.class_labels(), device)


In [23]:
mtrainer.model.text_model.set_backbone_trainable(False)
mtrainer.model.img_model.set_backbone_trainable(True)
mtrainer.model.logit_scale.requires_grad = False

In [None]:
mtrainer.run_train(2, support_loader, query_loader, lr=1e-5)

In [25]:
mtrainer.run_eval(mtrainer.best_model, query_loader)


(63.295918367346935, 0.802973279267677, 143.54310607910156)

In [26]:
mtrainer.run_eval(mtrainer.model, query_loader)

(63.795918367346935, 0.8083159331393267, 141.8147216796875)

In [27]:
torch.save(mtrainer.model, 'imgtext_model_trained.pth')