In [1]:
import torch
import torch.nn as nn
from transformers import CLIPProcessor, CLIPModel, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from transformers import Trainer, TrainingArguments
from PIL import Image
import os

# Set device to MPS if available, otherwise CPU
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")

def load_clip():
    # -- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    # Load CLIP processor and model, saving locally if not already present
    if os.path.exists("data/processor_clip_embeddings_vit_base_patch32.pt"):
        processor = CLIPProcessor.from_pretrained("data/processor_clip_embeddings_vit_base_patch32.pt")
    else:
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        # save processor for later use
        processor.save_pretrained("data/processor_clip_embeddings_vit_base_patch32.pt")

    model_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    model_clip = model_clip.to(device).to(torch.float32)
    return processor, model_clip

def generate_clip_embeddings(processor, model_clip, image_dir, embeddings_path):
    # Generate CLIP embeddings for images and save them
    image_files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
    clip_embeddings = []
    image_ID = {}

    for i, image_file in enumerate(image_files[:100]):  # Limit to 100 images for this example
        image_path = os.path.join(image_dir, image_file)
        image = Image.open(image_path)
        inputs = processor(images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            image_features = model_clip.get_image_features(**inputs)
        clip_embeddings.append(image_features.cpu())
        image_ID[image_file.split(".")[0]] = i

    torch.save(torch.cat(clip_embeddings, dim=0), embeddings_path)
    return image_ID

class ProjectionLayer(nn.Module):
    # Linear projection layer to map CLIP embeddings to Phi model dimensions
    def __init__(self, clip_embedding_dim, phi_hidden_dim):
        super().__init__()
        self.linear = nn.Linear(clip_embedding_dim, phi_hidden_dim)

    def forward(self, image_embeddings):
        return self.linear(image_embeddings)

class MultimodalPhiWithAdapter(nn.Module):
    # Multimodal model combining Phi language model with image embeddings
    def __init__(self, language_model, projection_layer, freeze_language_model=True, freeze_projection_layer=False):
        super().__init__()
        self.language_model = language_model
        self.projection_layer = projection_layer
        
        # Set trainable parameters based on input flags
        self.set_trainable_params(freeze_language_model, freeze_projection_layer)

    def set_trainable_params(self, freeze_language_model, freeze_projection_layer):
        # Set which parts of the model are trainable
        for param in self.language_model.parameters():
            param.requires_grad = not freeze_language_model
        for param in self.projection_layer.parameters():
            param.requires_grad = not freeze_projection_layer


    def forward(self, input_ids, attention_mask, image_embeddings, labels=None):
        # Forward pass combining image embeddings with text input
        batch_size = input_ids.shape[0]
        projected_embeddings = self.projection_layer(image_embeddings)
        
        # Prepend projected image embeddings to the input sequence
        input_embeds = self.language_model.get_input_embeddings()(input_ids)
        combined_embeds = torch.cat([projected_embeddings.unsqueeze(1), input_embeds], dim=1)
        combined_embeds = combined_embeds.to(torch.float32)
        
        # Adjust attention mask and labels for the added image token
        image_attention = torch.ones((batch_size, 1), dtype=torch.long, device=device)
        combined_attention_mask = torch.cat([image_attention, attention_mask], dim=1)

        # Adjust labels
        if labels is not None:
            # Add a padding label for the image token
            pad_labels = torch.full((batch_size, 1), -100, dtype=labels.dtype, device=labels.device)
            labels = torch.cat([pad_labels, labels], dim=1)
        
        outputs = self.language_model(inputs_embeds=combined_embeds, attention_mask=combined_attention_mask)
        
        logits = outputs.logits
        loss = None
        if labels is not None:
            # Calculate loss if labels are provided
            loss_fct = nn.CrossEntropyLoss()
            shift_logits = logits[:, 1:, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
        return {"loss": loss, "logits": logits}
    
    # return total number of trainable parameters
    def count_trainable_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def count_total_parameters(self):
        return sum(p.numel() for p in self.parameters())

class InstructDataset(torch.utils.data.Dataset):
    # Dataset class for instruction-following data with images
    def __init__(self, instruct_data, clip_embeddings, tokenizer, image_ID, max_length=512):
        self.instruct_data = instruct_data
        self.clip_embeddings = clip_embeddings
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.image_ID = image_ID

    def __len__(self):
        return len(self.instruct_data)
    
    def __getitem__(self, idx):
        # Prepare a single item from the dataset
        item = self.instruct_data[idx]
        conversations = item['conversations']
        
        full_text = "\n".join([f"{conv['from']}: {conv['value']}" for conv in conversations])
        
        image_id = item["id"]
        img_idx = self.image_ID.get(image_id, 0)  # Default to first embedding if not found
        image_embedding = self.clip_embeddings[img_idx]

        encoded = self.tokenizer(
            full_text, 
            return_tensors="pt", 
            truncation=True, 
            max_length=self.max_length,
            padding="max_length"
        )

        return {
            "input_ids": encoded.input_ids.squeeze(),
            "attention_mask": encoded.attention_mask.squeeze(),
            "image_embeddings": image_embedding,
            "labels": encoded.input_ids.squeeze(),
        }


  from .autonotebook import tqdm as notebook_tqdm


Using device: mps


In [2]:
processor, model_clip = load_clip()
image_dir = "data/images_train2017"
embeddings_path = "data/clip_embeddings.pt"
image_ID = generate_clip_embeddings(processor, model_clip, image_dir, embeddings_path)

In [3]:

clip_embeddings = torch.load(embeddings_path)
clip_embedding_dim = clip_embeddings.shape[1]

In [4]:
# Load your Phi model
if os.path.exists("local_phi2_model"):
    model_phi = AutoModelForCausalLM.from_pretrained("local_phi2_model").to(device)
else:
    model_phi = AutoModelForCausalLM.from_pretrained("microsoft/phi-2").to(device)
    model_phi.save_pretrained("local_phi2_model")

# -- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
if os.path.exists("local_phi2_model"):
    tokenizer = AutoTokenizer.from_pretrained("local_phi2_model")
    print("tokenizer loaded from local")
else:
    tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
    tokenizer.save_pretrained("local_phi2_model")
    print("tokenizer loaded from HF")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.04s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


tokenizer loaded from local


In [5]:

projection_layer = ProjectionLayer(clip_embedding_dim, model_phi.config.hidden_size).to(device)
multimodal_phi = MultimodalPhiWithAdapter(model_phi, projection_layer).to(device)


In [6]:

# Configure QLoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],
    lora_dropout=0.05,
    bias="none",
    #task_type="CAUSAL_LM"
)
multimodal_phi = get_peft_model(multimodal_phi, lora_config)

# After applying LoRA, if you want to unfreeze certain parts:
multimodal_phi.set_trainable_params(freeze_language_model=True, freeze_projection_layer=False)

'NoneType' object has no attribute 'cadam32bit_grad_fp32'


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [7]:

# Add this new code block
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(multimodal_phi)
trainable_params = count_trainable_parameters(multimodal_phi)
projection_params = count_parameters(projection_layer)
adapter_params = trainable_params  # Since only adapter layers are trainable

print(f"Total parameters: {total_params:,}")
print(f"Projection layer parameters: {projection_params:,}")
print(f"Adapter (trainable) parameters: {adapter_params:,}")
print(f"Percentage of trainable parameters: {(adapter_params / total_params) * 100:.2f}%")


Total parameters: 2,782,307,840
Projection layer parameters: 1,313,280
Adapter (trainable) parameters: 1,313,280
Percentage of trainable parameters: 0.05%


In [8]:

# -- instruct_data = load_dataset("liuhaotian/LLaVA-Instruct-150K", split='train')
# Check for local copy first
if os.path.exists("model_instruct150k"):
    instruct_data = load_dataset("model_instruct150k", split='train')
    print("instruct data loaded from local")
else:
    print("loading instruct data from HF")
    instruct_data = load_dataset("liuhaotian/LLaVA-Instruct-150K", split='train')
    # Save the dataset locally for future use
    instruct_data.save_to_disk("model_instruct150k")
print("instruct data loaded")

instruct_data = instruct_data.filter(lambda x: x['id'] in image_ID.keys())
train_dataset = InstructDataset(instruct_data, clip_embeddings, tokenizer, image_ID)


instruct data loaded from local
instruct data loaded


In [9]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    learning_rate=2e-5,
    save_steps=1000,
    save_total_limit=2,
    remove_unused_columns=False,
    fp16=False,
    bf16=False,  # MPS doesn't support bfloat32
    tf32=False,
    half_precision_backend="auto",
)

