In [1]:
import torch
# from torch import nn
# import torchvision

from models.embedding.model import ImageTextEmbedding, TextEncoder, ImageEncoder, resnet_backbone

In [2]:
import pickle
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from utils.device import get_device
from models.embedding.dataset import Dataset
from models.embedding.trainer import Trainer

def get_query_and_support_ids(img_info, split_file):
    with open(split_file, 'rb') as fp:
        cxr_train_query = pickle.load(fp)
    query_image_ids = []
    for ids in cxr_train_query.values():
        query_image_ids.extend(ids)
    support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()
    return query_image_ids, support_image_ids

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')
# support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14
query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, shuffle=True)

PROJ_SIZE = 512
device = get_device()
# backbone = resnet_backbone(load_pretrained_resnet(1, 14, 'models/backbone/pretrained/cxr_backbone_bal.pkl'))

# backbone = load_medclip_retrained_resnet('models/backbone/pretrained/medclip_resnet50.pkl')
# model = ImageTextEmbedding(backbone, PROJ_SIZE, device=device)
model = torch.load('imgtext_model_trained.pth')

mtrainer = Trainer(model, support_dataset.class_labels(), device)


In [15]:
mtrainer.model.text_model.set_backbone_trainable(False)
mtrainer.model.img_model.set_backbone_trainable(False)
mtrainer.model.logit_scale.requires_grad = False

In [9]:
mtrainer.run_train(2, support_loader, query_loader, lr=1e-5)

Batch 1: loss 26.7219181060791
Batch 2: loss 28.940580368041992
Batch 3: loss 29.15361785888672
Batch 4: loss 29.56848907470703
Batch 5: loss 29.39850196838379
Batch 6: loss 28.922122637430828
Batch 7: loss 29.026561737060547
Batch 8: loss 28.973043203353882
Batch 9: loss 29.296273761325413
Batch 10: loss 29.4564640045166
Batch 11: loss 29.32164226878773
Batch 12: loss 29.078742186228435
Batch 13: loss 29.135898883526142
Batch 14: loss 29.11204787663051
Batch 15: loss 29.288188934326172
Batch 16: loss 29.281124711036682
Batch 17: loss 29.217458724975586
Batch 18: loss 29.2834259668986
Batch 19: loss 29.12468599018298
Batch 20: loss 29.078123378753663
Batch 21: loss 29.029097420828684
Batch 22: loss 29.09607947956432
Batch 23: loss 29.01464371059252
Batch 24: loss 29.087071816126507
Batch 25: loss 29.007728118896484
Batch 26: loss 28.967255665705753
Batch 27: loss 29.072203247635453
Batch 28: loss 29.029563426971436
Batch 29: loss 28.959361438093513
Batch 30: loss 28.96729475657145
Batc

In [17]:
mtrainer.run_eval(mtrainer.best_model, query_loader)

(62.38775510204083, 0.8102664557552944, 140.3484375)

In [18]:
mtrainer.run_eval(mtrainer.model, query_loader)

(62.31632653061224, 0.8074608790767269, 140.04942321777344)

In [19]:
torch.save(mtrainer.best_model, 'imgtext_model_trained1.pth')