In [1]:
import torch 
import torch.nn as nn
from typing import List, Union, Tuple, Dict, Optional

import numpy as np

from transformers import AutoTokenizer
from transformers import BertModel, AutoModel, ViTImageProcessor

import torchxrayvision as xrv
from adaptor import Adaptor
from utils import load_timm_model, freeze_encoder

from image_processor import ae_image_processor, timm_image_processor

import skimage

from transformers import TrainingArguments, Trainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
text_pretrained_available = [
    "bert-base-uncased", 
    "dmis-lab/biobert-v1.1", 
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", 
    "microsoft/BiomedVLP-CXR-BERT-general", 
]

### Load vision model
# vision_model = xrv.autoencoders.ResNetAE(weights="101-elastic")
vision_model = load_timm_model('swin_base_patch4_window7_224', pretrained=True, retain_head=False)

### Load text model
# text_pretrained = "microsoft/BiomedVLP-CXR-BERT-general"
text_pretrained = "./weights/ClinicalBERT_checkpoint/ClinicalBERT_pretraining_pytorch_checkpoint"
text_model = BertModel.from_pretrained(text_pretrained)

tokenizer = AutoTokenizer.from_pretrained(text_pretrained)
image_processor = lambda x: ViTImageProcessor()(x, return_tensors="pt", return_dict=True)

### Load sample input
img_path = 'sample.jpeg'
img = skimage.io.imread(img_path)
imgs = np.stack([img, img])
vision_inputs = image_processor(imgs)

### Define model
model = Adaptor(
    text_model=text_model,
    vision_model=vision_model,
    vision_model_type='timm', 
    vision_output_dim=1024,
    projection_dim=768,
)

### Obtain inputs
vision_inputs = image_processor(imgs)
text_inputs = tokenizer(
    text=["Nodule", "Lung Lesion"], 
    return_tensors="pt", padding=True, 
)
inputs = {**vision_inputs, **text_inputs}

### Forward to get output
outputs = model(**inputs, return_dict=True, return_loss=True)

Some weights of the model checkpoint at ./weights/ClinicalBERT_checkpoint/ClinicalBERT_pretraining_pytorch_checkpoint were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
arguments = TrainingArguments(
    output_dir="./results",
    per_device_eval_batch_size=16,  
    num_train_epochs=1, 
    save_strategy="epoch",
    learning_rate=2e-5, 
    seed=1117, 
    push_to_hub=False, 
)

trainer = Trainer(
    model=model, 
    args=arguments,
)

In [4]:
trainer.train()


ValueError: Trainer: training requires a train_dataset.