In [5]:
import torch

from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder, resnet_backbone

In [6]:
import pickle
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from utils.device import get_device
from models.embedding.dataset import Dataset
from models.embedding.trainer import Trainer

def get_query_and_support_ids(img_info, split_file):
    with open(split_file, 'rb') as fp:
        cxr_train_query = pickle.load(fp)
    query_image_ids = []
    for ids in cxr_train_query.values():
        query_image_ids.extend(ids)
    support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()
    return query_image_ids, support_image_ids

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')
# support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14
query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, shuffle=True)

PROJ_SIZE = 512
device = get_device()
# backbone = resnet_backbone(load_pretrained_resnet(1, 14, 'models/backbone/pretrained/cxr_backbone_bal.pkl'))

# backbone = load_medclip_retrained_resnet('models/backbone/pretrained/medclip_resnet50.pkl')
# model = ImageTextEmbedding(backbone, PROJ_SIZE, device=device)
model = torch.load('imgtext_model_trained1-newlib.pth')

mtrainer = Trainer(model, support_dataset.class_labels(), device)


In [17]:
mtrainer.model.text_model.set_backbone_trainable(False)
mtrainer.model.img_model.set_backbone_trainable(False)
mtrainer.model.logit_scale.requires_grad = False

In [18]:
mtrainer.run_train(2, support_loader, query_loader, lr=1e-5)

Batch 1: loss 27.081090927124023
Batch 2: loss 29.039323806762695
Batch 3: loss 28.183367411295574
Batch 4: loss 28.33365774154663
Batch 5: loss 28.183285522460938
Batch 6: loss 27.916765530904133
Batch 7: loss 28.167078835623606
Batch 8: loss 28.556424856185913
Batch 9: loss 28.786416583591038
Batch 10: loss 28.737309837341307
Batch 11: loss 28.55980821089311
Batch 12: loss 28.452989101409912


In [9]:
mtrainer.run_eval(mtrainer.best_model, query_loader)
# (62.38775510204083, 0.8102664557552944, 140.3484375)

(58.214285714285715, 0.8097899599929083, 137.38421936035155)

In [10]:
mtrainer.run_eval(mtrainer.model, query_loader)

(58.2142857142857, 0.8094589455247915, 137.3100372314453)

In [13]:
torch.save(mtrainer.best_model, 'imgtext_model_trained1.pth')