In [1]:
import torch 
import torch.nn as nn
from typing import List, Union, Tuple, Dict, Optional

import numpy as np

from transformers import AutoTokenizer
from transformers import BertModel, AutoModel, ViTImageProcessor

import torchxrayvision as xrv
from fusion import Fusion
from image_processor import ae_image_processor, timm_image_processor
from utils import load_timm_model

import skimage

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
text_pretrained_available = [
    "bert-base-uncased", 
    "dmis-lab/biobert-v1.1", 
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", 
    "microsoft/BiomedVLP-CXR-BERT-general", 
]

# vision_pretrained = "google/vit-base-patch16-224"
# vision_model = xrv.autoencoders.ResNetAE(weights="101-elastic")

vision_model = load_timm_model('swin_base_patch4_window7_224', pretrained=True, retain_head=False)

text_pretrained = "microsoft/BiomedVLP-CXR-BERT-general"
text_model = BertModel.from_pretrained(text_pretrained)

model = Fusion(
    text_model=text_model,
    vision_model=vision_model,
    vision_model_type='timm', 
    vision_output_dim=1024,
    projection_dim=768,
)

Some weights of the model checkpoint at microsoft/BiomedVLP-CXR-BERT-general were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
tokenizer = AutoTokenizer.from_pretrained(text_pretrained)
image_processor = lambda x: ViTImageProcessor()(x, return_tensors="pt", return_dict=True)

img_path = 'sample.jpeg'
img = skimage.io.imread(img_path)
imgs = np.stack([img, img])
vision_inputs = image_processor(imgs)

In [4]:
text_inputs = tokenizer(
    text=["Nodule", "Lung Lesion"], 
    return_tensors="pt", padding=True, 
)
inputs = {**vision_inputs, **text_inputs}
outputs = model(**inputs, return_dict=True, return_loss=True)