In [1]:
import json
from argparse import Namespace
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
import torch
import numpy as np
from datasets import load_dataset, DatasetDict
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop

In [2]:
device = "cpu"

In [10]:
%load_ext autoreload
%autoreload 2

## Check provided checkpoint
Reference: https://huggingface.co/datasets/axiong/pmc_oa_beta/blob/main/checkpoint.pt

In [3]:
checkpoint_path = "./checkpoints/pmc_clip/checkpoint.pt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")

In [4]:
state_dict = checkpoint["state_dict"]

In [None]:
# additional key in checkpoint
# module.text_encoder.embeddings.position_ids => delete this key, value

In [11]:
# state_dict["module.text_encoder.embeddings.position_ids"]

In [12]:
# for key in state_dict.keys():
#     print(key)

## Initialize model

In [5]:
config_path = "Text_Enhanced_MedCLIP/pmc_clip/model_configs/RN50_fusion4.json"
model_config = json.load(open(config_path))
model_config

{'embed_dim': 768,
 'clip_model': 'PMC_CLIP',
 'vision_cfg': {'image_size': 224,
  'layers': [3, 4, 6, 3],
  'width': 64,
  'patch_size': None},
 'text_cfg': {'context_length': 77,
  'vocab_size': 30522,
  'width': 768,
  'heads': 8,
  'layers': 12,
  'fusion_layers': 4,
  'bert_model_name': 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'}}

In [6]:
from Text_Enhanced_MedCLIP.pmc_clip.model import PMC_CLIP

In [7]:
args = dict(bert_model_name=model_config['text_cfg']['bert_model_name'],
            device=device,
            mlm=True)
args = Namespace(**args)
model_config["args"] = args
model_config.pop("clip_model")

'PMC_CLIP'

In [12]:
model = PMC_CLIP(**model_config)

Init3


In [9]:
# model.state_dict().keys()

In [13]:
sd = {k[len('module.'):]: v for k, v in state_dict.items()}
if "text_encoder.embeddings.position_ids" in sd:
    del sd["text_encoder.embeddings.position_ids"]
model.load_state_dict(sd)

<All keys matched successfully>

In [14]:
# sd["text_encoder.embeddings.position_ids"] # tensor arrange 0-511

## Data

In [14]:
dataset = load_dataset("flaviagiammarino/vqa-rad")
dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 1793
    })
    test: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 451
    })
})

### Data Preprocess

In [46]:
train_val_dataset = dataset["train"].train_test_split(test_size=0.125, seed=123)
train_val_test_dataset = DatasetDict({'train': train_val_dataset['train'],
                                        'val': train_val_dataset['test'],
                                        'test': dataset['test']})

# binary task
train_val_test_dataset = train_val_test_dataset.filter(lambda example: example["answer"].lower() in ("yes", "no"))
train_val_test_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 821
    })
    val: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 119
    })
    test: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 251
    })
})

In [47]:
def preprocess(batch):
    
    # binary
    batch["labels"] = [1 if answer.lower() == "yes" else 0 for answer in batch["answer"]]
    # mutli-class not implemented
    
    batch['bert_input'] = [question for question in batch['question']] # pmc-clip tokenize text inputs in the forward call
    batch['bert_label'] = [question for question in batch['question']]
    
    return batch

In [48]:
processed_dataset = train_val_test_dataset.map(preprocess, batched=True)
processed_dataset = processed_dataset.remove_columns(["question", "answer"])
# processed_dataset = processed_dataset.rename_column("image", "images")
processed_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'labels', 'bert_input', 'bert_label'],
        num_rows: 821
    })
    val: Dataset({
        features: ['image', 'labels', 'bert_input', 'bert_label'],
        num_rows: 119
    })
    test: Dataset({
        features: ['image', 'labels', 'bert_input', 'bert_label'],
        num_rows: 251
    })
})

### Image transform

In [49]:
image_size = model.visual.image_size
crop_scale = 0.9 # follow pmc-clip pre-training
mean = (0.48145466, 0.4578275, 0.40821073)  # OpenAI dataset mean
std = (0.26862954, 0.26130258, 0.27577711)  # OpenAI dataset std

def _convert_to_rgb(image):
    return image.convert('RGB')

train_image_transform =  Compose([
                                RandomResizedCrop(image_size, scale=(crop_scale, 1.0), interpolation=InterpolationMode.BICUBIC),
                                _convert_to_rgb,
                                ToTensor(),
                                Normalize(mean=mean, std=std),
                            ])
