In [17]:
def args_to_str(args:str, var:dict={'SAVED_EMBEDDINGS_DIR':"/vol/bitbucket/jq619/individual-project/saved_embeddings"}):
    for k, v in var.items():
        args = args.replace('$'+k, v)
    args = args.split(' ')[2:]
    args = [f"\"{arg}\"" for arg in args]
    args = ", \n".join(args)
    return args

In [19]:
args = 'python3 ./pretrain.py --vision_model_type ae --vision_pretrained 101-elastic --text_pretrained dmis-lab/biobert-v1.1 --batch_size 128 --data_pct 0.01 --num_hidden_layers 1 --num_of_batches 100 --num_train_epochs 10 --seed 42 --text_embeds_raw_dir $SAVED_EMBEDDINGS_DIR/text_embeds/BioBERT --image_embeds_raw_dir $SAVED_EMBEDDINGS_DIR/image_embeds/ResNetAE --output_dir ./results/BioBERT_ResNetAE'
print(args_to_str(args))

"--vision_model_type", 
"ae", 
"--vision_pretrained", 
"101-elastic", 
"--text_pretrained", 
"dmis-lab/biobert-v1.1", 
"--batch_size", 
"128", 
"--data_pct", 
"0.01", 
"--num_hidden_layers", 
"1", 
"--num_of_batches", 
"100", 
"--num_train_epochs", 
"10", 
"--seed", 
"42", 
"--text_embeds_raw_dir", 
"/vol/bitbucket/jq619/individual-project/saved_embeddings/text_embeds/BioBERT", 
"--image_embeds_raw_dir", 
"/vol/bitbucket/jq619/individual-project/saved_embeddings/image_embeds/ResNetAE", 
"--output_dir", 
"./results/BioBERT_ResNetAE"


In [1]:
import torch 
import torch.nn as nn
from typing import List, Union, Tuple, Dict, Optional

import numpy as np

from transformers import AutoTokenizer
from transformers import BertModel, AutoModel, ViTImageProcessor

import torchxrayvision as xrv
from adaptor import Adaptor
from image_processor import ae_image_processor, timm_image_processor
from utils import load_timm_model

import skimage

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
text_pretrained_available = [
    "bert-base-uncased", 
    "dmis-lab/biobert-v1.1", 
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", 
    "microsoft/BiomedVLP-CXR-BERT-general", 
]

# vision_model = xrv.autoencoders.ResNetAE(weights="101-elastic")

vision_model = load_timm_model('swin_base_patch4_window7_224', pretrained=True, retain_head=False)

# text_pretrained = "microsoft/BiomedVLP-CXR-BERT-general"
text_pretrained = "./weights/ClinicalBERT_checkpoint/ClinicalBERT_pretraining_pytorch_checkpoint"
text_model = BertModel.from_pretrained(text_pretrained)

model = Adaptor(
    text_model=text_model,
    vision_model=vision_model,
    vision_model_type='timm', 
    vision_output_dim=1024,
    projection_dim=768,
)

Some weights of the model checkpoint at ./weights/ClinicalBERT_checkpoint/ClinicalBERT_pretraining_pytorch_checkpoint were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
tokenizer = AutoTokenizer.from_pretrained(text_pretrained)
image_processor = lambda x: ViTImageProcessor()(x, return_tensors="pt", return_dict=True)

img_path = 'sample.jpeg'
img = skimage.io.imread(img_path)
imgs = np.stack([img, img])
vision_inputs = image_processor(imgs)

In [4]:
text_inputs = tokenizer(
    text=["Nodule", "Lung Lesion"], 
    return_tensors="pt", padding=True, 
)
inputs = {**vision_inputs, **text_inputs}
outputs = model(**inputs, return_dict=True, return_loss=True)

In [5]:
outputs

CLIPOutput(loss=tensor(0.6970, grad_fn=<DivBackward0>), logits_per_image=tensor([[5.6564, 5.7795],
        [5.4444, 5.5670]], grad_fn=<PermuteBackward0>), logits_per_text=tensor([[5.6564, 5.4444],
        [5.7795, 5.5670]], grad_fn=<MulBackward0>), text_embeds=tensor([[[-0.0491,  0.0511,  0.0189,  ...,  0.0209, -0.0117, -0.0364],
         [-0.0354,  0.0553,  0.0241,  ...,  0.0145, -0.0324, -0.0109],
         [-0.0534,  0.0547,  0.0104,  ...,  0.0183, -0.0111, -0.0345],
         [-0.0395,  0.0689,  0.0027,  ..., -0.0001, -0.0232, -0.0713],
         [-0.0561,  0.0393,  0.0314,  ...,  0.0209, -0.0105, -0.0294]],

        [[-0.0546,  0.0490,  0.0094,  ...,  0.0255, -0.0133, -0.0398],
         [-0.0497,  0.0555,  0.0149,  ...,  0.0104, -0.0076, -0.0357],
         [-0.0549,  0.0562,  0.0238,  ...,  0.0292, -0.0048, -0.0395],
         [-0.0454,  0.0566,  0.0048,  ...,  0.0393, -0.0093, -0.0353],
         [-0.0412,  0.0719, -0.0031,  ...,  0.0033, -0.0262, -0.0612]]],
       grad_fn=<DivBackwa