In [1]:
from mgca.datasets.transforms import DataTransforms
from datasets.dataset import MultimodalPretrainingDatasetForAdaptor, multimodal_collator

import torch 
import torch.nn as nn
from torch.utils.data import DataLoader

from typing import List, Union, Tuple, Dict, Optional

import numpy as np

from transformers import AutoTokenizer
from transformers import BertModel, AutoModel, ViTImageProcessor

from models.adaptor import Adaptor
from utils.utils import load_timm_model, freeze_encoder

from utils.dataset_utils import ae_image_processor, timm_image_processor

import skimage

from transformers import TrainingArguments, Trainer

In [2]:
seed = 1117
batch_size = 48
num_workers = 16
data_pct = 0.01
crop_size = 224

train_dataset = MultimodalPretrainingDatasetForAdaptor(
    split='train', 
    transform=DataTransforms(True, crop_size), 
    data_pct=data_pct, 
)

val_dataset = MultimodalPretrainingDatasetForAdaptor(
    split='valid', 
    transform=DataTransforms(False, crop_size), 
    data_pct=data_pct, 
)

232421it [03:12, 1209.60it/s]
232421it [02:13, 1746.07it/s] 


In [3]:
text_pretrained_available = [
    "bert-base-uncased", 
    "dmis-lab/biobert-v1.1", 
    "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", 
    "microsoft/BiomedVLP-CXR-BERT-general", 
]

### Load vision model
# vision_model = xrv.autoencoders.ResNetAE(weights="101-elastic")
vision_model = load_timm_model('swin_base_patch4_window7_224', pretrained=True, retain_head=False)

### Load text model
# text_pretrained = "microsoft/BiomedVLP-CXR-BERT-general"
text_pretrained = "./weights/ClinicalBERT_checkpoint/ClinicalBERT_pretraining_pytorch_checkpoint"
text_model = BertModel.from_pretrained(text_pretrained)

tokenizer = AutoTokenizer.from_pretrained(text_pretrained)
image_processor = lambda x: ViTImageProcessor()(x, return_tensors="pt", return_dict=True)

### Load sample input
img_path = 'sample.jpeg'
img = skimage.io.imread(img_path)
imgs = np.stack([img, img])
vision_inputs = image_processor(imgs)
text_inputs = tokenizer(
    text=["Nodule", "Lung Lesion"], 
    return_tensors="pt", padding=True, 
)

Some weights of the model checkpoint at ./weights/ClinicalBERT_checkpoint/ClinicalBERT_pretraining_pytorch_checkpoint were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
from utils.utils import (
    load_timm_model, freeze_encoder, 
    get_text_embeds_raw, get_image_embeds_raw,
    get_dataloader,
)

In [5]:
val_dataloader = get_dataloader(
    val_dataset, 
    batch_size=8,
    num_workers=num_workers,
    collate_fn=multimodal_collator,
)

# get_image_embeds_raw(
#     val_dataloader, 
#     vision_model=vision_model,
#     vision_model_type='timm', 
#     save_path='./saved_embeddings/image_embeds',
#     split='valid',

get_text_embeds_raw(
    val_dataloader,
    text_model=text_model,
    save_path='./saved_embeddings/text_embeds',
    split='valid',
)


  3%|â–Ž         | 17/619 [00:17<10:03,  1.00s/it]


KeyboardInterrupt: 