test_image_transform = Compose([
                            Resize(image_size, interpolation=InterpolationMode.BICUBIC),
                            CenterCrop(image_size),
                            _convert_to_rgb,
                            ToTensor(),
                            Normalize(mean=mean, std=std)
                        ])

In [54]:
def train_transform(batch):
    batch['image'] = [train_image_transform(img) for img in batch['image']]
    return batch
def test_transform(batch):
    batch['image'] = [test_image_transform(img) for img in batch['image']]

In [55]:
processed_dataset['train'].set_transform(train_transform)
processed_dataset['val'].set_transform(test_transform)
processed_dataset['test'].set_transform(test_transform)

In [56]:
batch = processed_dataset['train'][:2]
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(k, v.shape)
    else:
        print(k, v)

image [tensor([[[-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
         [-1.7923, -1.7923, -1.7923,  ..., -1.7777, -1.7923, -1.7923],
         [-1.7923, -1.7923, -1.7923,  ..., -1.4857, -1.7923, -1.7923],
         ...,
         [-1.7923, -1.7923, -1.7923,  ..., -1.7923, -1.7923, -1.7923],
         [-1.7777, -1.7777, -1.7777,  ..., -1.7777, -1.7777, -1.7777],
         [-1.7777, -1.7777, -1.7777,  ..., -1.7777, -1.7777, -1.7777]],

        [[-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
         [-1.7521, -1.7521, -1.7521,  ..., -1.7371, -1.7521, -1.7521],
         [-1.7521, -1.7521, -1.7521,  ..., -1.4369, -1.7521, -1.7521],
         ...,
         [-1.7521, -1.7521, -1.7521,  ..., -1.7521, -1.7521, -1.7521],
         [-1.7371, -1.7371, -1.7371,  ..., -1.7371, -1.7371, -1.7371],
         [-1.7371, -1.7371, -1.7371,  ..., -1.7371, -1.7371, -1.7371]],

        [[-1.4802, -1.4802, -1.4802,  ..., -1.4802, -1.4802, -1.4802],
         [-1.4802, -1.4802, -1.4802,  

In [53]:
# batch key
# images, bert_input, bert_label

### Data Collator

In [27]:
def torch_images_and_label_data_collator(features: List[Any]) -> Dict[str, Any]:
    """
        Collate images and label into tensors,
        leave bert_input and bert_label as list of strings,
        which will be tokenized and collated in PMC-CLIP
    """
    first = features[0]
    for k, v in first.items():
        if k not in ("bert_input", "bert_label") and v is not None:
            if isinstance(v, torch.Tensor):
                batch[k] = torch.stack([f[k] for f in features])
            elif isinstance(v, np.ndarray):
                batch[k] = torch.tensor(np.stack([f[k] for f in features]))
            else:
                batch[k] = torch.tensor([f[k] for f in features])
        else:
            batch[k] = [f[k] for f in features]
    return batch

In [81]:
from torch.utils.data import DataLoader

data_loader = DataLoader(processed_dataset['train'], collate_fn=torch_images_and_label_data_collator, batch_size=4)
data_iter = iter(data_loader)
batch = next(data_iter)
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(k, v.shape)
    else:
        print(k, v)

image torch.Size([4, 3, 224, 224])
labels torch.Size([4])
bert_input ['is there evidence of large calcified lesions in the lung fields?', 'is there evidence of midlight shift of structures on this mri?', 'are the colon walls thickened?', 'is there cardiac enlargement?']
bert_label ['is there evidence of large calcified lesions in the lung fields?', 'is there evidence of midlight shift of structures on this mri?', 'are the colon walls thickened?', 'is there cardiac enlargement?']


In [40]:
# satisfy pmc-clip input format
with torch.no_grad():
    output = model(batch)
# output

last_token_index: tensor([[0, 0],
        [1, 0],
        [2, 0],
        [3, 0]])


In [45]:
# model

In [69]:
model.tokenizer.cls_token_id

2

## PMC-CLIP for VQA

In [99]:
from argparse import Namespace
from typing import Any, Optional, Tuple, Union
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from transformers.modeling_outputs import ImageClassifierOutput

from Text_Enhanced_MedCLIP.pmc_clip.model import PMC_CLIP

In [100]:
class PMC_CLIPforVQA(nn.Module):
    """
        Apply a fully connected network on the pooled_output of the fusion module
        Predict a scalar score for binary/multi-class classification
        
        Inputs:
            checkpoint_path: download checkpoint from https://huggingface.co/datasets/axiong/pmc_oa_beta/blob/main/checkpoint.pt
            config_path: RN50_fusion4.json for the provided checkpoint, download configs from PMC-CLIP repo
            text_model_path
            num_labels
            pool_type: "average" or "cls"
            
    """
    def __init__(self, checkpoint_path, config_path, text_model_path=None, num_labels=2, pool_type="average"):
        super().__init__()
        
        model_config = json.load(open(config_path))
        args = dict(bert_model_name=model_config['text_cfg']['bert_model_name'],
                    device=device,
                    mlm=True)
        args = Namespace(**args)
        model_config["args"] = args
        model_config.pop("clip_model")
        self.base_model = PMC_CLIP(**model_config)
        checkpoint = torch.load(checkpoint_path)
        state_dict = checkpoint["state_dict"]
        sd = {k[len('module.'):]: v for k, v in state_dict.items()}
        if "text_encoder.embeddings.position_ids" in sd:
            del sd["text_encoder.embeddings.position_ids"]
        self.base_model.load_state_dict(sd)
        self.cls_id = self.base_model.tokenizer.cls_token_id
        
        if text_model_path:
            print("Load text model weight from:", text_model_path)
            text_model_dict = torch.load(text_model_path)
            text_model_dict = {k: v for k, v in text_model_dict.items() if k.startswith("text_model")}
            base_model_dict = self.base_model.state_dict()
            base_model_dict.update(text_model_dict)
            self.base_model.load_state_dict(base_model_dict)
        
        self.num_labels = num_labels
        projection_dim = self.base_model.transformer_width # 768
        output_dim = 1 if num_labels == 2 else num_labels # scalar output for binary classification
        self.MLP = nn.Sequential(nn.Linear(projection_dim, 512),
                                 nn.ReLU(),
                                 nn.Linear(512, 128),
                                 nn.ReLU(),
                                 nn.Linear(128, output_dim))
        
        self.pool_type = pool_type
        
    def forward(self,
                bert_input: Optional[List[str]] = None,
                bert_label: Optional[List[str]] = None,
                image: Optional[torch.FloatTensor] = None,
                labels: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:

        image_features = self.base_model.encode_image(image)
        image_features = F.normalize(image_features['image_features'], dim=-1)  # (bs, 768)

        batch = dict(bert_input=bert_input, bert_label=bert_label)
        text_output = self.base_model.encode_text(batch, image_features)

        fusion_features = text_output["fusion_features"] # (bs, 79, 768)

        if self.pool_type == "average":
            pooled_feature = fusion_features.mean(dim=1) # (bs, 768)
        elif self.pool_type == "cls":
            last_token_index = torch.nonzero((text_output["encoded_input"]['input_ids'] == self.cls_id).squeeze())
            pooled_feature = fusion_features[torch.arange(fusion_features.shape[0]), last_token_index[:, 1]] # the 0-index of each row

        print("pooled_feature:", pooled_feature.shape)
        logits = self.MLP(pooled_feature).squeeze() # (N,) or (N, C)

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            if self.num_labels == 2:
                # binary classification
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits.squeeze(), labels.squeeze().float())
            else:
                # multi-class classification
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
        )
        

In [106]:
checkpoint_path = "./checkpoints/pmc_clip/checkpoint.pt"
config_path = "Text_Enhanced_MedCLIP/pmc_clip/model_configs/RN50_fusion4.json"
model = PMC_CLIPforVQA(checkpoint_path, config_path, pool_type="cls", num_labels=5)

In [102]:
# batch

In [107]:
model(**batch)

pooled_feature: torch.Size([4, 768])


ImageClassifierOutput(loss=tensor(2.0616, grad_fn=<NllLossBackward0>), logits=tensor([[ 0.0969, -0.3604, -0.2607,  0.5581, -0.5485],
        [ 0.3941, -0.2613,  0.0454,  0.6127, -0.7037],
        [ 0.6128, -0.4206, -0.1830,  0.7622, -0.0140],
        [ 0.3079, -0.8749, -0.1913,  0.5312, -0.4160]],
       grad_fn=<SqueezeBackward0>), hidden_states=None, attentions=None)