### Code to finetune the vision encoder for llava-med

https://huggingface.co/docs/transformers/en/main_classes/image_processor

In [116]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader

In [2]:
import pandas as pd
import os

In [3]:
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


[2025-03-01 16:39:40,961] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/r11kaijun/anaconda3/envs/llava-med/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/r11kaijun/anaconda3/envs/llava-med/compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status
  _torch_pytree._register_pytree_node(


### Dataset and Dataloaders

In [68]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        self.df = pd.read_csv(annotations_file)
        self.img_dir = img_dir

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        '''Returns the absolute image file path and list of disease classifications'''
        image = os.path.join(self.img_dir, self.df.at[idx, "image_path"])
        labels = self.df.iloc[idx, 5:].values

        labels_str = ""
        for label in labels:
            labels_str += str(label)
            labels_str += ","

        labels_str = labels_str[:-1]

        return image, labels_str
    

In [100]:
def convert_label_str(label_str_list):
    '''
    label_str_list: list of strings. Each string corresponds to each image
    output: list of list of numbers. Each "sublist" contains the One-Hot encoded labels for each image
    '''
    label_list_list = []

    for label_str in label_str_list:
        # print("label_str:", label_str)
        label_list = []
        for label in label_str.split(","):
            # print("label:", label)
            if label == '-1.0':
                label_list.extend([1, 0, 0, 0])
            elif label == '0.0':
                label_list.extend([0, 1, 0, 0])
            elif label == '1.0':
                label_list.extend([0, 0, 1, 0])
            elif label == '2.0':
                label_list.extend([0, 0, 0, 1])
            else:
                raise ValueError("Invalid label: f{label}")

        label_list_list.append(label_list)
            
    return label_list_list

In [101]:
training_data = CustomImageDataset(
    annotations_file="/home/r11kaijun/MIMIC-CXR/processed_data/processed_mimic-cxr-2.0.0-chexpert_train.csv",
    img_dir="/home/r11kaijun/physionet.org/files/mimic-cxr-jpg/2.1.0"
)

In [104]:
train_dataloader = DataLoader(training_data, batch_size=2, shuffle=True)

In [109]:
for i, (batch_image, batch_labels_str) in enumerate(train_dataloader):
    print("i:", i)
    print("length of batch_image:", len(batch_image))
    print("batch_labels_str:", batch_labels_str)
    print("batch_labels:", convert_label_str(batch_labels_str)[1])
    print("length of batch_labels:", len(convert_label_str(batch_labels_str)[1]))
    break

i: 0
length of batch_image: 2
batch_labels_str: ('2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,1.0', '2.0,1.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,-1.0,2.0,2.0')
batch_labels: [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]
length of batch_labels: 56


### Finetune the Vision Encoder
- Note: Vision Encoder is in the Vision Tower

In [8]:
from llava.mm_utils import process_images
from PIL import Image

In [112]:
class CustomCLIPVisionTower(nn.Module):

    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()
        self.vision_tower_name = vision_tower

        self.select_layer = getattr(args, "mm_vision_select_layer", -2)
        self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
        print("self.select_feature:", self.select_feature)
        self.image_aspect_ratio = getattr(args, "image_aspect_ratio", "pad")
        self.is_loaded = False

        if not delay_load:
            self.load_model()
        else:
            self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
            print("self.cfg_only:", self.cfg_only)
        self.is_loaded = True

    def load_model(self):
        if self.is_loaded:
            print(
                "{} is already loaded, `load_model` called again, skipping.".format(
                    self.vision_tower_name
                )
            )
            return

        self.image_processor = CLIPImageProcessor.from_pretrained(
            self.vision_tower_name
        )
        self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)

        self.is_loaded = True

    def feature_select(self, image_forward_outs):
        """Returns the CLS token and the patch embeddings (ie select_feature == 'cls_patch')"""
        image_features = image_forward_outs.hidden_states[self.select_layer]
        return image_features

    def get_tokens(self, select_feature, image_features):
        """
        Function to obtain the CLS/ patch tokens after extracting the image features.
        Can only be used when "select_feature" is "cls_patch"
        """
        if select_feature == "patch":
            image_features = image_features[:, 1:]
        elif select_feature == "cls":
            image_features = image_features[:, 0]
        else:
            raise ValueError(f"Unexpected select feature: {self.select_feature}")

        return image_features

    def preprocess(self, image_paths):
        if not self.is_loaded:
            raise ValueError(f"Image processor is not loaded yet")
        images = []
        for image_path in image_paths:
            image = Image.open(image_path).convert("RGB")
            images.append(image)

        return process_images(images, self.image_processor, self.config)

    # @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                image_forward_out = self.vision_tower(
                    image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
                    output_hidden_states=True,
                )
                image_feature = self.feature_select(image_forward_out).to(image.dtype)
                image_features.append(image_feature)
        else:
            image_forward_outs = self.vision_tower(
                images.to(device=self.device, dtype=self.dtype),
                output_hidden_states=True,
            )
            image_features = self.feature_select(image_forward_outs).to(images.dtype)

        return image_features

    @property
    def vision_tower(self):
        return self.vision_tower

    @property
    def image_processor(self):
        return self.image_processor

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return self.vision_tower.dtype

    @property
    def device(self):
        return self.vision_tower.device

    @property
    def config(self):
        if self.is_loaded:
            return self.vision_tower.config
        else:
            return self.cfg_only

    @property
    def hidden_size(self):
        return self.config.hidden_size

    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2

