In [1]:
import torch

from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder, resnet_backbone

In [2]:
import pickle
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from utils.device import get_device
from models.embedding.dataset import Dataset
from models.embedding.trainer import Trainer

def get_query_and_support_ids(img_info, split_file):
    with open(split_file, 'rb') as fp:
        cxr_train_query = pickle.load(fp)
    query_image_ids = []
    for ids in cxr_train_query.values():
        query_image_ids.extend(ids)
    support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()
    return query_image_ids, support_image_ids

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')
# support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14
query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, shuffle=True)

PROJ_SIZE = 512
device = get_device()
# backbone = resnet_backbone(load_pretrained_resnet(1, 14, 'models/backbone/pretrained/cxr_backbone_bal.pkl'))

# backbone = load_medclip_retrained_resnet('models/backbone/pretrained/medclip_resnet50.pkl')
# model = ImageTextEmbedding(backbone, PROJ_SIZE, device=device)
model = torch.load('imgtext_model_trained1-newlib.pth')

mtrainer = Trainer(model, support_dataset.class_labels(), device)


In [3]:
mtrainer.model.text_model.set_backbone_trainable(False)
mtrainer.model.img_model.set_backbone_trainable(False)
mtrainer.model.logit_scale.requires_grad = False

In [4]:
mtrainer.run_train(2, support_loader, query_loader, lr=1e-5)

Batch 1: loss 28.469074249267578
Batch 2: loss 29.466032028198242
Batch 3: loss 29.55094846089681
Batch 4: loss 29.260549068450928
Batch 5: loss 29.87935600280762
Batch 6: loss 29.490957578023274
Batch 7: loss 29.538506371634348
Batch 8: loss 29.57067632675171
Batch 9: loss 29.427893108791775
Batch 10: loss 29.408737564086913
Batch 11: loss 29.4022967598655
Batch 12: loss 29.30565357208252
Batch 13: loss 29.34943727346567
Batch 14: loss 29.5397915158953
Batch 15: loss 29.41028429667155
Batch 16: loss 29.50651216506958
Batch 17: loss 29.42297699872185
Batch 18: loss 29.336332427130806
Batch 19: loss 29.348928150377777
Batch 20: loss 29.349716663360596
Batch 21: loss 29.540258589245024
Batch 22: loss 29.593233108520508
Batch 23: loss 29.482493773750637
Batch 24: loss 29.557597796122234
Batch 25: loss 29.48684310913086
Batch 26: loss 29.467873353224533
Batch 27: loss 29.41236729092068
Batch 28: loss 29.47128200531006
Batch 29: loss 29.590852671656116
Batch 30: loss 29.52819569905599
Batch

In [5]:
mtrainer.run_eval(mtrainer.best_model, query_loader)
# (62.38775510204083, 0.8102664557552944, 140.3484375)

(54.377551020408156, 0.8127540973974094, 133.65570678710938)

In [18]:
mtrainer.run_eval(mtrainer.model, query_loader)

(62.31632653061224, 0.8074608790767269, 140.04942321777344)

In [6]:
torch.save(mtrainer.best_model, 'imgtext_model_trained1-newlib.pth')