### Code to finetune the vision encoder for llava-med

https://huggingface.co/docs/transformers/en/main_classes/image_processor

In [1]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader

In [40]:
import pandas as pd
import os
import time

In [3]:
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


[2025-03-04 16:31:14,802] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/r11kaijun/anaconda3/envs/llava-med/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/r11kaijun/anaconda3/envs/llava-med/compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status
  _torch_pytree._register_pytree_node(


### Dataset and Dataloaders

In [4]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.df = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        '''Returns the absolute image file path and list of disease classifications'''
        image = os.path.join(self.img_dir, self.df.at[idx, "image_path"])
        labels = self.df.iloc[idx, 5:].values

        labels_str = ""
        for label in labels:
            labels_str += str(label)
            labels_str += ","

        labels_str = labels_str[:-1]

        return image, labels_str
    

In [5]:
def convert_label_str(label_str_list):
    '''
    label_str_list: list of strings. Each string corresponds to each image
    output: list of list of numbers. Each "sublist" contains the One-Hot encoded labels for each image
    '''
    label_list_list = []

    for label_str in label_str_list:
        # print("label_str:", label_str)
        label_list = []
        for label in label_str.split(","):
            # print("label:", label)
            if label == '-1.0':
                label_list.extend([1, 0, 0, 0])
            elif label == '0.0':
                label_list.extend([0, 1, 0, 0])
            elif label == '1.0':
                label_list.extend([0, 0, 1, 0])
            elif label == '2.0':
                label_list.extend([0, 0, 0, 1])
            else:
                raise ValueError("Invalid label: f{label}")

        label_list_list.append(label_list)
            
    return label_list_list

In [6]:
training_data = CustomImageDataset(
    annotations_file="/home/r11kaijun/MIMIC-CXR/processed_data/processed_mimic-cxr-2.0.0-chexpert_train.csv",
    img_dir="/home/r11kaijun/physionet.org/files/mimic-cxr-jpg/2.1.0"
)

In [7]:
train_dataloader = DataLoader(training_data, batch_size=2, shuffle=True)

In [8]:
for i, (batch_image, batch_labels_str) in enumerate(train_dataloader):
    print("i:", i)
    print("length of batch_image:", len(batch_image))
    print("batch_labels_str:", batch_labels_str)
    print("batch_labels:", convert_label_str(batch_labels_str)[1])
    print("length of batch_labels:", len(convert_label_str(batch_labels_str)[1]))
    break

i: 0
length of batch_image: 2
batch_labels_str: ('1.0,-1.0,2.0,2.0,-1.0,2.0,2.0,2.0,2.0,0.0,2.0,2.0,2.0,2.0', '2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,1.0')
batch_labels: [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0]
length of batch_labels: 56


### Finetune the Vision Encoder
- Note: Vision Encoder is in the Vision Tower

In [9]:
from llava.mm_utils import process_images
from PIL import Image

In [43]:
class CustomCLIPVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()
        self.vision_tower_name = vision_tower

        self.select_layer = getattr(args, "mm_vision_select_layer", -2)
        self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
        print("self.select_feature:", self.select_feature)
        self.image_aspect_ratio = getattr(args, "image_aspect_ratio", "pad")
        self.is_loaded = False
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
            print("self.cfg_only:", self.cfg_only)
        self.is_loaded = True

    def load_model(self):
        if self.is_loaded:
            print(
                "{} is already loaded, `load_model` called again, skipping.".format(
                    self.vision_tower_name
                )
            )
            return

        print("self.vision_tower_name:", self.vision_tower_name)
        self.image_processor = CLIPImageProcessor.from_pretrained(
            self.vision_tower_name
        )
        self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)

        self.is_loaded = True

    def feature_select(self, image_forward_outs):
        """Returns the CLS token and the patch embeddings (ie select_feature == 'cls_patch')"""
        image_features = image_forward_outs.hidden_states[self.select_layer]
        if self.select_feature == "patch":
            image_features = image_features[:, 1:]
            # TODO: Add additional processing methods to pool the results in each of the patch embeddings
        elif self.select_feature == "cls":
            image_features = image_features[:, 0]
        else:
            raise ValueError(f"Unexpected select feature: {self.select_feature}")
        return image_features

    # def get_tokens(self, select_feature, image_features):
    #     """
    #     Function to obtain the CLS/ patch tokens after extracting the image features.
    #     Can only be used when "select_feature" is "cls_patch"
    #     """
    #     if select_feature == "patch":
    #         image_features = image_features[:, 1:]
    #     elif select_feature == "cls":
    #         image_features = image_features[:, 0]
    #     else:
    #         raise ValueError(f"Unexpected select feature: {self.select_feature}")

    #     return image_features

    def preprocess(self, image_paths):
        if not self.is_loaded:
            raise ValueError(f"Image processor is not loaded yet")
        images = []
        for image_path in image_paths:
            image = Image.open(image_path).convert("RGB")
            images.append(image)

        return process_images(images, self.image_processor, self.config)

    # @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                image_forward_out = self.vision_tower(
                    image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
                    output_hidden_states=True,
                )
                image_feature = self.feature_select(image_forward_out).to(image.dtype)
                image_features.append(image_feature)
        else:
            image_forward_outs = self.vision_tower(
                images.to(device=self.device, dtype=self.dtype),
                output_hidden_states=True,
            )
            image_features = self.feature_select(image_forward_outs).to(images.dtype)

        return image_features

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return self.vision_tower.dtype

    # @property
    # def device(self):
    #     return self.vision_tower.device

    @property
    def config(self):
        if self.is_loaded:
            return self.vision_tower.config
        else:
            return self.cfg_only

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2

