In [1]:
import sys
import torch
import torch.nn.functional as F
import argparse
import random
from tqdm import tqdm

sys.path.append('/pasteur/u/yuhuiz/archive/neurips_modality_gap/pull_figure/convirt/')
from data.pretrain_loader import PretrainDataset

In [2]:
# Load model
from medclip import MedCLIPModel, MedCLIPVisionModelViT

# load MedCLIP-ViT
model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
model.from_pretrained(input_dir='/pasteur/u/esui/data/medclip_pretrained')
model.to('cuda')

# vision_model = model.vision_model
# text_model = model.text_model

  from .autonotebook import tqdm as notebook_tqdm
Downloading: 100%|██████████| 71.8k/71.8k [00:00<00:00, 1.15MB/s]
Downloading: 100%|██████████| 113M/113M [00:01<00:00, 109MB/s]
Some weights of the model checkpoint at microsoft/swin-tiny-patch4-window7-224 were not used when initializing SwinModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SwinModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SwinModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Downloading: 100%|██████████| 385/385 [00:00<00:00, 297kB/s]
Downloading: 100%|██████████| 436M/436M [00:03<00:00, 128MB/s]
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not 

load model weight from: /pasteur/u/esui/data/medclip_pretrained


MedCLIPModel(
  (vision_model): MedCLIPVisionModelViT(
    (model): SwinModel(
      (embeddings): SwinEmbeddings(
        (patch_embeddings): SwinPatchEmbeddings(
          (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
        )
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): SwinEncoder(
        (layers): ModuleList(
          (0): SwinStage(
            (blocks): ModuleList(
              (0-1): 2 x SwinLayer(
                (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                (attention): SwinAttention(
                  (self): SwinSelfAttention(
                    (query): Linear(in_features=96, out_features=96, bias=True)
                    (key): Linear(in_features=96, out_features=96, bias=True)
                    (value): Linear(in_features=96, out_features=96, bias=True)
                    (dropout): Dropout(p=0.0, inplace=F

In [3]:
base_dir = "/u/scr/zyh/develop/data-open/mimic-cxr-jpg-resized"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, help="Data directory with train and valid indexed report files.", default=None)
    parser.add_argument('--meta_file', type=str, default=f'{base_dir}/meta.json', help="Dataset meta file.")
    parser.add_argument('--img_dir', type=str, default=f'{base_dir}/files/', help="Directory to load image data from.")
    parser.add_argument('--local_img_dir', type=str, default=f'{base_dir}/files/', help="Directory to load image data from.")
    parser.add_argument('--image_encoder', type=str, default='resnet50', help="Name of the model architecture.")
    parser.add_argument('--bert_name', type=str, default='emilyalsentzer/Bio_ClinicalBERT', help="Name of the pretrained BERT model.")
    parser.add_argument('--imsize', type=int, default=224, help="Size of image.")
    parser.add_argument('--augment_p', type=float, default=0., help="Probability for image augmentation.")
    parser.add_argument('--dropout', type=float, default=0.2, help="Dropout rate.")
    parser.add_argument('--finetune_text_encoder', dest='freeze_text_encoder', action='store_false', help="Whether to finetune text encoder.")

    parser.add_argument('--num_clf_layer', type=int, default=2, help="Number of layers to use for NN classifier.")
    parser.add_argument('--clf_hidden_dim', type=int, default=512, help="Number of hidden dims for NN classifier.")
    parser.add_argument('--pool', choices=['cls', 'mean', 'max'], default='mean', help="Type of pooling to use for text encoder.")

    parser.add_argument('--fp', dest='amp', action='store_false', help="Use full precision training; by default use mixed precision.")
    parser.add_argument('--rih', action='store_true', help="Train on the RIH data; use corresponding data loaders.")

    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_epoch', type=int, default=200)
    parser.add_argument('--steps_per_epoch', type=int, default=5000)
    parser.add_argument('--optim', type=str, default='adam', help='sgd, adagrad, adam, adamw or adamax.')
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-6)
    parser.add_argument('--patience', type=int, default=5)
    parser.add_argument('--annealing_factor', type=float, default=0.5)
    parser.add_argument('--log_interval', type=int, default=100)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--pin_memory', action='store_true')
    parser.add_argument('--save_dir', type=str, default=None, help="Directory to save the trained model; if None will use id to look up")
    parser.add_argument('--root_dir', type=str, default='saved_models/pretrain', help="Root directory for model saving.")
    parser.add_argument('--id', type=int, default=0, help="An id of the training run")
    parser.add_argument('--seed', type=int, default=1234)

    args = parser.parse_args("")
    return args

args = parse_args()
opt = vars(args)

In [4]:
tokenizer = model.text_model.tokenizer #AutoTokenizer.from_pretrained(opt['bert_name'])

dataset = PretrainDataset(
    indexed_file=f"{base_dir}/reports_indexed.json",
    meta_file=f"{base_dir}/meta.json",
    img_dir=f"{base_dir}/files/",
    opt=opt,
    tokenizer=tokenizer,
    evaluation=False,
    imsize=opt['imsize'],
    augment_p=opt['augment_p']
)

In [None]:
# random sample 500 idxs from dataset
random.seed(1234)
idxs = random.sample(range(len(dataset)), 10000)
all_img_v, all_text_v = [], []

data = []

with torch.no_grad():
    for idx in tqdm(idxs):
        image, text_ids = dataset[idx]
        
        image = torch.tensor(image).unsqueeze(0).cuda()
        text_ids = torch.tensor(text_ids).unsqueeze(0).cuda()
        text_attention_mask = torch.ones(len(text_ids[0])).unsqueeze(0).cuda()
        
        img_v = model.vision_model(image)
        text_v = model.text_model(text_ids, text_attention_mask)

        # img_v = model.image_proj(img_v) # batch_size, dim
        # text_v = model.text_proj(text_v) # batch_size, dim

        # normalize for cosine similarity
        img_v = F.normalize(img_v, dim=1)
        text_v = F.normalize(text_v, dim=1)

        all_img_v.append(img_v)
        all_text_v.append(text_v)
        data.append({
            'x': image.cpu().numpy(),
            'y': tokenizer.decode(text_ids.squeeze(), skip_special_tokens=True),
            'x_embed': img_v.squeeze().cpu().numpy(),
            'y_embed': text_v.squeeze().cpu().numpy()
        })

  image = torch.tensor(image).unsqueeze(0).cuda()
100%|██████████| 10000/10000 [25:23<00:00,  6.56it/s]


In [10]:
import pickle
with open('/pasteur/u/esui/data/c3/data_medclip_no_aug_10k.pkl', 'wb') as f:
    pickle.dump(data, f)