In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
import torch
from transformers import  AutoTokenizer

model_name = "microsoft/phi-2"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

from transformers import AutoModelForCausalLM, AutoConfig

model_name = "microsoft/phi-2"

config = AutoConfig.from_pretrained(
    model_name,
    vocab_size=len(tokenizer),
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    trust_remote_code=True
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
model_name = "microsoft/phi-2"
# phi2_model_pretrained = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     trust_remote_code=True,  
# )
phi2_model_pretrained = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

### Create dataset

In [4]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd 
import json
import os 
import h5py


In [5]:
def get_image_name(image_id_from_caption, list_image_info): 
    for img in list_image_info: 
        if img['id'] == image_id_from_caption: 
            img_name = img['file_name'].split('.')[0]
            return img['file_name'].split('.')[0]
    return 'NoImgNameFound'

In [6]:
# file_path_captions_coco = '/media/App/amaranth/lavanya/Capstone_data/annotations_trainval2017/annotations/captions_train2017.json'

# with open(file_path_captions_coco) as f:
#    data = json.load(f)

# captions_info = []
# for a in data['annotations']: 
#     captions_info.append([a['image_id'], a['caption'], a['id']])

# captions_info_df = pd.DataFrame(data=captions_info, columns=['image_id', 'caption', 'caption_id'])
# captions_info_df['image_name'] = captions_info_df['image_id'].apply(lambda x: get_image_name(x, data['images']))
# captions_info_df['image_name'] = captions_info_df['image_name'].apply(lambda x: '0'*(12-len(str(x))) + str(x))
# captions_info_df.to_csv('captions_images_map_COCO_train2017.csv', index = False)

In [7]:
captions_info_df = pd.read_csv('captions_images_map_COCO_train2017.csv')

  captions_info_df = pd.read_csv('captions_images_map_COCO_train2017.csv')


In [8]:
import h5py    
import numpy as np    

In [9]:
class COCO_CLIP_Dataset(Dataset):

    def __init__(
        self, caption_file, embedding_path, tokenizer, max_token_len_data):
        self.embedding_path = embedding_path
        self.caption_file = caption_file
        self.tokenizer = tokenizer
        self.max_token_len_data = max_token_len_data

    def __len__(self):
        return len(self.caption_file)
    
    def __getitem__(self, index):
        row = self.caption_file.iloc[[index]]
        df_img = row['image_name'].values[0]
        img_base_name = '0'*(12-len(str(df_img))) + str(df_img)
        img_base_name = img_base_name.replace(' ', '0')
        img_clip_embedding_path = os.path.join(self.embedding_path, f'{img_base_name}.h5')

        np_array_embed_img = h5py.File(img_clip_embedding_path,'r+')['image_features'][()]
        
        img_caption = row['caption'].values[0] ## Tokenize this 
        img_caption_tokenized = self.tokenizer(img_caption, return_tensors="pt", 
                                               return_attention_mask=False).input_ids

        ## put bos, eos, and padding for batch 
        input_bos = torch.cat((torch.tensor(self.tokenizer.bos_token_id).view((1,1)), 
                                                       img_caption_tokenized), dim=1)
        input_eos = torch.cat((input_bos, 
                               torch.tensor(self.tokenizer.eos_token_id).view((1,1))), dim=1)
        
        if (self.max_token_len_data - input_eos.shape[1]) > 0: 
            input_final =  torch.cat((input_eos,torch.tensor([self.tokenizer.pad_token_id]*(self.max_token_len_data - input_eos.shape[1])).unsqueeze(0)), dim=1)
        else: 
            input_final = input_eos
        
        return torch.tensor(np_array_embed_img).squeeze(0), input_final.squeeze(0)

In [10]:
max_token_len_data = 75
dataset = COCO_CLIP_Dataset(captions_info_df, 
                            '/media/App/amaranth/lavanya/Capstone_data/clip_features_base_patch32/', 
                            tokenizer, max_token_len_data)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleResBlock(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.pre_norm = nn.LayerNorm(input_size)
        self.proj = nn.Sequential(
            nn.Linear(input_size, input_size),
            nn.GELU(),
            nn.Linear(input_size, input_size)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)
    
class Phi2wrapper(nn.Module):
    
    #This defines the structure of the NN.
    def __init__(self, input_dim_CLIP=768, input_dim_phi2=2560, 
                 phi2_model=phi2_model_pretrained, 
                 max_token_len_data=max_token_len_data, tokenizer=tokenizer):
        
        super(Phi2wrapper, self).__init__()

        self.input_dim_CLIP = input_dim_CLIP
        self.input_dim_phi2 = input_dim_phi2
        self.projection_img = nn.Linear(self.input_dim_CLIP, self.input_dim_phi2, 
                                        bias=False)
        self.resblock = SimpleResBlock(self.input_dim_phi2)
        self.phi2_model = phi2_model
        self.max_token_len_data = max_token_len_data
        self.tokenizer = tokenizer

    def forward(self, x):

        x = self.projection_img(x)
        x = self.resblock(x)

        # x = self.phi2_model.forward(inputs_embeds=x)
        x = self.phi2_model.generate(inputs_embeds=x, 
                                     max_new_tokens=self.max_token_len_data, 
                                     output_scores=True, return_dict_in_generate = True, 
                                     pad_token_id=self.tokenizer.eos_token_id)

        # x = self.phi2_model.model.layers[0](x)
        # for layer_idx in range(1, 32): 
        #     x = self.phi2_model.model.layers[layer_idx](x[0])
                
        # x = self.phi2_model.model.final_layernorm(x[0])
        # x = self.phi2_model.lm_head(x)
        
        return x 

device = 'cuda:1'
torch.set_grad_enabled(True)  
phi2_projection_model = Phi2wrapper().to(device=device)

## Freezing phi-2 for projection layer training 
for name, param in phi2_projection_model.named_parameters():
    if "phi2_model" in name:
        param.requires_grad = False
    else: 
        param.requires_grad = True

In [12]:
batch_size_train = 4
train_dataloader = DataLoader(dataset, batch_size=batch_size_train, shuffle=True)    

In [13]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, 
                                    phi2_projection_model.parameters()), 
                            lr=1e-5, eps=1e-9) 