In [16]:
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List

In [17]:
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="microsoft/llava-med-v1.5-mistral-7b")
    version: Optional[str] = field(default="mistral_instruct")
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    vision_tower: Optional[str] = field(default="openai/clip-vit-large-patch14-336")
    mm_vision_select_layer: Optional[int] = field(default=-2)  # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_projector_type: Optional[str] = field(default='mlp2x_gelu')
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=True)
    mm_patch_merge_type: Optional[str] = field(default='flat')
    mm_vision_select_feature: Optional[str] = field(default="cls")

In [18]:
def build_vision_tower(vision_tower_cfg, **kwargs):
    vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
    print("vision_tower:", vision_tower)
    is_absolute_path_exists = os.path.exists(vision_tower)
    if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
        return CustomCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs )


In [None]:
vision_tower = build_vision_tower(ModelArguments())

vision_tower: openai/clip-vit-large-patch14-336
self.select_feature: cls
self.vision_tower_name: openai/clip-vit-large-patch14-336


  return torch.load(checkpoint_file, map_location=map_location)


In [52]:
train_dataloader = DataLoader(training_data, batch_size=2, shuffle=True)

In [None]:
for i, (image_path, labels) in enumerate(train_dataloader):
  print("image_path:", image_path[0])
  print("labels:", labels)

  image_tensors = vision_tower.preprocess(list(image_path))
  # print("image_tensors:", image_tensors)

  image_features = vision_tower.forward(image_tensors)
  print("image_features:", image_features, image_features.shape)

  # global average pooling
  # pooled_features = image_features.mean(dim=1) 
  # print("pooled_features:", pooled_features, pooled_features.shape)

  break
  

In [38]:
# vision_tower.get_tokens("cls", image_features)
# vision_tower.get_tokens("patch", image_features)
vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
# Freeze all layers
for param in vision_tower.parameters():
    param.requires_grad = False

# Unfreeze the last 4 layers
num_layers = len(vision_tower.vision_model.encoder.layers)  # Total layers
for param in vision_tower.vision_model.encoder.layers[-4:].parameters():
    param.requires_grad = True  # Unfreeze last 4 layers

vision_tower

CLIPVisionModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(577, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=

In [58]:
class CLIPDiseaseClassifier(nn.Module):
    def __init__(self, input_neurons=1024, hidden_dim=1024, output_neurons=56):
        super().__init__()
        # MLP Classification Head
        self.mlp = nn.Sequential(
            nn.Linear(input_neurons, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, output_neurons),  # Output: (B, 56)
        )

    def forward(self, image_features):
        output = self.mlp(image_features)  # (B, 56)
        # return output.view(-1, 14, 4)  # Reshape to (batch, 14 diseases, 4 states)
        return output


def train_clip_classifier(
    vision_tower_instance: CustomCLIPVisionTower,
    classifier: CLIPDiseaseClassifier,
    train_loader,
    val_loader,
    output_dir,
    unfreeze_layers=4,
    epochs=2,
    lr=1e-4,
):
    # send the models to the gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vision_tower_instance.to(device)
    classifier.to(device)

    # Freeze all layers
    for param in vision_tower.parameters():
        param.requires_grad = False

    # unfreeze the layers that we want to finetune in the clip encoder
    for param in vision_tower_instance.vision_tower.vision_model.encoder.layers[
        -unfreeze_layers:
    ].parameters():
        param.requires_grad = True  # Unfreeze last 4 layers

    criterion = nn.BCEWithLogitsLoss()  # Multi-label classification loss
    optimizer = optim.AdamW(
        list(vision_tower_instance.parameters()) + list(classifier.parameters()),
        lr=lr,
    )

    avg_train_loss_arr = []
    avg_val_loss_arr = []

    for epoch in range(epochs):
        print("epoch:", epoch)
        start_time = time.time()
        vision_tower_instance.train()
        classifier.train()
        total_loss = 0

        for batch, (image_paths, labels_str) in enumerate(train_loader):
            print("batch:", batch)
            optimizer.zero_grad()

            ground_truths = torch.Tensor(convert_label_str(labels_str)).to(device)
            # print("ground_truths:", ground_truths)

            # preprocess images
            images = vision_tower_instance.preprocess(image_paths).to(device)
            image_features = vision_tower_instance.forward(images)

            predicted_classes = classifier.forward(image_features)

            loss = criterion(predicted_classes, ground_truths)  # Compute loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # Validation Loop
        vision_tower_instance.eval()
        classifier.eval()
        val_loss = 0
        with torch.no_grad():
            for batch, (image_paths, labels_str) in enumerate(val_loader):
                ground_truths = torch.Tensor(convert_label_str(labels_str)).to(device)
                # print("ground_truths:", ground_truths)

                images = vision_tower_instance.preprocess(image_paths).to(device)
                image_features = vision_tower_instance.forward(images)

                predicted_classes = classifier.forward(image_features)

                val_loss += criterion(predicted_classes, ground_truths).item()

        avg_val_loss = val_loss / len(val_loader)
        end_time = time.time()
        epoch_time = end_time - start_time
        print(
            f"Epoch [{epoch+1}/{epochs}] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {epoch_time:.2f} seconds"
        )

        avg_train_loss_arr.append(avg_train_loss)
        avg_val_loss_arr.append(avg_val_loss)

    # save the models separately
    torch.save(
        vision_tower.state_dict(),
        os.path.join(output_dir, f"vision_tower-epoch-{epoch}-lr-{lr}.pth"),
    )
    torch.save(
        classifier.state_dict(),
        os.path.join(output_dir, f"vision_tower-epoch-{epoch}-lr-{lr}.pth"),
    )

    return vision_tower, classifier

In [59]:
vision_tower = build_vision_tower(ModelArguments())
classifier = CLIPDiseaseClassifier()

training_data = CustomImageDataset(
    annotations_file="/home/r11kaijun/MIMIC-CXR/processed_data/processed_mimic-cxr-2.0.0-chexpert_train.csv",
    img_dir="/home/r11kaijun/physionet.org/files/mimic-cxr-jpg/2.1.0",
)
validation_data = CustomImageDataset(
    annotations_file="/home/r11kaijun/MIMIC-CXR/processed_data/processed_mimic-cxr-2.0.0-chexpert_validate.csv",
    img_dir="/home/r11kaijun/physionet.org/files/mimic-cxr-jpg/2.1.0",
)

train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
valdation_dataloader = DataLoader(validation_data, batch_size=16, shuffle=True)

train_clip_classifier(
    vision_tower,
    classifier,
    train_loader=train_dataloader,
    val_loader=valdation_dataloader,
    output_dir=".",
)

vision_tower: openai/clip-vit-large-patch14-336
self.select_feature: cls
self.vision_tower_name: openai/clip-vit-large-patch14-336


  return torch.load(checkpoint_file, map_location=map_location)


epoch: 0
batch: 0
batch: 1
batch: 2
batch: 3
batch: 4
batch: 5
batch: 6
batch: 7
batch: 8
batch: 9
batch: 10
batch: 11
batch: 12
batch: 13
batch: 14
batch: 15
batch: 16
batch: 17
batch: 18
batch: 19
batch: 20


KeyboardInterrupt: 

### Ground Truth (using LLava-Med model and CLIP model as reference)

In [None]:
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
print(image_processor)

In [None]:
vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
print(vision_tower)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from llava.model import LlavaMistralForCausalLM
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

In [11]:
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):

    kwargs = {}

    if device != "cuda":
        kwargs['device_map'] = {"": device}

    if load_8bit:
        kwargs['load_in_8bit'] = True
    elif load_4bit:
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
    else:
        kwargs['torch_dtype'] = torch.float16
    
    if 'llava' in model_name.lower():
        # Load LLaVA model
            if 'mistral' in model_name.lower():
                print("model_name:", model_name)
                tokenizer = AutoTokenizer.from_pretrained(model_path)
                print("initialised tokenizer:")
                model = LlavaMistralForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    use_flash_attention_2=False,
                    **kwargs
                )
    else:
        # Load language model
        if model_base is not None:
            # PEFT model
            from peft import PeftModel
            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
            model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
            print(f"Loading LoRA weights from {model_path}")
            model = PeftModel.from_pretrained(model, model_path)
            print(f"Merging weights")
            model = model.merge_and_unload()
            print('Convert to FP16...')
            model.to(torch.float16)
        else:
            use_fast = False
            if 'mpt' in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)

    image_processor = None

    if 'llava' in model_name.lower(): # or 'mistral' in model_name.lower():
        mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
        mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
        if mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
        if mm_use_im_start_end:
            tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
        model.resize_token_embeddings(len(tokenizer))

        vision_tower = model.get_vision_tower()
        if not vision_tower.is_loaded:
            vision_tower.load_model()
        vision_tower.to(device=device, dtype=torch.float16)
        model.model.mm_projector.to(device=device, dtype=torch.float16)
        model.to(device=device, dtype=torch.float16)
        image_processor = vision_tower.image_processor

    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    return tokenizer, model, image_processor, context_len

In [14]:
def load_model(model_path, model_base, model_name):
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path=model_path,
        model_base=model_base,
        model_name=model_name,
        device="cuda",
    )
    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model()
    vision_tower.to(device="cuda", dtype=torch.float16)
    model.model.mm_projector.to(device="cuda", dtype=torch.float16)
    model.to(device="cuda", dtype=torch.float16)
    image_processor = vision_tower.image_processor
    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    return tokenizer, model, image_processor, context_len


def load_base_model():
    return load_model(
        model_path="microsoft/llava-med-v1.5-mistral-7b",
        model_base="",
        model_name="microsoft/llava-med-v1.5-mistral-7b",
        # load_8bit=load_8bit,
        # load_4bit=load_4bit,
    )

In [None]:
tokenizer, model, image_processor, context_len = load_base_model()

In [None]:
print("image_processor:", image_processor)

In [None]:
# LLava-Med model overall architecture
print("model:", model)

In [None]:
print("model:", model.model.layers)

In [None]:
# LLava-Med MLP layer
print(model.model.mm_projector)

In [None]:
print(model.get_vision_tower().vision_tower.vision_model)

In [None]:
print(model.get_vision_tower().vision_tower)