In [20]:
import torch
from torch import nn
import torchvision

from transformers import AutoTokenizer, AutoModel

def load_pretrained_resnet(img_channels, num_classes, save_path, fc_bias=True):
    model = torchvision.models.resnet50(num_classes=num_classes, weights=None)
    if fc_bias == False:
        model.fc = nn.Linear(2048, num_classes, bias=False)
    model.conv1 = torch.nn.Conv2d(img_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
    model.load_state_dict(torch.load(save_path))
    return model

def resnet_backbone(model):
    return torch.nn.Sequential(*(list(model.children())[:-2]))

def load_medclip_retrained_resnet(path):
    return resnet_backbone(load_pretrained_resnet(1, 512, path, False))

class ImageEncoder(nn.Module):
    def __init__(self, backbone, embed_dims, freeze_backbone=False):
        super().__init__()
        self.backbone = backbone
        self.proj = nn.Linear(2048, embed_dims)
        if freeze_backbone:
            self.set_backbone_trainable(False)
    
    def set_backbone_trainable(self, trainable):
        for param in self.backbone.parameters():
            param.requires_grad = trainable

    def forward(self, input):
        # B, C, H, W
        img = self.backbone(input)
        # B, C, H, W -> B, H, W, C -> B, D, H, W
        return self.proj(img.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)


class TextEncoder(nn.Module):
    def __init__(self, embed_dims, device='cpu', freeze_backbone=True):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.backbone = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.proj = nn.Linear(768, embed_dims)
        self.device = device
        if freeze_backbone:
            self.set_backbone_trainable(False)

    def set_backbone_trainable(self, trainable):
        for param in self.backbone.parameters():
            param.requires_grad = trainable
    
    def forward(self, input):
        tokens = self.tokenizer(input, max_length=77, return_tensors='pt', padding='max_length').to(self.device)
        out = self.backbone(**tokens)
        enc = out['pooler_output']
        # enc = out['last_hidden_state'][:, 0]
        return self.proj(enc)

class ImageTextEmbedding(nn.Module):
    def __init__(self, img_backbone, embed_dims, logit_scale_init_value=0.1, device='cpu'):
        super().__init__()
        self.text_model = TextEncoder(embed_dims, device)
        self.img_model = ImageEncoder(img_backbone, embed_dims)
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/logit_scale_init_value)))
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.flatten = nn.Flatten(start_dim=1)
        self.criterion = nn.CrossEntropyLoss()
    
    def embed_text(self, text):
        text_emb = self.text_model(text)
        return text_emb / text_emb.norm(dim=-1, keepdim=True)
    
    def embed_image(self, image, pool=False):
        img_emb = self.img_model(image) # B, D, H, W
        if pool:
            img_emb = self.flatten(self.gap(img_emb)) # B, D
        return img_emb / img_emb.norm(dim=-1, keepdim=True)

    
    def compute_logits(self, text_emb, img_emb):
        self.logit_scale.data = torch.clamp(self.logit_scale.data, 0, 4.6052)
        logit_scale = self.logit_scale.exp()        
        if len(img_emb.shape) == 4:
            logits_per_image = logit_scale * torch.matmul(img_emb.permute(2,3,0,1), text_emb.t())
        else:
            logits_per_image = logit_scale * torch.matmul(img_emb, text_emb.t())
        
        if len(img_emb.shape) == 4:
            logits_per_text = logits_per_image.permute(0,1,3,2) # HxWxBxB
        else:
            logits_per_text = logits_per_image.t()
        return logits_per_text, logits_per_image
        
    def forward(self, text, img, pool=False):
        text_emb = self.embed_text(text)
        img_emb = self.embed_image(img, pool)

        return text_emb, img_emb
    
    def contrastive_logit_loss(self, logits_per_text, logits_per_image, labels):
         # Image-label contrastive loss, which is similar to classification loss, except using the computed logits
        itl = self.criterion(logits_per_image, labels)
        til = self.criterion(logits_per_text, labels.t())
        return (itl+til) / 2
    
    def loss(self, text_emb, img_emb, labels):
        # text_embed should be an NxD matrix where N is the number of classes, so each row is the text embedding for the ith class
        # image embed: BxD
        # labels is an BxN indicator matrix with 1 for each class an image belongs to
        logits_per_text, logits_per_image = self.compute_logits(text_emb, img_emb)
        
        return self.contrastive_logit_loss(logits_per_text, logits_per_image, labels)

In [25]:
import copy
from utils.metrics import AverageMeter, calculate_auc, multilabel_accuracy