multimodal_phi = multimodal_phi.to(torch.float32)

trainer = Trainer(
    model=multimodal_phi.to(device),
    args=training_args,
    train_dataset=train_dataset,
)

trainer.train()

100%|██████████| 120/120 [13:42<00:00,  6.85s/it]

{'train_runtime': 822.0464, 'train_samples_per_second': 1.146, 'train_steps_per_second': 0.146, 'train_loss': 10.175467936197917, 'epoch': 3.0}





TrainOutput(global_step=120, training_loss=10.175467936197917, metrics={'train_runtime': 822.0464, 'train_samples_per_second': 1.146, 'train_steps_per_second': 0.146, 'total_flos': 0.0, 'train_loss': 10.175467936197917, 'epoch': 3.0})

In [None]:
# Save only the LoRA state dict
multimodal_phi.save_pretrained("fine_tuned_phi_lora", state_dict=multimodal_phi.state_dict())
# save projection layer
torch.save(projection_layer.state_dict(), "projection_layer.pt")

In [None]:
exit()
    
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")

# Load the LoRA configuration
peft_config = PeftConfig.from_pretrained("fine_tuned_phi_lora")

projection_layer = ProjectionLayer(clip_embedding_dim, model_phi.config.hidden_size)
projection_layer.load_state_dict(torch.load("projection_layer.pt"))

# Load the fine-tuned model
fine_tuned_model = PeftModel.from_pretrained(base_model, "fine_tuned_phi_lora")

# use base, projection layer, and fine tuned model to make whole model
whole_model = MultimodalPhiWithAdapter(base_model, projection_layer)
whole_model.load_state_dict(torch.load("fine_tuned_phi_lora.pt"))
# use whole model to generate text
inputs = tokenizer("Hello, how are you?", return_tensors="pt").to(device)
whole_model.generate(**inputs)