In [48]:
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List

In [49]:
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="microsoft/llava-med-v1.5-mistral-7b")
    version: Optional[str] = field(default="mistral_instruct")
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    vision_tower: Optional[str] = field(default="openai/clip-vit-large-patch14-336")
    mm_vision_select_layer: Optional[int] = field(default=-2)  # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_projector_type: Optional[str] = field(default='mlp2x_gelu')
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=True)
    mm_patch_merge_type: Optional[str] = field(default='flat')
    mm_vision_select_feature: Optional[str] = field(default="cls_patch")

In [50]:
def build_vision_tower(vision_tower_cfg, **kwargs):
    vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
    print("vision_tower:", vision_tower)
    is_absolute_path_exists = os.path.exists(vision_tower)
    if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
        return CustomCLIPVisionTower(vision_tower, args=vision_tower_cfg, input_neurones=1, output_neurones=1, hidden_layers=1, output_layers=1,**kwargs )


In [51]:
vision_tower = build_vision_tower(ModelArguments())

vision_tower: openai/clip-vit-large-patch14-336
self.select_feature: cls_patch


  return torch.load(checkpoint_file, map_location=map_location)


In [52]:
train_dataloader = DataLoader(training_data, batch_size=2, shuffle=True)

In [53]:
for i, (image_path, labels) in enumerate(train_dataloader):
  print("image_path:", image_path[0])
  print("labels:", labels)

  image_tensors = vision_tower.preprocess(list(image_path))
  # print("image_tensors:", image_tensors)

  image_features = vision_tower.forward(image_tensors)
  print("image_features:", image_features, image_features.shape)

  # global average pooling
  # pooled_features = image_features.mean(dim=1) 
  # print("pooled_features:", pooled_features, pooled_features.shape)

  break
  