class Trainer:
    def __init__(self, model, class_labels, device='cpu'):
        self.model = model
        self.device = device
        self.class_labels = class_labels
    
    def run_train(self, epochs, dataloader, val_dataloader):
        model = self.model.to(self.device)
        best_epoch = None
        best_acc = None
        # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-2)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        for epoch in range(epochs):
            model.train()
            loss_meter = AverageMeter()
            for i, (images, class_inds) in enumerate(dataloader):
                images, class_inds = images.to(self.device), class_inds.to(self.device)
                optimizer.zero_grad()

                text_embeddings, image_embeddings = model(self.class_labels, images, pool=True)
                loss = model.loss(text_embeddings, image_embeddings, class_inds)
                loss.backward()
                optimizer.step()

                loss_meter.update(loss.item(), len(class_inds))
                print(f"Batch {i+1}: loss {loss_meter.average()}")
            
            print(f"Epoch {epoch+1}: Training loss {loss_meter.average()}")
            val_acc, val_auc, val_loss = self.run_eval(model, val_dataloader)
            print(f"Epoch {epoch+1}: Validation loss {val_loss} | Accuracy {val_acc} | AUC {val_auc}")
            
            if best_acc is None or val_acc > best_acc:
                best_acc = val_acc
                self.best_model = copy.deepcopy(model)
                best_epoch = epoch
        print('Best epoch: ', best_epoch+1)

    def run_eval(self, model, dataloader):
        model.eval()
        model = model.to(self.device)
        
        loss_meter = AverageMeter()
        auc_meter = AverageMeter()
        acc_meter = AverageMeter()
        with torch.no_grad():
            for images, class_inds in dataloader:
                images, class_inds = images.to(self.device), class_inds.to(self.device)
                text_embeddings, image_embeddings = model(self.class_labels, images, pool=True)

                logits_per_text, logits_per_image = model.compute_logits(text_embeddings, image_embeddings)
        
                loss = model.contrastive_logit_loss(logits_per_text, logits_per_image, class_inds)
                loss_meter.update(loss.item(), len(class_inds))

                auc = calculate_auc(logits_per_image, class_inds)
                auc_meter.update(auc, len(class_inds))
                
                acc = multilabel_accuracy(logits_per_image, class_inds)
                acc_meter.update(acc, len(class_inds))
        return acc_meter.average(), auc_meter.average(), loss_meter.average()

In [22]:
import pickle
import pandas as pd
from torch.utils.data import DataLoader

from utils.labels import VINDR_CXR_LABELS, VINDR_SPLIT
from models.backbone.datasets import MEAN_STDS
from utils.device import get_device
from models.embedding.dataset import Dataset

# with open('data/vindr_train_query_set.pkl', 'rb') as fp:
#     cxr_train_query = pickle.load(fp)

# query_image_ids = []
# for ids in cxr_train_query.values():
#     query_image_ids.extend(ids)

def get_query_and_support_ids(img_info, split_file):
    with open(split_file, 'rb') as fp:
        cxr_train_query = pickle.load(fp)
    query_image_ids = []
    for ids in cxr_train_query.values():
        query_image_ids.extend(ids)
    support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()
    return query_image_ids, support_image_ids

img_info = pd.read_pickle('data/vindr_cxr_split_labels.pkl')
query_image_ids, support_image_ids = get_query_and_support_ids(img_info, 'data/vindr_train_query_set.pkl')
# support_image_ids = img_info[(img_info['meta_split'] == 'train') & ~img_info['image_id'].isin(query_image_ids)]['image_id'].to_list()

IMG_PATH = 'datasets/vindr-cxr-png'
batch_size = 10*14
query_dataset = Dataset(IMG_PATH, img_info, query_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
query_loader = DataLoader(dataset=query_dataset, batch_size=batch_size, shuffle=True)
support_dataset = Dataset(IMG_PATH, img_info, support_image_ids, VINDR_CXR_LABELS, VINDR_SPLIT['train'], mean_std=MEAN_STDS['chestmnist'])
support_loader = DataLoader(dataset=support_dataset, batch_size=batch_size, shuffle=True)

PROJ_SIZE = 512
device = get_device()
# backbone = resnet_backbone(load_pretrained_resnet(1, 14, 'models/backbone/pretrained/cxr_backbone_bal.pkl'))
backbone = load_medclip_retrained_resnet('models/backbone/pretrained/medclip_resnet50.pkl')
model = ImageTextEmbedding(backbone, PROJ_SIZE, device=device)
# model.img_model.set_backbone_trainable(False)
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-2)
mtrainer = Trainer(model, support_dataset.class_labels(), device)


Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
bb = torchvision.models.resnet50()
bb.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model_baseline = ImageTextEmbedding(resnet_backbone(bb), PROJ_SIZE, device=device)

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [23]:
mtrainer.run_eval(mtrainer.model, query_loader)

(65.24489822387696, 0.46013339092210115, 127.54764556884766)

In [24]:
mtrainer.run_train(10, support_loader, query_loader)

Batch 1: loss 33.69576644897461
Batch 2: loss 35.70352745056152
Batch 3: loss 36.39114507039388
Batch 4: loss 37.13524913787842
Batch 5: loss 37.716374206542966
Batch 6: loss 37.22264162699381
Batch 7: loss 36.68990707397461
Batch 8: loss 36.2630820274353
Batch 9: loss 35.733789655897354
Batch 10: loss 35.663933181762694
Batch 11: loss 35.68782286210494
Batch 12: loss 35.78476587931315
Batch 13: loss 35.8788701570951
Batch 14: loss 35.878697531563894
Batch 15: loss 35.65770772298177
Batch 16: loss 35.810307025909424
Batch 17: loss 35.901333304012525
Batch 18: loss 36.104226430257164
Batch 19: loss 35.97945122969778
Batch 20: loss 35.83519229888916
Batch 21: loss 35.72511182512556
Batch 22: loss 35.65146879716353
Batch 23: loss 35.60612769748854
Batch 24: loss 35.58307123184204
Batch 25: loss 35.528944702148436
Batch 26: loss 35.56490604694073
Batch 27: loss 35.5359984503852
Batch 28: loss 35.60614095415388
Batch 29: loss 35.529578504891234
Batch 30: loss 35.58520647684733
Batch 31: los

KeyboardInterrupt: 

In [None]:
mtrainer.run_eval(mtrainer.best_model, query_loader)