In [14]:
num_epochs = 10
vocab_size = 50295

phi2_projection_model.train()

for epoch in range(num_epochs):

    print(f"working on epoch {epoch}")

    loss_epoch = 0 
    for batch in train_dataloader:

        optimizer.zero_grad()

        input = batch[0]
        gt = batch[1] 
        output = phi2_projection_model(input.to(device))

        ## need to map gt token_ids to one-hot enocding vocab_size
        gt_one_hot = torch.nn.functional.one_hot(gt, vocab_size).to(torch.float32)

        ## output in correct shape
        output_tensor_new = torch.empty(batch_size_train, max_token_len_data, vocab_size)
        for idx, s in enumerate(output.scores): 
            output_tensor_new[:, idx, :] = s

        ## ce loss between output_tensor_new and gt_one_hot
        loss = F.cross_entropy(output_tensor_new, gt_one_hot)

        loss.requires_grad = True
        loss.backward()

        optimizer.step() 
        # optimizer.zero_grad(set_to_none=True) 

        ## print gt and output decoded tokens for visual inspection for the 1st el of batch
        gt_idx_0_decoded = tokenizer.decode(gt[0]).replace('<|endoftext|>', '')
        output_idx_0_decoded = tokenizer.decode(torch.argmax(output_tensor_new[0], dim=1)).replace('<|endoftext|>', '')

        ## print loss 
        print(f"Loss: {loss}\nCaption (gt): {gt_idx_0_decoded}\nCaption (pred): {output_idx_0_decoded}")

    print(f"Epoch {epoch} finished")
    print("")

working on epoch 0
Loss: 0.006729808170348406
Caption (gt): An all-way stop sign next to a red light
Caption (pred):  Ob Strategyjahumbles Emb chem Mir ank contributing concealking Knot031 butterflies FOX996 voluntesilver aber moderatedozenika bettingModLoader interfere envisioned chem envisioned imm]"etc suc031 afford tin Over HTTPS lifetime colours Kul desperate populated150 Memorial suc bould illumvelengthdozen Reggie sax Meow Perez996 afford733king tomb envisioneddozenidges996 morphology buffers031ments envisioned teaching996 gy996 upvelength sher Ob
Loss: 0.006720697041600943
Caption (gt): a couple of chairs in front of a kitchen counter
Caption (pred): dozen saxbrow condensed introductory commonly exhaust Emb Emblish becoming disenfranchging EmblishEasytv disruptive Emb prin introductory Probably rooms suc bould Rik Coy rooms Myree illum Coy rooms sucSynopsisropri Supreme AAClish Dragonbound ape loosen exhaust ENT Rik k Dragonboundsilver Skiptv bowling resultant costumes unless E