In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
import torch
from transformers import  AutoTokenizer
from transformers import AutoModelForCausalLM, AutoConfig

device = 'cuda:1'

model_name = "microsoft/phi-2"
phi2_model_pretrained = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,  
    torch_dtype = torch.float16
)

phi2_model_pretrained.to(device)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
special_tokens_dict = {'pad_token': '<|PAD|>', 'bos_token': '<|BOS|>'}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
phi2_model_pretrained.resize_token_embeddings(len(tokenizer))  # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Embedding(50297, 2560)

### Create dataset

In [3]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd 
import json
import os 
import h5py

In [4]:
def get_image_name(image_id_from_caption, list_image_info): 
    for img in list_image_info: 
        if img['id'] == image_id_from_caption: 
            img_name = img['file_name'].split('.')[0]
            return img['file_name'].split('.')[0]
    return 'NoImgNameFound'

In [5]:
# file_path_captions_coco = '/media/App/amaranth/lavanya/Capstone_data/annotations_trainval2017/annotations/captions_train2017.json'

# with open(file_path_captions_coco) as f:
#    data = json.load(f)

# captions_info = []
# for a in data['annotations']: 
#     captions_info.append([a['image_id'], a['caption'], a['id']])

# captions_info_df = pd.DataFrame(data=captions_info, columns=['image_id', 'caption', 'caption_id'])
# captions_info_df['image_name'] = captions_info_df['image_id'].apply(lambda x: get_image_name(x, data['images']))
# captions_info_df['image_name'] = captions_info_df['image_name'].apply(lambda x: '0'*(12-len(str(x))) + str(x))
# captions_info_df.to_csv('captions_images_map_COCO_train2017.csv', index = False)

In [6]:
captions_info_df = pd.read_csv('captions_images_map_COCO_train2017.csv')

  captions_info_df = pd.read_csv('captions_images_map_COCO_train2017.csv')


In [7]:
import h5py    
import numpy as np    

In [8]:
class COCO_CLIP_Dataset(Dataset):

    def __init__(
        self, caption_file, embedding_path, tokenizer, max_token_len_data):
        
        self.embedding_path = embedding_path
        self.caption_file = caption_file
        self.tokenizer = tokenizer
        self.max_token_len_data = max_token_len_data

    def __len__(self):
        return len(self.caption_file)
    
    def __getitem__(self, index):

        row = self.caption_file.iloc[[index]]

        df_img = row['image_id'].values[0]
        img_base_name = '0'*(12-len(str(df_img))) + str(df_img)
        img_base_name = img_base_name.replace(' ', '0')
        img_clip_embedding_path = os.path.join(self.embedding_path, f'{img_base_name}.h5')

        np_array_embed_img = h5py.File(img_clip_embedding_path,'r+')['image_features'][()]
        
        img_caption = row['caption'].values[0] ## Tokenize this 
        img_caption_tokenized = self.tokenizer(img_caption, return_tensors="pt", 
                                               return_attention_mask=False).input_ids

        ## put bos, eos, and padding for batch         
        # input_bos = torch.cat((torch.tensor(self.tokenizer.bos_token_id).view((1,1)), 
        #                                                img_caption_tokenized), dim=1)

        input_bos = img_caption_tokenized

        input_eos = torch.cat((input_bos, 
                               torch.tensor(self.tokenizer.eos_token_id).view((1,1))), dim=1)
        
        if (self.max_token_len_data - input_eos.shape[1]) > 0: 
            input_final =  torch.cat((input_eos,torch.tensor([self.tokenizer.pad_token_id]*(self.max_token_len_data - input_eos.shape[1])).unsqueeze(0)), dim=1)
        else: 
            input_final = input_eos
        
        return torch.tensor(np_array_embed_img).squeeze(0), input_final.squeeze(0)

In [9]:
def file_exists(image_id, fpath = '/media/App/amaranth/lavanya/Capstone_data/clip_features_base_patch32/'): 

    n = '0'*(12-len(str(image_id))) + str(image_id) + '.h5'
    fp = os.path.join(fpath, n)

    if os.path.exists(fp): 
        return True
    else: 
        return False

In [10]:
### captions_info_df contains for 1 image multiple entries, lets reduce keeping one image, one entry. 
captions_info_df_subset = captions_info_df.drop_duplicates(subset='image_id', keep='first')

