# Fine tune Projector Layer + Phi2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoConfig

#### MLP projector layer taken from:
Courtesy: https://github.com/sshh12/multi_token/blob/81ee75cd4435ebd5c7c7c3cf42c136c4053320fb/multi_token/modalities/projectors.py

In [None]:
def build_patch_mlp_projector(
    input_hidden_size: int, lm_hidden_size: int, num_layers: int
) -> nn.Module:
    modules = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
    for _ in range(1, num_layers):
        modules.append(nn.GELU())
        modules.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
    return nn.Sequential(*modules)


class _MLPVectorProjector(nn.Module):
    def __init__(
        self, input_hidden_size: int, lm_hidden_size: int, num_layers: int, width: int
    ):
        super(_MLPVectorProjector, self).__init__()
        self.mlps = nn.ModuleList()
        for _ in range(width):
            mlp = [nn.Linear(input_hidden_size, lm_hidden_size, bias=False)]
            for _ in range(1, num_layers):
                mlp.append(nn.GELU())
                mlp.append(nn.Linear(lm_hidden_size, lm_hidden_size, bias=False))
            self.mlps.append(nn.Sequential(*mlp))

    def forward(self, x):
        return torch.cat([mlp(x) for mlp in self.mlps], dim=-2)


def build_mlp_vector_projector(
    input_hidden_size: int, lm_hidden_size: int, num_layers: int, num_tokens: int
):
    return _MLPVectorProjector(
        input_hidden_size, lm_hidden_size, num_layers, num_tokens
    )

#### Load the projection model that we obtained from Step 1


#### Use the projector layer + phi2 model from step 1. Use Q&A instead of captions to fine tune the custom model

In [None]:
from transformers import AutoModelForCausalLM
model_name = "microsoft/phi-2"
phi2 = AutoModelForCausalLM.from_pretrained(
            model_name,
            trust_remote_code=True,
            # torch_dtype = torch.float16
        ).to("cuda")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
class ImageWithPhiLayer(nn.Module):
    def __init__(self,
                 clip_emb:int = 512,
                 token_emb: int = 2560,
                 projection_n_tokens: int = 4,
                 projection_n_layers: int = 1
                ):
        super().__init__()
        self.projection_n_tokens = projection_n_tokens
        self.ll1 = build_mlp_vector_projector(
            clip_emb, token_emb, projection_n_layers, self.projection_n_tokens).to("cuda")
        self.ll1.load_state_dict(torch.load('proj_layer.pth'))
        model_name = "microsoft/phi-2"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.vocab_size = len(self.tokenizer)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.phi2Model = phi2
        self.token_embedding = self.phi2Model.get_submodule('model.embed_tokens')
        #for name, param in self.phi2Model.named_parameters():
                #param.requires_grad = False

    def generate_text_from_embeddings(self, logits):
        # Apply softmax to obtain probabilities
        probabilities = logits.softmax(dim=2)  # Softmax along the last dimension

        # Find the index of the class with the highest probability for each token in each sequence
        predicted_indices = torch.argmax(probabilities, dim=2)

        # Decode each sequence of predicted indices back into text
        predicted_texts = [self.tokenizer.decode(seq) for seq in predicted_indices]
        #print(predicted_texts)
        return predicted_texts


    def forward(self, x, QnAtokens, QTokenLength, Atokens):
        x = self.ll1(x)
        QnAtoken_embeddings = self.token_embedding(QnAtokens)

        inputs = torch.concat((x, QnAtoken_embeddings), axis=-2)
        outputs = self.phi2Model(inputs_embeds=inputs)
        predictions = self.generate_text_from_embeddings(outputs.logits)

        b,t,v = outputs.logits.shape #batch size, token size, vocab size

        # Calculate loss only for answer part, so take logits from answer token
        for i in range(b):
            output_logits_per_batch = outputs.logits[i,:,:].squeeze(0)

            AtokenSliceLength = Atokens[i].shape[0] -  QTokenLength[i].item()

            if (i == 0):
                loss = F.cross_entropy(
                    output_logits_per_batch[self.projection_n_tokens + QTokenLength[i].item():-2, :],
                    QnAtokens[i][QTokenLength[i].item() + 2:],
                    ignore_index=50256,)
            else:
                loss += F.cross_entropy(
                    output_logits_per_batch[self.projection_n_tokens + QTokenLength[i].item():-2, :],
                    QnAtokens[i][QTokenLength[i].item() + 2:],
                    ignore_index=50256,)

        return loss, predictions

#### Add adaptors (for image and text modes) using PEFT

In [None]:
model = ImageWithPhiLayer()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
import peft
from peft import LoraConfig

lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    target_modules=[
        "Wqkv",
        "out_proj",
        "fc1",
        "fc2",
    ]
)

