## Importing Libraries

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel, BitsAndBytesConfig
from PIL import Image
import json
import os
import requests
from io import BytesIO
import bitsandbytes as bnb
import types
from tqdm import tqdm


2024-10-25 15:46:47.163939: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-25 15:46:47.164005: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-25 15:46:47.165250: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-25 15:46:47.171488: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.cuda.is_available())

url = "https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_instruct_150k.json"
response = requests.get(url)

with open("llava_instruct_150k.json", "wb") as f:
    f.write(response.content)
    
    
# Opening JSON file - instruct150k
f = open('llava_instruct_150k.json')

# returns JSON object as a dictionary
data = json.load(f)

# create input pickle file by flattening the data
data_instruct150_flatten = []


r = 0

for a_idx,d in enumerate(data):
    image = d['image']
    image_url = 'http://images.cocodataset.org/train2017/' + image

    conv_iter = iter( d['conversations'])
    # Loops over the dataset data where each item d is a dictionary containing an image reference and conversations related to that image. a_idx is the index of the current entry in the dataset.
    for i in conv_iter:
      # Fetches the next item in the iterator, which should be the corresponding GPT response to the human question. The assumption is that the structure alternates between human and GPT responses.
      gpt_ans = next(conv_iter)
      # Checks the length of the GPT response. If the response exceeds 200 characters, it's skipped to avoid overly long answers (likely for cleaner or more concise data).
      if len(gpt_ans['value']) > 200: # filter out too long answers
          continue
      # Ensures that the current item i is from a human, and the next item gpt_ans is from GPT. This ensures that only valid human-question and GPT-response pairs are processed.
      if i['from'] == 'human' and gpt_ans['from'] == 'gpt':
        # The human question, with any occurrences of '<image>\n' or '\n<image>' removed (to clean up the question text).
        data_instruct150_flatten.append((image_url, i['value'].replace('<image>\n','').replace('\n<image>',''),gpt_ans['value']))

    # Every 10,000 entries, it prints the progress.
    if a_idx % 10000 == 0:
      print(f"{10000 * r} processed")
      r += 1
        
# Setting folder where we can store the images for the dataset
image_dir = "cached_images"
os.makedirs(image_dir, exist_ok=True)


## Written code to predownload the images so that during it won't download the images on fly. By predownloading the image we speed up the training a little bit
def download_images(data, image_dir):
    for item in tqdm(data, desc="Downloading images"):
        image_url = item[0]
        
        # Define the filename based on URL hash or unique ID
        image_filename = os.path.join(image_dir, image_url.split('/')[-1])
        
        # Skip download if the image is already cached
        if os.path.exists(image_filename):
            continue

        # Download and save the image
        try:
            response = requests.get(image_url, timeout=10)
            if response.status_code == 200:
                with open(image_filename, 'wb') as f:
                    f.write(response.content)
            else:
                print(f"Failed to download {image_url} (status code: {response.status_code})")
        except requests.exceptions.RequestException as e:
            print(f"Error downloading {image_url}: {e}")

# Run the download process
download_images(data_instruct150_flatten, image_dir)
        
        
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def get_clip_embedding(image_path):
    image = Image.open(image_path)
    inputs = clip_processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = clip_model.get_image_features(**inputs)
    return outputs.squeeze(0)  # Remove batch dimension

class Adapter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, hidden_states):
        residual = hidden_states
        hidden_states = self.layer(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states
    
    
class ImageProjection(nn.Module):
    def __init__(self, clip_embedding_dim, model_dim):
        super().__init__()
        self.linear = nn.Linear(clip_embedding_dim, model_dim)

    def forward(self, x):
        return self.linear(x)
    
    
# Set up quantization configuration
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
)


class MultiModalDataset(Dataset):
    def __init__(self, data, tokenizer, image_dir="cached_images", max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.image_dir = image_dir
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_url = item[0]  
        question = item[1]    
        answer = item[2]   

        # Combine question and answer for context
        text = f"Question: {question} Answer: {answer}"

        # Tokenize the combined text
        encoding = self.tokenizer(text,
                                  truncation=True,
                                  max_length=self.max_length,
                                  padding='max_length',
                                  return_tensors="pt")

        # Load image from local cache
        image_filename = os.path.join(self.image_dir, image_url.split('/')[-1])
        image = Image.open(image_filename).convert("RGB") 
        
        image_embedding = torch.randn(512)  
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': encoding['input_ids'].squeeze(),  # Assuming you're training with input IDs as labels
            'image_embedding': image_embedding
        }
    
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")

# Set padding token
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Configure quantization
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0,
    llm_int8_has_fp16_weight=False,
)

# Load the model with 8-bit quantization
phi2_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-2",
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
)

# Resize token embeddings
phi2_model.resize_token_embeddings(len(tokenizer))

