In [1]:
import torch
from timm.models import create_model
import utils
from PIL import Image
from transformers import XLMRobertaTokenizer
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
import torchvision
import modeling


def xlm_tokenizer(tokens, tokenizer, max_len=100):
    tokens = tokenizer.encode(tokens)

    tokens = tokens[1:-1]  # remove eos and bos;
    if len(tokens) > max_len - 2:
        tokens = tokens[:max_len - 2]

    tokens = [tokenizer.bos_token_id] + tokens[:] + [tokenizer.eos_token_id]
    num_tokens = len(tokens)
    padding_mask = [0] * num_tokens + [1] * (max_len - num_tokens)

    text_tokens = tokens + [tokenizer.pad_token_id] * (max_len - num_tokens)
    return text_tokens, padding_mask


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0")

# >>>>>>>>>>>> load model >>>>>>>>>>>> #
model_config = "musk_large_patch16_384"
model = create_model(model_config, vocab_size=64010).eval()
model_path = "./models/musk.pth"
utils.load_model_and_may_interpolate(model_path, model, 'model|module', '')
model.to(device, dtype=torch.float16)
model.eval()
# <<<<<<<<<<<< load model <<<<<<<<<<<< #

# >>>>>>>>>>>> process image >>>>>>>>>>> #
# load an image and process it
img_size = 384
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(img_size, interpolation=3, antialias=True),
    torchvision.transforms.CenterCrop((img_size, img_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
])

img = Image.open('./assets/lungaca1014.jpeg').convert("RGB")  # input image
img_tensor = transform(img).unsqueeze(0)
with torch.inference_mode():
    image_embeddings = model(
        image=img_tensor.to(device, dtype=torch.float16),
        with_head=True, 
        out_norm=True
        )[0]  # return (vision_cls, text_cls)
# <<<<<<<<<<< process image <<<<<<<<<<< #

# >>>>>>>>>>> process language >>>>>>>>> #
# load tokenzier for language input
tokenizer = XLMRobertaTokenizer("./models/tokenizer.spm")
labels = ["lung adenocarcinoma",
            "benign lung tissue",
            "lung squamous cell carcinoma"]

texts = ['histopathology image of ' + item for item in labels]
text_ids = []
paddings = []
for txt in texts:
    txt_ids, pad = xlm_tokenizer(txt, tokenizer, max_len=64)
    text_ids.append(torch.tensor(txt_ids).unsqueeze(0))
    paddings.append(torch.tensor(pad).unsqueeze(0))

text_ids = torch.cat(text_ids)
paddings = torch.cat(paddings)
with torch.inference_mode():
    text_embeddings = model(
        text_description=text_ids.to(device),
        padding_mask=paddings.to(device),
        with_head=True, 
        out_norm=True
    )[1]  # return (vision_cls, text_cls)
# <<<<<<<<<<<< process language <<<<<<<<<<< #

# >>>>>>>>>>>>> calculate similarity >>>>>>> #
with torch.inference_mode():
    # expected prob:[0.3784, 0.3250, 0.2964]  --> lung adenocarcinoma
    sim = model.logit_scale * image_embeddings @ text_embeddings.T
    prob = sim.softmax(dim=-1)
    print(prob)

Load ckpt from ./models/musk.pth
tensor([[0.3784, 0.3250, 0.2964]], device='cuda:0', dtype=torch.float16)
