In [17]:
from tqdm import tqdm

import torch
from datasets import load_dataset, DatasetDict
from transformers import CLIPProcessor, CLIPModel, AutoModel, AutoTokenizer,  AutoImageProcessor, CLIPImageProcessor

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
dataset = load_dataset("flaviagiammarino/vqa-rad")
train_val_dataset = dataset["train"].train_test_split(test_size=0.125, seed=123)
train_val_test_dataset = DatasetDict({'train': train_val_dataset['train'],
                                      'val': train_val_dataset['test'],
                                      'test': dataset['test']})
close_ended_train_val_test_dataset = train_val_test_dataset.filter(lambda example: example["answer"].lower() in ("yes", "no"))
close_ended_train_val_test_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 821
    })
    val: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 119
    })
    test: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 251
    })
})

In [6]:
for i in range(10):
    print(close_ended_train_val_test_dataset["train"][i]["question"])

is there evidence of large calcified lesions in the lung fields?
is there evidence of midlight shift of structures on this mri?
are the colon walls thickened?
is there cardiac enlargement?
is there a pneumothorax?
is there a mass demonstrated?
is the jejunal wall enlarged?
is there an aortic aneurysm?
is the liver normal?
are the sulci visible in this image?


In [7]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

  return self.fget.__get__(instance, owner)()


In [8]:
question = close_ended_train_val_test_dataset["train"][0]["question"]
answer = close_ended_train_val_test_dataset["train"][0]["answer"]
image = close_ended_train_val_test_dataset["train"][0]["image"]

In [9]:
template = lambda question: [question + " Yes.", question + " No."]
template(question)

['is there evidence of large calcified lesions in the lung fields? Yes.',
 'is there evidence of large calcified lesions in the lung fields? No.']

In [15]:
inputs = processor(text=template(question), images=image, return_tensors="pt", padding=True)

with torch.no_grad():
    outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
pred_id = probs.argmax().item()
prediction = "yes" if pred_id == 0 else "no"
print(answer, prediction)

no no


In [19]:
for split in close_ended_train_val_test_dataset:
    print(split)
    for i in tqdm(range(len(close_ended_train_val_test_dataset[split]))):
        answer = close_ended_train_val_test_dataset[split][i]["answer"]
        assert answer == "yes" or answer == "no"

train


100%|██████████████████████████████████████████████████████████| 821/821 [00:12<00:00, 63.20it/s]


val


100%|█████████████████████████████████████████████████████████| 119/119 [00:00<00:00, 156.42it/s]


test


100%|█████████████████████████████████████████████████████████| 251/251 [00:01<00:00, 190.84it/s]


## Zero-shot inference on whole dataset

In [22]:
def zero_shot_eval(dataset, template):
    correct = 0
    for i in tqdm(range(len(dataset))):
        question = dataset[i]["question"]
        answer = dataset[i]["answer"]
        image = dataset[i]["image"]


        inputs = processor(text=template(question), images=image, return_tensors="pt", padding=True)

        with torch.no_grad():
            outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image # this is the image-text similarity score
        probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
        pred_id = probs.argmax().item()
        prediction = "yes" if pred_id == 0 else "no"
        if answer == prediction:
            correct += 1
    return correct / len(dataset)

In [23]:
template = lambda question: [question + " Yes.", question + " No."]
zero_shot_eval(close_ended_train_val_test_dataset["val"], template)

100%|██████████████████████████████████████████████████████████| 119/119 [00:18<00:00,  6.38it/s]


0.48739495798319327

In [24]:
template = lambda question: [f"Q: {question} A: Yes.", f"Q: {question} A: No."]
zero_shot_eval(close_ended_train_val_test_dataset["val"], template)

100%|██████████████████████████████████████████████████████████| 119/119 [00:19<00:00,  6.15it/s]


0.4369747899159664

In [25]:
template = lambda question: [f"The answer of the question: {question} is Yes", 
                             f"The answer of the question: {question} is No"]
zero_shot_eval(close_ended_train_val_test_dataset["val"], template)

100%|██████████████████████████████████████████████████████████| 119/119 [00:19<00:00,  6.25it/s]


0.5714285714285714

In [26]:
template = lambda question: [f"The answer of the question: {question} is Yes", 
                             f"The answer of the question: {question} is No"]
zero_shot_eval(close_ended_train_val_test_dataset["test"], template)

100%|██████████████████████████████████████████████████████████| 251/251 [00:40<00:00,  6.15it/s]


0.50199203187251

## Zero-shot inference on general domain

In [35]:
def eval_one_example(text, image):
    inputs = processor(text=text, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image # this is the image-text similarity score
    probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
    return probs

In [32]:
from PIL import Image 
 
# Load the image
image = Image.open("cat.png")

In [36]:
question = "Is this a cat?"
template = lambda question: [f"The answer of the question: {question} is Yes", 
                             f"The answer of the question: {question} is No"]
text = template(question)
eval_one_example(text, image)

tensor([[0.5688, 0.4312]])

In [37]:
text = ["This is a cat.", "This is a dog."]
eval_one_example(text, image)

tensor([[0.9855, 0.0145]])

In [38]:
text = ["This is a cat.", "This is not a cat."]
eval_one_example(text, image)

tensor([[0.5849, 0.4151]])