In [49]:
import torch
import torch.nn as nn
from transformers import LlamaModel, LlamaTokenizer

class VisionLanguageModel(nn.Module):
    def __init__(self, qformer, llm_tokenizer, llm_model):
        super(VisionLanguageModel, self).__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # ViT model
        self.qformer = qformer
        for param in self.qformer.parameters():
            param.requires_grad = False

        # Llama model
        self.llm_model = llm_model
        for param in self.llm_model.parameters():
            param.requires_grad = False

        # Linear layer
        embedding_size = self.llm_model.config.hidden_size
        # self.linear = nn.Linear(256, embedding_size)
        self.linear = ProcessingLayer(256, embedding_size)

        # Tokenizer
        self.llm_tokenizer = llm_tokenizer

        # Move entire model to device, instead of moving parts individually
        self.to(self.device)

        # Embeddings for <image> and </image>
        self.image_start_token_id = self.llm_tokenizer.encode('<image>', add_special_tokens=False)[0]
        self.image_end_token_id = self.llm_tokenizer.encode('</image>', add_special_tokens=False)[0]

        # Create a prompt for GPT model
        prompt = "Describe this image: "
        # prompt = "this is a picture of "
        self.prompt_tokens = self.llm_tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        

    def forward(self, image, caption):
        batch_size = image.shape[0]

        # Embeddings for <image> and </image> and prompt
        image_start_embedding = self.llm_model.base_model.embed_tokens(torch.tensor([self.image_start_token_id], device=self.device)).repeat(batch_size, 1, 1)
        image_end_embedding = self.llm_model.base_model.embed_tokens(torch.tensor([self.image_end_token_id], device=self.device)).repeat(batch_size, 1, 1)
        prompt_embeddings = self.llm_model.base_model.embed_tokens(self.prompt_tokens).repeat(batch_size, 1, 1)

        # Ensure image is on the correct device
        image = image.to(self.device)

        # Get image features
        features_image = self.qformer.extract_features({'image':image}, mode="image")

        # Pass through linear layer
        linear_output = self.linear(features_image.image_embeds_proj)
        # print('linear_output.shape: ', linear_output.shape)

        # Concatenate the prompt, image start embedding, image features, and image end embedding
        cond_embedding = torch.cat([prompt_embeddings, image_start_embedding, linear_output, image_end_embedding], dim=1)

        # Prepare target token ids for LLM model
        self.llm_tokenizer.padding_side = "right"
        self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
        regression_target = self.llm_tokenizer(caption, return_tensors="pt", padding='longest', truncation=True, max_length=32, add_special_tokens=False)
        regression_target_ids = regression_target.input_ids
        regression_attention_mask = regression_target.attention_mask
        part_targets = regression_target_ids.masked_fill(regression_target_ids == self.llm_tokenizer.eos_token_id, -100)

        regression_embs = self.llm_model.base_model.embed_tokens(regression_target_ids.to(self.device))

        # concat the embedding to condition and the embedding to regress
        cat_embs = torch.cat([cond_embedding, regression_embs], dim=1)
        cat_att = torch.cat([torch.ones((batch_size, cond_embedding.shape[1]), dtype=torch.long), regression_attention_mask], dim=1).to(self.device)

        # get bos token embedding
        bos_token_id = self.llm_tokenizer.bos_token_id
        bos_token_emb = self.llm_model.base_model.embed_tokens(torch.tensor([bos_token_id], device=self.device)).repeat(batch_size, 1, 1)
        bos_att = torch.ones((batch_size, 1), dtype=torch.long, device=self.device)

        # concat bos token embedding to condition and attention mask
        cat_embs = torch.cat([bos_token_emb, cat_embs], dim=1)
        cat_att = torch.cat([bos_att, cat_att], dim=1)
        # print('cat_embs.shape: ', cat_embs.shape)

        # final target token ids
        targets = torch.ones((batch_size, cat_embs.shape[1]), dtype=torch.long).fill_(-100)
        targets[:, 42:] = part_targets
        targets[:, 0] = bos_token_id

        self.llm_model.eval()
        outputs = self.llm_model(inputs_embeds=cat_embs, attention_mask=cat_att, labels=targets)

        return outputs

    def generate(self, image):
        batch_size = image.shape[0]
        prompt_embeddings = self.llm_model.base_model.embed_tokens(self.prompt_tokens).repeat(batch_size, 1, 1)
        image_start_embedding = self.llm_model.base_model.embed_tokens(torch.tensor([self.image_start_token_id], device=self.device)).repeat(batch_size, 1, 1)
        image_end_embedding = self.llm_model.base_model.embed_tokens(torch.tensor([self.image_end_token_id], device=self.device)).repeat(batch_size, 1, 1)

        image = image.to(self.device)
        features_image = self.qformer.extract_features({'image':image}, mode="image")
        with torch.no_grad():
            linear_output = self.linear(features_image.image_embeds_proj)
        image_embedding = torch.cat([prompt_embeddings, image_start_embedding, linear_output, image_end_embedding], dim=1)
        attention_mask = torch.ones((batch_size, image_embedding.shape[1]), dtype=torch.long).to(self.device)
        self.llm_model.eval()
        outputs = self.llm_model.generate(inputs_embeds=image_embedding, attention_mask = attention_mask, max_length=50, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1)
        outputs = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return outputs

class ProcessingLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProcessingLayer, self).__init__()
        self.linear1 = nn.Linear(input_dim, 512)
        self.activation1 = nn.ReLU()
        self.linear2 = nn.Linear(512, 1024)
        self.activation2 = nn.ReLU()
        self.linear3 = nn.Linear(1024, output_dim)
        self.layernorm = nn.LayerNorm(output_dim)
        self.activation3 = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)
        x = self.activation2(x)
        x = self.linear3(x)
        x = self.layernorm(x)
        x = self.activation3(x)
        return x

In [2]:
from transformers import LlamaForCausalLM, LlamaTokenizer
from lavis.models import load_model_and_preprocess
import torch

checkpoint_path = "meta-llama/Llama-2-7b-hf"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

llm_model = LlamaForCausalLM.from_pretrained(checkpoint_path)
llm_tokenizer = LlamaTokenizer.from_pretrained(checkpoint_path)

qformer, vis_processors, txt_processors = load_model_and_preprocess(name="blip2_feature_extractor", model_type="pretrain", is_eval=True, device=device)




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# Add new tokens
new_tokens = ['<image>', '</image>']
num_added_toks = llm_tokenizer.add_tokens(new_tokens)

print("We have added", num_added_toks, "tokens")

# Resize position embeddings matrix
llm_model.resize_token_embeddings(len(llm_tokenizer))

# Check if new tokens are recognized
print(llm_tokenizer.encode('<image>'))
print(llm_tokenizer.encode('</image>'))

We have added 2 tokens
[1, 29871, 32000]
[1, 29871, 32001]


In [50]:

model = VisionLanguageModel(qformer, llm_tokenizer, llm_model)
model.linear.load_state_dict(torch.load('linear_layer_state_dict.pth'))

<All keys matched successfully>

In [5]:
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import ViTImageProcessor, GPT2Tokenizer

class CustomDataset(Dataset):
    def __init__(self, json_file, image_dir, transform=None):
        with open(json_file, 'r') as f:
            self.data = json.load(f)["annotations"]
        
        self.image_dir = image_dir
        self.vis_processors = vis_processors

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get image
        image_path = f"{self.image_dir}/{self.data[idx]['image_id']}.jpg"
        images = Image.open(image_path).convert("RGB").resize((224, 224))  # Read image and resize to 224x224
        image_tensor = self.vis_processors["eval"](images)
        # Get caption
        caption = self.data[idx]['caption']
        return {'image': image_tensor, 'caption': caption}

In [6]:
# Instantiate dataset
dataset = CustomDataset(json_file='cc_sbu_align/filter_cap.json', image_dir='cc_sbu_align/image')

# DataLoader for training, with batching and shuffling
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [46]:
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from accelerate import Accelerator

learning_rate = 1e-3
# optimizer with weight decay
optimizer = AdamW(model.linear.parameters(), lr=learning_rate, weight_decay=0.01)
loss_function = CrossEntropyLoss(ignore_index=-100)

# Training settings
num_epochs = 20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Train mode
# model.train()

accelerator = Accelerator(mixed_precision='fp16')

model, optimizer, dataloader = accelerator.prepare(
    model, optimizer, dataloader
    )


In [52]:
from PIL import Image

raw_image1 = Image.open("dog.jpg").convert("RGB")
raw_image2 = Image.open("merlion.png").convert("RGB")
image1 = vis_processors["eval"](raw_image1).unsqueeze(0).to(device)
image2 = vis_processors["eval"](raw_image2).unsqueeze(0).to(device)
input_test = torch.cat([image1, image2], dim=0)
answer = model.generate(input_test)
for i in range(len(answer)):
    print('Generated output:', answer[i])

Generated output: This image shows a dachshund wiper wiping the dust on a car during a race. The wiper is orange and black and has a curved blade that appears to be made of metal. The car is
Generated output: The image shows a statue of a Greek goddess, possibly Athena, made of gold and standing on a hill overlooking a city at night. The city is illuminated by bright lights and there are people walking around in the foreground


In [44]:

# Training loop
for epoch in range(num_epochs):
    total_loss = 0 # For monitoring training loss
    for i, batch in enumerate(dataloader):  # Assuming dataloader is already defined
        optimizer.zero_grad()

        images = batch['image'].to(device)
        # print(images)
        captions = batch['caption']

        # Forward pass
        output = model(images, captions)
        loss = output.loss
        
        # Backpropagate, optimizer step
        # loss.backward()
        accelerator.backward(loss)
        optimizer.step()

        total_loss += loss.item()

        if i % 10 == 0:  # Print loss every 10 batches for monitoring
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}")
            # print a sample output
            output_generate = model.generate(torch.cat([image1, image2], dim=0))
            for i in range(len(output_generate)):
                output_generate[i] = output_generate[i].replace('\n', ' ')
                print('Output_generate: ', output_generate[i])
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss}")
    # Save the state dictionary of the linear layer
    torch.save(model.linear.state_dict(), 'linear_layer_state_dict.pth')


Epoch [1/20], Step [1/108], Loss: 0.1691
Output_generate:  This image is a yellow and black Dachsh lawnmower with a yellow and black body and orange wheels. The lawnmower has a bright orange blade that matches the body color. The mower is running on a green lawn, indicating that it is being used to maintain the grass. A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A: A:
Output_generate:  The image shows a statue of a lion in the middle of a fountain in front of a large building in a city. The fountain appears to be made of gold or silver and the statue of the lion is also made of gold or silver. The building in the background appears to be made of glass or reflective material and has a large dome on top. The image appears to be taken at night, as the lights from the city can be seen in the background. The lion statue appears to be watching over the city and the fountain below. What is the main focus of the


KeyboardInterrupt: 