In [11]:
max_token_len_data = 75
dataset = COCO_CLIP_Dataset(captions_info_df_subset, 
                            '/media/App/amaranth/lavanya/Capstone_data/clip_features_base_patch32/', 
                            tokenizer, max_token_len_data)

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleResBlock(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.pre_norm = nn.LayerNorm(input_size)
        self.proj = nn.Sequential(
            nn.Linear(input_size, input_size),
            nn.GELU(),
            nn.Linear(input_size, input_size)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)
    
class Phi2wrapper(nn.Module):
    
    #This defines the structure of the NN.
    def __init__(self, input_dim_CLIP=768, input_dim_phi2=2560, 
                 phi2_model=phi2_model_pretrained, 
                 max_token_len_data=max_token_len_data, tokenizer=tokenizer, device=device):
        
        super(Phi2wrapper, self).__init__()

        self.input_dim_CLIP = input_dim_CLIP
        self.input_dim_phi2 = input_dim_phi2
        self.projection_img = nn.Linear(self.input_dim_CLIP, self.input_dim_phi2, 
                                        bias=False)
        self.resblock = SimpleResBlock(self.input_dim_phi2)
        self.phi2_model = phi2_model
        self.max_token_len_data = max_token_len_data
        self.tokenizer = tokenizer

        self.device = device

        bos = self.tokenizer("Image: ", return_tensors="pt", return_attention_mask=False)
        eoi = self.tokenizer(" Caption: ", return_tensors="pt", return_attention_mask=False)
    
        self.bos_embedding = self.phi2_model.get_input_embeddings()(bos.input_ids.to(self.device)).squeeze(0)
        self.eoi_embedding = self.phi2_model.get_input_embeddings()(eoi.input_ids.to(self.device)).squeeze(0)
        self.eos_embedding = self.phi2_model.get_input_embeddings()(torch.tensor(self.tokenizer.eos_token_id).to(self.device)).unsqueeze(0)

    def forward(self, x, input_caption):

        x = self.projection_img(x)
        x = self.resblock(x)

        batch_size = x.shape[0]

#         imgae_prompt_embed = x.clone()
        x = torch.cat((self.bos_embedding.repeat(batch_size,1,1), x, 
                    self.eoi_embedding.repeat(batch_size,1,1)), dim=1)
        
        loss = 0 
        word_output_pred_tokens = None

        for idx in range(input_caption.shape[1]): 
            
#             next_word = self.phi2_model.forward(inputs_embeds=x.to(torch.float16))["logits"][:, -1, :]
            
            next_word = self.phi2_model.generate(inputs_embeds=x.to(torch.float16), max_new_tokens = 1, 
                                            output_scores=True, return_dict_in_generate = True, 
                                            pad_token_id=self.tokenizer.pad_token_id, 
                                            bos_token_id=self.tokenizer.bos_token_id, 
                                            eos_token_id=self.tokenizer.eos_token_id) ## this gives first word  
                        
            caption_word_token = input_caption[:,idx]
            
            no_of_pad_tokens = sum(torch.eq(torch.tensor([self.tokenizer.pad_token_id]*batch_size).to(self.device), caption_word_token))
            if no_of_pad_tokens == torch.tensor(batch_size): 
                break 
            
            caption_word_embedding = self.phi2_model.get_input_embeddings()(caption_word_token).unsqueeze(1)
            
            ## instead of append like instruct image output words.. instruct image w1 out, instruct image w2 output ..
            x = torch.cat((x, caption_word_embedding), dim=1)

#             caption_word_token_new = input_caption[:, :idx+1]
#             caption_word_embedding_new = self.phi2_model.get_input_embeddings()(caption_word_token_new)
#             x = torch.cat((self.bos_embedding.repeat(batch_size,1,1), imgae_prompt_embed, caption_word_embedding_new,  
#                     self.eoi_embedding.repeat(batch_size,1,1)), dim=1)

#             loss_val = F.cross_entropy(F.softmax(next_word, dim=-1), caption_word_token, 
#                         ignore_index=self.tokenizer.pad_token_id, label_smoothing=0.1)

            loss_val = F.cross_entropy(F.softmax(next_word.scores[0], dim=-1), caption_word_token, 
                        ignore_index=self.tokenizer.pad_token_id, label_smoothing=0.1)

            loss += loss_val
            
#             if word_output_pred_tokens is None: 
#                 word_output_pred_tokens = torch.argmax(next_word,dim=-1).unsqueeze(1) 
#             else:

#                 word_output_pred_tokens = torch.cat((word_output_pred_tokens, torch.argmax(next_word,dim=-1).unsqueeze(1)), dim=1)
    
            if word_output_pred_tokens is None:
                word_output_pred_tokens = next_word.sequences[:, 1].unsqueeze(1)
            else:
                word_output_pred_tokens = torch.cat((word_output_pred_tokens, next_word.sequences[:, 1].unsqueeze(1)), dim=1)
    
        loss_tosend = loss/idx

        return loss_tosend, word_output_pred_tokens

        ### Without feature forcing
        # x = self.phi2_model.generate(inputs_embeds=x, 
        #                              max_new_tokens=self.max_token_len_data, 
        #                              output_scores=True, return_dict_in_generate = True, 
        #                              pad_token_id=self.tokenizer.eos_token_id, 
        #                              bos_token_id=self.tokenizer.bos_token_id, 
        #                              eos_token_id=self.tokenizer.eos_token_id)

        # return x 

torch.set_grad_enabled(True)  
phi2_projection_model = Phi2wrapper().to(device=device)

## Freezing phi-2 for projection layer training 
for name, param in phi2_projection_model.named_parameters():
    if "phi2_model" in name:
        param.requires_grad = False
    else: 
        param.requires_grad = True

In [13]:
batch_size_train = 32
train_dataloader = DataLoader(dataset, batch_size=batch_size_train, shuffle=True, num_workers=8)

num_batches_train_on = 1500  
num_batches_train_on, len(train_dataloader)

(1500, 3697)

In [14]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, 
                                    phi2_projection_model.parameters()), 
                            lr=1e-3, eps=1e-9) 