image_path: /home/r11kaijun/physionet.org/files/mimic-cxr-jpg/2.1.0/files/p18/p18298823/s54194522/372bc95b-ff7a832c-0c51d0b3-80acc594-d66814f5.jpg
labels: ('2.0,2.0,2.0,2.0,2.0,2.0,1.0,1.0,2.0,1.0,1.0,2.0,2.0,2.0', '2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,1.0')
image_features: tensor([[[-0.2858,  0.1829, -0.2237,  ..., -0.4651,  0.1020, -0.2103],
         [ 0.6322, -0.5140, -0.0848,  ...,  0.9966,  0.1525,  0.2769],
         [ 0.2476,  0.1778,  0.4171,  ...,  0.3098,  0.0962,  0.1796],
         ...,
         [ 0.7404,  1.3373, -0.0708,  ...,  1.1104,  0.0839,  0.8365],
         [ 0.7460, -0.4267, -0.4826,  ...,  0.1110,  0.2611,  0.6659],
         [ 0.6995,  0.6328,  0.7132,  ...,  0.3153,  0.1317,  0.2101]],

        [[-0.3073,  0.0718,  0.0143,  ..., -0.5742, -0.2717, -0.1639],
         [ 0.6445,  0.3420,  0.1778,  ...,  0.1889,  0.8064, -0.1838],
         [ 0.9056,  0.5606,  0.0175,  ...,  0.4767, -0.0091,  1.1256],
         ...,
         [ 0.4904,  1.0878,  0.4074,  ...

In [115]:
# vision_tower.get_tokens("cls", image_features)
# vision_tower.get_tokens("patch", image_features)

In [113]:
class CLIPDiseaseClassifier(nn.Module):
    def __init__(self, input_neurons=1024, hidden_dim=4096, output_neurons=56):
        # MLP Classification Head
        self.mlp = nn.Sequential(
            nn.Linear(input_neurons, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, output_neurons)  # Output: (B, 56)
        )

    def forward(self, image_features):
        output = self.mlp(image_features)  # (B, 56)
        # return output.view(-1, 14, 4)  # Reshape to (batch, 14 diseases, 4 states)
        return output


def train_clip_classifier_cls(
    vision_tower_instance:CustomCLIPVisionTower, classifier:CLIPDiseaseClassifier, train_loader, val_loader, output_dir, epochs=10, lr=1e-4
):
    vision_tower = vision_tower_instance.vision_tower
    device = vision_tower.device

    # unfreeze the layers that we want to finetune in the clip encoder

    criterion = nn.BCEWithLogitsLoss()  # Multi-label classification loss
    optimizer = optim.AdamW(vision_tower.parameters(), lr=lr)

    for epoch in range(epochs):
        vision_tower.train()
        total_loss = 0

        for batch, (image_paths, labels_str) in enumerate(train_loader):
            optimizer.zero_grad()
            
            ground_truths = convert_label_str(labels_str).to(device)
            print("ground_truths:", ground_truths)

            # preprocess images
            images = vision_tower_instance.preprocess(image_paths)
            image_features = vision_tower_instance.forward(images)

            # get CLS tokens
            cls_tokens = vision_tower_instance.get_tokens("cls", image_features)

            predicted_classes = classifier.forward(cls_tokens)

            loss = criterion(predicted_classes, labels)  # Compute loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # # Validation Loop
        # model.eval()
        # val_loss = 0
        # with torch.no_grad():
        #     for images, labels in val_loader:
        #         images, labels = images.to(device), labels.to(device)
        #         outputs = model(images)
        #         val_loss += criterion(outputs, labels).item()

        # avg_val_loss = val_loss / len(val_loader)
        # print(f"Epoch [{epoch+1}/{epochs}] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    # save the models
    torch.save()

    return vision_tower, classifier


vision_tower = build_vision_tower(ModelArguments())
training_data = CustomImageDataset(
    annotations_file="/home/r11kaijun/MIMIC-CXR/processed_data/processed_mimic-cxr-2.0.0-chexpert_train.csv",
    img_dir="/home/r11kaijun/physionet.org/files/mimic-cxr-jpg/2.1.0",
)
validation_data = training_data

train_dataloader = DataLoader(training_data, batch_size=2, shuffle=True)
valdation_dataloader = DataLoader(validation_data, batch_size=2, shuffle=True)

train_clip_classifier(
    vision_tower,
    train_loader=train_dataloader,
    val_loader=valdation_dataloader,
    output_dir=".",
)

vision_tower: openai/clip-vit-large-patch14-336
self.select_feature: cls_patch


  return torch.load(checkpoint_file, map_location=map_location)


labels: [[0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1]]
labels: [[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0]]
labels: [[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1], [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,

KeyboardInterrupt: 

### Ground Truth (using LLava-Med model and CLIP model as reference)

In [4]:
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
print(image_processor)



CLIPImageProcessor {
  "crop_size": {
    "height": 336,
    "width": 336
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "CLIPFeatureExtractor",
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 336
  }
}



In [5]:
vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
print(vision_tower)

  return torch.load(checkpoint_file, map_location=map_location)


CLIPVisionModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(577, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from llava.model import LlavaMistralForCausalLM
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


[2025-03-01 14:20:52,178] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/home/r11kaijun/anaconda3/envs/llava-med/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/r11kaijun/anaconda3/envs/llava-med/compiler_compat/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status
  _torch_pytree._register_pytree_node(


In [11]:
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):

    kwargs = {}

    if device != "cuda":
        kwargs['device_map'] = {"": device}

    if load_8bit:
        kwargs['load_in_8bit'] = True
    elif load_4bit:
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
    else:
        kwargs['torch_dtype'] = torch.float16
    
    if 'llava' in model_name.lower():
        # Load LLaVA model
            if 'mistral' in model_name.lower():
                print("model_name:", model_name)
                tokenizer = AutoTokenizer.from_pretrained(model_path)
                print("initialised tokenizer:")
                model = LlavaMistralForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    use_flash_attention_2=False,
                    **kwargs
                )
    else:
        # Load language model
        if model_base is not None:
            # PEFT model
            from peft import PeftModel
            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
            model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
            print(f"Loading LoRA weights from {model_path}")
            model = PeftModel.from_pretrained(model, model_path)
            print(f"Merging weights")
            model = model.merge_and_unload()
            print('Convert to FP16...')
            model.to(torch.float16)
        else:
            use_fast = False
            if 'mpt' in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)

    image_processor = None

    if 'llava' in model_name.lower(): # or 'mistral' in model_name.lower():
        mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
        mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
        if mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
        if mm_use_im_start_end:
            tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
        model.resize_token_embeddings(len(tokenizer))

        vision_tower = model.get_vision_tower()
        if not vision_tower.is_loaded:
            vision_tower.load_model()
        vision_tower.to(device=device, dtype=torch.float16)
        model.model.mm_projector.to(device=device, dtype=torch.float16)
        model.to(device=device, dtype=torch.float16)
        image_processor = vision_tower.image_processor

    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    return tokenizer, model, image_processor, context_len

In [14]:
def load_model(model_path, model_base, model_name):
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        model_path=model_path,
        model_base=model_base,
        model_name=model_name,
        device="cuda",
    )
    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model()
    vision_tower.to(device="cuda", dtype=torch.float16)
    model.model.mm_projector.to(device="cuda", dtype=torch.float16)
    model.to(device="cuda", dtype=torch.float16)
    image_processor = vision_tower.image_processor
    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    return tokenizer, model, image_processor, context_len


def load_base_model():
    return load_model(
        model_path="microsoft/llava-med-v1.5-mistral-7b",
        model_base="",
        model_name="microsoft/llava-med-v1.5-mistral-7b",
        # load_8bit=load_8bit,
        # load_4bit=load_4bit,
    )

In [15]:
tokenizer, model, image_processor, context_len = load_base_model()

model_name: microsoft/llava-med-v1.5-mistral-7b




initialised tokenizer:


Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.80s/it]
Some weights of the model checkpoint at microsoft/llava-med-v1.5-mistral-7b were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.encoder.layers.2.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.3.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.22.self_attn.v_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.22.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.8.self_attn.k_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.19.mlp.fc2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.4.mlp.fc1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.19.self_attn.out_proj.bias', 'model.vision_tower.vision_tower.visio

In [16]:
print("image_processor:", image_processor)

image_processor: CLIPImageProcessor {
  "crop_size": {
    "height": 336,
    "width": 336
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "feature_extractor_type": "CLIPFeatureExtractor",
  "image_mean": [
    0.48145466,
    0.4578275,
    0.40821073
  ],
  "image_processor_type": "CLIPImageProcessor",
  "image_std": [
    0.26862954,
    0.26130258,
    0.27577711
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 336
  }
}



In [17]:
# LLava-Med model overall architecture
print("model:", model)

model: LlavaMistralForCausalLM(
  (model): LlavaMistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): Mi

In [35]:
print("model:", model.model.layers)

model: LlavaMistralModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x MistralDecoderLayer(
      (self_attn): MistralAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): MistralRotaryEmbedding()
      )
      (mlp): MistralMLP(
        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): MistralRMSNorm()
      (post_attention_layernorm): MistralRMSNorm()
    )
  )
  (norm): MistralRMSNorm()
  (vision_tower): CLIPVisionTower(
    (vision_tower): CLIPVisi

In [32]:
# LLava-Med MLP layer
print(model.model.mm_projector)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): GELU(approximate='none')
  (2): Linear(in_features=4096, out_features=4096, bias=True)
)


In [22]:
print(model.get_vision_tower().vision_tower.vision_model)

CLIPVisionTransformer(
  (embeddings): CLIPVisionEmbeddings(
    (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (position_embedding): Embedding(577, 1024)
  )
  (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (encoder): CLIPEncoder(
    (layers): ModuleList(
      (0-23): 24 x CLIPEncoderLayer(
        (self_attn): CLIPAttention(
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): CLIPMLP(
          (activation_fn): QuickGELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias

In [21]:
print(model.get_vision_tower().vision_tower)

CLIPVisionModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(577, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=1024, out_features=4096, bias=