In [None]:
import copy
module = model.to('cuda')
peft_model = peft.get_peft_model(module, peft_config)
optimizer = torch.optim.AdamW(module.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()
peft_model.print_trainable_parameters()

trainable params: 52,428,800 || all params: 2,837,355,520 || trainable%: 1.8478050998698958


#### Get datset that will be used to fine tune Phi-2 + projector layer - use llava_instruct_150k

In [None]:
import json

instruct_dataset = f'./llava_instruct_150k.json'
with open(instruct_dataset, 'r') as f:
    instruct_data = json.load(f)

In [None]:
instruct_data[0]

{'id': '000000033471',
 'image': '000000033471.jpg',
 'conversations': [{'from': 'human',
   'value': '<image>\nWhat are the colors of the bus in the image?'},
  {'from': 'gpt', 'value': 'The bus in the image is white and red.'},
  {'from': 'human',
   'value': 'What feature can be seen on the back of the bus?'},
  {'from': 'gpt', 'value': 'The back of the bus features an advertisement.'},
  {'from': 'human',
   'value': 'Is the bus driving down the street or pulled off to the side?'},
  {'from': 'gpt',
   'value': 'The bus is driving down the street, which is crowded with people and other vehicles.'}]}

#### Create a custom dataset that consists of image embedding from CLIP with Q&A from 150k dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

class CustomTextDataset(Dataset):
    def __init__(self, json_data, image_embedding_dict,  tokenizer, maxContext=1024):
        self.image_embedding_dict = image_embedding_dict
        self.tokenizer = tokenizer
        self.json_data = json_data
        self.maxContext = maxContext

        self.entries = []
        for entry in json_data:
            # Get the text corresponding to the image
            image = entry['image']
            image_embedding = self.getEmbeddingForImage(image)
            if image_embedding is None:
                continue

            conversations = entry['conversations']
            for i in range(len(conversations)):
                if conversations[i]['from'] == 'human':
                    if len(conversations[i]['value'] + conversations[i + 1]['value']) > 512:
                        continue
                    question = 'Question: ' + conversations[i]['value'].lstrip('<image>\n')
                    answer = 'Answer: ' + conversations[i + 1]['value']
                    # Assuming the next message is from 'gpt' and contains the answer
                    self.entries.append({
                        'image_name': image,
                        'image_embedding': image_embedding,
                        'Question': question,
                        'Answer': answer,
                        'QnAText': question + answer
                        })
        print('------------- num entries = -----------------')
        print(len(self.entries))

    def getEmbeddingForImage(self, image):
        if image in self.image_embedding_dict:
            image_embedding = self.image_embedding_dict[image]
            return image_embedding
        else:
            return None

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        # Get the entry
        entry = self.entries[idx]
        image_name = entry['image_name']
        Q_caption_tokens = tokenizer.encode(entry['Question'], add_special_tokens=True)
        A_caption_tokens = tokenizer.encode(entry['Answer'], add_special_tokens=True)
        QnA_captions_tokens = tokenizer.encode(entry['QnAText'], add_special_tokens=True)
        QTokensLength = len(Q_caption_tokens)

        Q_caption_tokens = Q_caption_tokens + [tokenizer.pad_token_id] * (self.maxContext - len(Q_caption_tokens))
        A_caption_tokens = A_caption_tokens + [tokenizer.pad_token_id] * (self.maxContext - len(A_caption_tokens))
        QnA_captions_tokens = QnA_captions_tokens + [tokenizer.pad_token_id] * (self.maxContext - len(QnA_captions_tokens))

        return {'image_name': entry['image_name'],
                'QText': entry['Question'],
                'AText': entry['Answer'],
                'image_embedding':  entry['image_embedding'],
                'QTokens':torch.tensor(Q_caption_tokens),
                'ATokens':torch.tensor(A_caption_tokens),
                'QnA_tokens':torch.tensor(QnA_captions_tokens),
                'QTokensLength': QTokensLength
               }



In [None]:
img_emb = torch.load("img_embeddings.pth").unsqueeze(1)
print(img_emb.shape)

torch.Size([118287, 1, 512])


In [None]:
with open("./image_names.json", 'r') as file:
    image_names = json.load(file)
imgEmbDict = dict(zip(image_names, img_emb))

In [None]:
model_name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

custom_dataset = CustomTextDataset(instruct_data, imgEmbDict,  tokenizer)
custom_dataloader = DataLoader(custom_dataset, batch_size=4, shuffle=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


------------- num entries = -----------------
225484


#### Train - finetune the projector layer + Phi-2 peft model with QnA dataset

In [None]:
# Training loop
num_epochs = 200
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

for epoch in range(num_epochs):
    model.train()
    for ix, batch in enumerate(custom_dataloader):
        embeddings = batch['image_embedding'].to('cuda')
        QnAtokens = batch['QnA_tokens'].to('cuda')
        QTokenLength = batch['QTokensLength'].to('cuda')
        ATokens = batch['ATokens'].to('cuda')

        loss, predictions = peft_model(embeddings, QnAtokens, QTokenLength, ATokens)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if ix % 10 == 0:
            print("------------captions text = -------------------")
            print(batch['QText'][0])
            print(batch['AText'][0])
            print("------------Teacher forced predictions text = -------------------")
            print(predictions[0].rstrip('<|endoftext|>').rstrip("\n")[:200])
            print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")
    torch.save(
        peft_model.state_dict(),
        "stage_2_model.pth"
    )

------------captions text = -------------------
Question: Describe the color and type of backpack the baseball player is carrying.
<image>
Answer: The baseball player is carrying a pink feminine backpack while walking onto the field.
------------Teacher forced predictions text = -------------------
-,. a mark aending the scene of shape of the. girl player is wearing. A| src
: A baseball player is carrying a red backpack backpack. wearing on the field. 












































Epoch 1/200, Loss: 30.28024673461914
------------captions text = -------------------
Question: Where is the garbage truck in relation to the garbage cans?
Answer: The garbage truck is near two garbage cans on the side of the road, possibly picking up trash from those cans at the curb.
------------Teacher forced predictions text = -------------------

,. a: a a the cat can going the to the house can? : The garbage truck is in the garbage cans. the street of the road. but in up trash. the cans. t

KeyboardInterrupt: 

In [None]:
torch.save(
    peft_model.state_dict(),
    "stage_2_model.pth"
)