In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
%pip install -qq -U datasets transformers pyarrow
%pip install -qq --upgrade transformers ftfy accelerate regex tqdm
%pip install git+https://github.com/openai/CLIP.git


**All the imports**

In [None]:
import os
import torch
import pickle
import json
import torch.nn as nn
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from tqdm import tqdm
from pathlib import Path

**Model definition**

First, the projection layer...

In [None]:
class IdentityMap(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

    @property
    def config(self):
        return {"mm_projector_type": 'identity'}


class SimpleResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.pre_norm = nn.LayerNorm(in_channels)

        self.proj = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.GELU(),
            nn.Linear(out_channels, out_channels)
        )
    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)


class SimpleLinearBlock(nn.Module):
    def __init__(self, in_size, out_size, hidden_size = 50, add_residual_connection=True):
        super().__init__()
        self.pre_norm = nn.LayerNorm(in_size)
        self.proj = nn.Sequential(nn.Linear(in_size, hidden_size),
                                  nn.GELU(),
                                  nn.Linear(hidden_size, out_size))
        self.add_residual_connection = add_residual_connection
        
    def forward(self,x):
        return self.proj(x)


def build_resnet_projection_layer(in_channels, out_channels, hidden_size = 50, mlp_depth=2):
    res_block = SimpleResBlock(in_channels, out_channels, hidden_size = hidden_size)
    for _ in range(1, mlp_depth):
        modules.append(res_block)
    return nn.Sequential(*modules)

and then the model....

In [None]:
class MultiModalGPT(nn.Module):
    """
    Pytorch Lightning module for Transformer

    """
    def __init__(self,
                 llm_model,
                 tokenizer,
                 projection_layer_in_channels,
                 projection_layer_out_channels,
                 device,
                 hidden_size = 32,
                 ):
        super(MultiModalGPT, self).__init__()
        self.llm_model = None
        self.tokenizer = None
        self.llm_model = llm_model
        # freeze the llm
        for param in self.llm_model.parameters():
            param.requires_grad = False
        self.tokenizer = tokenizer
        self.projection_layer = SimpleLinearBlock(projection_layer_in_channels,projection_layer_out_channels, hidden_size=hidden_size)
        self.device = device
    
    
    def forward(self, x, max_length=1):
        x = self.projection_layer(x)
        with torch.no_grad():  
            x = self.llm_model.generate(inputs_embeds = x, max_length=max_length)
        return x


**Data loader**

In [None]:
def get_absolute_paths(directory_path, max_files = None):
    absolute_paths = []
    image_ids = []

    # Check if the given path is a valid directory
    if os.path.isdir(directory_path):
        # Iterate over all files in the directory
        for root, _, files in os.walk(directory_path):
            for file in tqdm(files):
                # extract image ID
                image_ids.append(Path(file).stem)
                # Construct the absolute path for each file
                absolute_path = os.path.abspath(os.path.join(root, file))
                absolute_paths.append(absolute_path)
                if max_files is not None and len(absolute_paths) > max_files:
                    break
    return absolute_paths, image_ids


def parse_captions_file(captions_path, captions_key):
    """
    Read a JSON file and return its contents as a dictionary.

    Parameters:
    - file_path (str): The path to the JSON file.

    Returns:
    - dict: The contents of the JSON file as a dictionary.
    """
    try:
        with open(captions_path, 'r') as file:
            data = json.load(file)
        captions = {}
        annotations = data[captions_key]
        for annotation in annotations:
            captions[annotation['image_id']] = annotation['caption']
        return captions
    except FileNotFoundError:
        print(f"Error: File not found - {captions_path}")
    except json.JSONDecodeError:
        print(f"Error: Unable to decode JSON in file - {captions_path}")

        
def load_pickle_file(file_path):
    with open(file_path, 'rb') as fh:
        data = pickle.load(fh)
    keys = list(data.keys()) 
    assert len(keys) == 1
    return data[keys[0]]


class PickleDataset(Dataset):

    def __init__(self, 
                 images_path,
                 captions_path,
                 captions_key,
                 tokenizer, 
                 max_embd_len=2048):
        super().__init__()
        self.tokenizer = tokenizer
        self.ds = None
        self.image_file_names = None
        self.captions_key = captions_key
        self.images_path = images_path
        self.all_images, self.image_ids = get_absolute_paths(images_path)
        self.captions = parse_captions_file(captions_path, captions_key)
        

    def __len__(self):
        return len(self.image_ids)


    def __getitem__(self, idx):

        # get image embeddings
        img_embds = load_pickle_file(self.all_images[idx])
        img_embds = torch.tensor(np.expand_dims(img_embds,1))
        this_img_id = self.image_ids[idx]
        
        # get caption
        caption = self.captions[int(this_img_id)]
        return {
            "image_embeddings": img_embds,
            "image_id": this_img_id,
            "caption": caption,
        }
    


**Download the LLM and tokenizer**

In [None]:
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")

**Define Hyperparameters**

In [None]:
train_dataset_path = '/kaggle/input/coco2017-clip-image-embeddings/coco_embeddings_clip_vision_1x768'
captions_path = '/kaggle/input/coco-2017-dataset/coco2017/annotations/captions_train2017.json'
captions_key = 'annotations'
batch_size = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'


**Define train dataset and train dataloader**

In [None]:
train_ds = PickleDataset(train_dataset_path, captions_path, captions_key, phi_tokenizer)


In [None]:
#val_ds = HindiAestheticsDataset(val_dataset_path, tokenizer, block_size=block_size)
train_dataloader = DataLoader(dataset = train_ds,
                              batch_size = batch_size,
                              num_workers = 1,
                              collate_fn = None,
                              shuffle = True)
#val_dataloader = DataLoader(dataset = val_ds,
#                            batch_size = 1,
#                            num_workers = 1,
#                            collate_fn = None,
#                            shuffle = False)

In [None]:
cc = next(iter(train_dataloader))
input_embeds = cc['image_embeddings'].to(device)

In [None]:
input_embeds.shape

In [None]:
multimodal_gpt_model = MultiModalGPT(phi_model, phi_tokenizer, 1, 2560, device, hidden_size = 32)
multimodal_gpt_model = multimodal_gpt_model.to(device)

In [None]:
#del multimodal_gpt_model
#del input_embeds
#del phi_model
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

In [None]:
input_embeds

In [None]:
output = multimodal_gpt_model(input_embeds)

In [None]:
!nvidia-smi