In [None]:
num_epochs = 10
vocab_size = len(tokenizer)

phi2_projection_model.train()
N_batches = len(train_dataloader)
                
for epoch in range(num_epochs):

    print(f"Working on epoch {epoch}")

    for iteration, batch in enumerate(train_dataloader):

        if iteration == num_batches_train_on: 
            break 

        print(f"Iteration {iteration}/{num_batches_train_on}", end='\r')

        optimizer.zero_grad()

        input_ = batch[0]
        gt = batch[1] 

        loss, output_pred_tokens = phi2_projection_model(input_.to(device), gt.to(device))

        loss.requires_grad = True
        loss.backward()

        optimizer.step()

        if (iteration % 1) == 0: 
            print("Loss:", loss)
            print("Predictions:", tokenizer.batch_decode(output_pred_tokens)[0].rstrip())
            print("Gt:", tokenizer.batch_decode(gt)[0].split('<|endoftext|>')[0])

    print("")
    print(f"Epoch {epoch} finished")
    print("")

Working on epoch 0
Loss: tensor(10.6511, device='cuda:1', requires_grad=True)
Predictions: 
 man sign. a to a intersection billboard light.
Theance
.


 and
Gt: A stop sign is next to an electronic traffic signal.
Loss: tensor(10.6895, device='cuda:1', requires_grad=True)
Predictions: 
.'s in a beach court. a tennisage on her finger hand.The of
Gt: a woman standing on a tennis court with a bandage on her left knee
Loss: tensor(10.7097, device='cuda:1', requires_grad=True)
Predictions: 
  number bear. in. the lake. the zoo in
Theald:
Gt: A large polar bear swimming underwater in a cage at a zoo.
Loss: tensor(10.5945, device='cuda:1', requires_grad=True)
Predictions: 
 man- white photograph of a man, a beach.
The.:. to.
.,

...
Gt: A black and white photo of a person on the street.
Loss: tensor(10.6653, device='cuda:1', requires_grad=True)
Predictions: 
 is many ways in the road light. One
,:g the one
, and
 of
Gt: There are two signs on the traffic signal.
Loss: tensor(10.6599, device='

Loss: tensor(10.6494, device='cuda:1', requires_grad=True)
Predictions: 
 man who into a cliff. a cliffboard.The
:.r..
Gt: A person jumping off a ledge on a skateboard
Loss: tensor(10.6863, device='cuda:1', requires_grad=True)
Predictions: 
 man signikie with a and and with food types of food.
The:
Gt: A wooden tabled topped with blue plates filled with different types of cakes.
Loss: tensor(10.6859, device='cuda:1', requires_grad=True)
Predictions: 
 is man isfer is a wave wave inThet the.. the.. the
Gt: This young male surfer rides a small wave
Loss: tensor(10.6563, device='cuda:1', requires_grad=True)
Predictions: 
 is a sal of words in on a branch.The
ation
yt
Gt: this is a pair of birds sitting on a log
Loss: tensor(10.6993, device='cuda:1', requires_grad=True)
Predictions: 
hest- employee player in the team roster for front of W empty of the.
Gt: Highest paid baseball players on the Cubs pose in front of an image of cash.
Loss: tensor(10.6758, device='cuda:1', requires_grad=True)

Loss: tensor(10.6795, device='cuda:1', requires_grad=True)
Predictions: 
  is 8, pepper, pineapple. a. 

:: the.
Gt: A pizza contains sausage, broccoli and cheese within it.    
Loss: tensor(10.6845, device='cuda:1', requires_grad=True)
Predictions: 
. man in a bookboard on her head as


 to the . man the
Gt: a young girl balances a surfboard on her head 
Loss: tensor(10.6594, device='cuda:1', requires_grad=True)
Predictions: 
  is her her hair in a bathroom.

:


..
Gt: A woman blow drying her hair in a bathroom.
Loss: tensor(10.6754, device='cuda:1', requires_grad=True)
Predictions: 
 man- with a and a vegetables.

ent...:

.
Gt: A white plate with broccoli and other vegetables.
Loss: tensor(10.6894, device='cuda:1', requires_grad=True)
Predictions: 
 man at is a of a salad bag.

lye
ed.
Gt: A BBQ sandwich on top of a plastic container.
Loss: tensor(10.6879, device='cuda:1', requires_grad=True)
Predictions: 
 b's on a table with eating her a cake.

: the
Gt: a woman sitting at a tabl

In [None]:
'''
TODO: 1. Tensorboard, 
2. Flash attention, 
3. Phi2 (float16), 
4.model.forwrd?, 
5.eos_token to inputprompt, 
6.Onecycle policy
7.Smaller ''' 