# Define Adapter class
class Adapter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.down = bnb.nn.Linear8bitLt(config.hidden_size, config.hidden_size // 8, has_fp16_weights=False, threshold=6.0)
        self.up = bnb.nn.Linear8bitLt(config.hidden_size // 8, config.hidden_size, has_fp16_weights=False, threshold=6.0)
        self.act = nn.ReLU()

    def forward(self, x):
        return self.up(self.act(self.down(x))) + x


# Define ImageProjection class
class ImageProjection(nn.Module):
    def __init__(self, image_embed_dim, hidden_size):
        super().__init__()
        self.linear = bnb.nn.Linear8bitLt(image_embed_dim, hidden_size, has_fp16_weights=False, threshold=6.0)

    def forward(self, x):
        if x.dim() == 3:
            x = x.view(x.size(0), -1)
        return self.linear(x)
    
    
# Enable gradient checkpointing
phi2_model.gradient_checkpointing_enable()

# Function to add adapters to the model
def add_adapters(model):
    if hasattr(model, 'model') and hasattr(model.model, 'layers'):
        for layer in model.model.layers:
            # Get the device of any parameter in the layer
            layer_device = next(layer.parameters()).device
            layer.adapter = Adapter(model.config).to(layer_device)

# Add adapters to the model
add_adapters(phi2_model)

# Add image projection layer
phi2_model.image_projection = ImageProjection(512, phi2_model.config.hidden_size).to(device)

# Modify the forward method
def modified_forward(self, input_ids, attention_mask=None, image_embeddings=None):
    outputs = self.model(input_ids, attention_mask=attention_mask, use_cache=False)
    hidden_states = outputs.last_hidden_state

    if image_embeddings is not None:
        batch_size, seq_len, hidden_dim = hidden_states.shape
        projected_embeddings = self.image_projection(image_embeddings)
        projected_embeddings = projected_embeddings.unsqueeze(1).expand(-1, seq_len, -1)
        hidden_states = torch.cat([projected_embeddings, hidden_states], dim=-1)
        if hidden_states.shape[-1] != hidden_dim:
            hidden_states = bnb.nn.Linear8bitLt(hidden_states.shape[-1], hidden_dim, has_fp16_weights=False, threshold=6.0).to(hidden_states.device)(hidden_states)

    for layer in self.model.layers:
        if hasattr(layer, 'adapter'):
            hidden_states = layer.adapter(hidden_states)

    lm_logits = self.lm_head(hidden_states)
    return lm_logits

phi2_model.forward = types.MethodType(modified_forward, phi2_model)

# Function to save model checkpoint
# Function to save model checkpoint
def save_model(model, optimizer, epoch, file_path):
    # Check if file_path is not empty
    if not file_path:
        raise ValueError("File path cannot be empty")
    
    # Ensure directory exists
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    
    # Save the model state
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, file_path)
    print(f"Model checkpoint saved at {file_path}")
    
    
# Training function
def train(model, train_loader, optimizer, epochs, accumulation_steps=4):
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()

        try:
            for i, batch in enumerate(train_loader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                image_embeddings = batch['image_embedding'].to(device)

                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    image_embeddings=image_embeddings)

                    loss = nn.functional.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1))
                    loss = loss / accumulation_steps

                loss.backward()

                if (i + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    optimizer.zero_grad()

                total_loss += loss.item() * accumulation_steps
                print('I ', i)

                # Free up memory after each batch
                del input_ids, attention_mask, labels, image_embeddings, outputs, loss
                torch.cuda.empty_cache()

            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss:.4f}")

            # Save model checkpoint after each epoch
            save_model(model, optimizer, epoch + 1, f"model_checkpoints/model_checkpoint_epoch_{epoch + 1}.pt")

        except Exception as e:
            print(f"Error during epoch {epoch + 1}: {e}")
            # Save model if an error occurs (useful in long training processes)
            save_model(model, optimizer, epoch + 1, f"model_checkpoints/error_checkpoint_epoch_{epoch + 1}.pt")
            raise  # Optionally, re-raise the error to stop the process

    # Save the final model after all epochs
    save_model(model, optimizer, epochs, "model_checkpoints/final_model.pt")
    
    
dataset = MultiModalDataset(data_instruct150_flatten, tokenizer)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

# Update the optimizer
optimizer = bnb.optim.Adam8bit(phi2_model.parameters(), lr=1e-5, betas=(0.9, 0.999), eps=1e-8)

# Train the model
train(phi2_model, train_loader, optimizer, epochs=3, accumulation_steps=4)

True
0 processed
10000 processed
20000 processed
30000 processed
40000 processed
50000 processed
60000 processed
70000 processed
80000 processed
90000 processed
100000 processed
110000 processed
120000 processed
130000 processed
140000 processed
150000 processed


Downloading images: 100%|██████████| 199770/199770 [00:00<00:00, 316668.19it/s]


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

I  0
I  1
I  2
I  3
I  4
I  5
I  6
I  7
I  8
I  9
I  10
I  11
I  12
I  13
I  14
I  15
I  16
I  17
I  18
I  19
I  20
I  21
I  22
I  23
I  24
I  25
I  26
I  27
I  28
I  29
I  30
I  31
I  32
I  33
I  34
I  35
I  36
I  37
I  38
I  39
I  40
I  41
I  42
I  43
I  44
I  45
I  46
I  47
I  48
I  49
I  50
I  51
I  52
I  53
I  54
I  55
I  56
I  57
I  58
I  59
I  60
I  61
I  62
I  63
I  64
I  65
I  66
I  67
I  68
I  69
I  70
I  71
I  72
I  73
I  74
I  75
I  76
I  77
I  78
I  79
I  80
I  81
I  82
I  83
I  84
I  85
I  86
I  87
I  88
I  89
I  90
I  91
I  92
I  93
I  94
I  95
I  96
I  97
I  98
I  99
I  100
I  101
I  102
I  103
I  104
I  105
I  106
I  107
I  108
I  109
I  110
I  111
I  112
I  113
I  114
I  115
I  116
I  117
I  118
I  119
I  120
I  121
I  122
I  123
I  124
I  125
I  126
I  127
I  128
I  129
I  130
I  131
I  132
I  133
I  134
I  135
I  136
I  137
I  138
I  139
I  140
I  141
I  142
I  143
I  144
I  145
I  146
I  147
I  148
I  149
I  150
I  151
I  152
I  153
I  154
I  155
I  156
I  157
I  1