In [1]:
import yaml
import torch
from ALBEF.models.model_pretrain import ALBEF
from ALBEF.models.tokenization_bert import BertTokenizer
from torchvision import transforms
import torch.nn.functional as F
from pathlib import Path
from PIL import Image

In [2]:
# Load config
config = yaml.safe_load(open("configs/Pretrain.yaml"))
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = ALBEF(config=config, text_encoder="bert-base-uncased", tokenizer=tokenizer, init_deit=False)

In [None]:
# Load checkpoint
ckpt = torch.load("output_mimic/checkpoint_29.pth", map_location="cpu")
model.load_state_dict(ckpt["model"], strict=False)
model.cuda()
model.eval()

In [None]:
# One test image from VinDr
img_path = Path("Test") / "002a34c58c5b758217ed1f584ccbcfe9.png"
img = Image.open(img_path).convert("RGB")

normalize = transforms.Normalize(
    (0.48145466, 0.4578275, 0.40821073),
    (0.26862954, 0.26130258, 0.27577711),
)
transform = transforms.Compose([
    transforms.Resize((config["image_res"], config["image_res"])),
    transforms.ToTensor(),
    normalize,
])
img_t = transform(img).unsqueeze(0).cuda()  # (1,3,H,W)

In [None]:
# Few prompts
prompts = [
    "No finding",
    "Pneumothorax",
    "Pleural Effusion",
]
text_inputs = tokenizer(prompts, padding=True, truncation=True, max_length=25, return_tensors="pt").cuda()

In [None]:
def get_image_text_features(model, images, text_inputs):
    image_embeds = model.visual_encoder(images)
    image_cls = image_embeds[:, 0, :]
    image_feat = model.vision_proj(image_cls)
    image_feat = F.normalize(image_feat, dim=-1)

    text_output = model.text_encoder(
        input_ids=text_inputs.input_ids,
        attention_mask=text_inputs.attention_mask,
        return_dict=True,
    )
    text_cls = text_output.last_hidden_state[:, 0, :]
    text_feat = model.text_proj(text_cls)
    text_feat = F.normalize(text_feat, dim=-1)

    return image_feat, text_feat

In [None]:
with torch.no_grad():
    image_feat, text_feat = get_image_text_features(model, img_t, text_inputs)
    sims = image_feat @ text_feat.t()  # (1, num_prompts)
    print("Similarity scores:", sims.squeeze().cpu().tolist())