In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import importlib
from torch.utils.data import DataLoader
import os

from config import Blip2Config
from dataset import Blip2Dataset
from tokenizer import FlanT5Tokenizer, BertTokenizer

In [2]:
import config, dataset, tokenizer
importlib.reload(config)
importlib.reload(dataset)
importlib.reload(tokenizer)

from config import Blip2Config
from dataset import Blip2Dataset
from tokenizer import FlanT5Tokenizer, BertTokenizer

In [3]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
t5_model_name = "google/flan-t5-small"
bert_autotokenizer =  AutoTokenizer.from_pretrained("bert-base-uncased") 
t5_autotokenizer = AutoTokenizer.from_pretrained(t5_model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
config = Blip2Config()

In [5]:
flan_t5_tokenizer = FlanT5Tokenizer(config, t5_autotokenizer)
bert_tokenizer = BertTokenizer(config, bert_autotokenizer)
config.bert_vocab_size = bert_tokenizer.n_vocab
config.t5_vocab_size = flan_t5_tokenizer.n_vocab

stage1_train_dataset = Blip2Dataset(config, split="train", tokenizer=bert_tokenizer.tokenize_text, type="bert")
stage1_train_dataloader = DataLoader(stage1_train_dataset, batch_size=config.batch_size, shuffle=True)

stage2_train_dataset = Blip2Dataset(config, split="train", tokenizer=flan_t5_tokenizer.tokenize_text, type="flan_t5")
stage2_train_dataloader = DataLoader(stage2_train_dataset, batch_size=config.batch_size, shuffle=True)

Skipping non-image file: image
Skipping non-image file: image


In [9]:
import copy
import timm
from transformers import BertConfig, BertModel

class BertMLPBlock(nn.Module):
    def __init__(self, intermediate, output):
        super().__init__()
        self.intermediate = intermediate
        self.output = output

    def forward(self, x):
        intermediate_output = self.intermediate(x)
        return self.output(intermediate_output, x)
    

class BertEncoderBlock(nn.Module):
    def __init__(self, bert_layer, bert_config, is_cross_attn=False):
        super().__init__()
        self.bert_config = bert_config
        self.is_cross_attn = is_cross_attn
        self.self_attn = bert_layer.attention
        self.mlp_img_transformer = BertMLPBlock(bert_layer.intermediate, bert_layer.output)
        self.mlp_text_transformer = BertMLPBlock(
                    copy.deepcopy(bert_layer.intermediate), 
                    copy.deepcopy(bert_layer.output)
                    )
        if is_cross_attn:
            self.cross_attn = nn.MultiheadAttention(embed_dim=self.bert_config.hidden_size, 
                                                    num_heads=self.bert_config.num_attention_heads, 
                                                    batch_first=True)
            self.cross_layer_norm = nn.LayerNorm(self.bert_config.hidden_size)
        
    def forward(self, query_embds, img_embds, text_embds, attn_mask):
        _, Qs, _ = query_embds.shape
        _, Ts, _ = text_embds.shape

        combined_embds = torch.concat((query_embds, text_embds), dim=1) # B, Qs + Ts, D

        self_attn_output = self.self_attn(combined_embds, attention_mask=attn_mask)[0]
        query_embds = combined_embds[:, :Qs]
        text_embds= combined_embds[:, Qs:]
        
        if self.is_cross_attn:
            hidden_states = self.cross_attn(query_embds, img_embds, img_embds)[0]
            query_embds = self.cross_layer_norm(query_embds + hidden_states)

        query_embds = self.mlp_img_transformer(query_embds)
        text_embds = self.mlp_text_transformer(text_embds)
        return query_embds, text_embds


class QTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert_cfg  = BertConfig.from_pretrained("bert-base-uncased")
        self.bert_model = BertModel.from_pretrained("bert-base-uncased", config = self.bert_cfg)
        
        self.encoder = nn.ModuleList()
        for i, bert_layer in enumerate(self.bert_model.encoder.layer):
            self.encoder.append(BertEncoderBlock(bert_layer, self.bert_cfg, i % 2 == 0))
        
        qs = config.num_queries
        ts = config.context_length
        combined_seq_len = qs + ts

        ####  STAGE 1: ITC, ITM, ITG Loss Masks ####
        # ITC Loss Mask
        itc_attn_mask = torch.zeros((combined_seq_len, combined_seq_len))
        itc_attn_mask[:qs, :qs] = 1
        itc_attn_mask[qs:, qs:] = 1

        # ITM Loss Mask
        itm_attn_mask = torch.ones((combined_seq_len, combined_seq_len))

        # ITG Loss Mask
        itg_attn_mask = torch.ones((combined_seq_len, combined_seq_len))
        itg_attn_mask[:qs, qs:] = 0
        itg_attn_mask[qs:, qs:] = torch.tril(itg_attn_mask[qs:, qs:], diagonal=0)

        self.register_buffer("itc_attn_mask", itc_attn_mask)
        self.register_buffer("itm_attn_mask", itm_attn_mask)
        self.register_buffer("itg_attn_mask", itg_attn_mask)

        ####  STAGE 2: ####
        # ITC Loss Mask will be same as stage 1 and reused for stage 2

    def forward(self, query_embds, img_embds, cls_text_embds, dec_text_embds, stage):

        itc_query_embds = query_embds.clone()
        itm_query_embds = query_embds.clone()
        itg_query_embds = query_embds.clone()

        itc_text_embds = cls_text_embds.clone()
        itm_text_embds = cls_text_embds.clone()
        itg_text_embds = dec_text_embds.clone()


        for encoder in self.encoder:
            itc_query_embds, itc_text_embds = encoder(itc_query_embds, img_embds, itc_text_embds, self.itc_attn_mask)
            if stage == 1:
                itm_query_embds, itm_text_embds = encoder(itm_query_embds, img_embds, itm_text_embds, self.itm_attn_mask)
                itg_query_embds, itg_text_embds = encoder(itg_query_embds, img_embds, itg_text_embds, self.itg_attn_mask)
        return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_query_embds, itg_text_embds
    

class QFormer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.q_transformer = QTransformer(config)
        self.learned_query = nn.Parameter(torch.randn(config.num_queries, config.embedding_dim))
        self.output_embedding  = nn.Embedding(config.bert_vocab_size, config.embedding_dim)
        self.position_embedding = nn.Embedding(config.context_length, config.embedding_dim)

        position_ids = torch.arange(self.config.context_length).unsqueeze(0)
        self.register_buffer("position_ids", position_ids)

    def forward(self, image_embedding: torch.tensor, cls_tokens: torch.tensor, dec_tokens: torch.tensor, stage:int):
        B, S, E = image_embedding.shape
        learned_query = self.learned_query.unsqueeze(0).expand(B, -1, -1)

        cls_text_embeddings = self.output_embedding(cls_tokens) #(S,768)
        cls_text_embeddings = cls_text_embeddings + self.position_embedding(self.position_ids.expand(B, -1))
        dec_text_embeddings = self.output_embedding(dec_tokens) #(S,768)
        dec_text_embeddings = dec_text_embeddings + self.position_embedding(self.position_ids.expand(B, -1))

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_query_embds, itg_text_embds = self.q_transformer(
            learned_query, image_embedding, cls_text_embeddings, dec_text_embeddings, stage)

        if itg_text_embds is not None:
            itg_logits = itg_text_embds @ self.output_embedding.weight.T # (S,Vocab_size)
        else:
            itg_logits = None

        return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits


class FlanT5Model(nn.Module):
    def __init__(self):
        super(FlanT5Model, self).__init__()
        self.lm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
        for param in self.lm_model.parameters():
            param.requires_grad = False

    def forward(self, query_embedding, input_token, label, enc_mask):
        #query_embd : (B,32,512)
        # input_token : (B,L)
        B, Q, d = query_embedding.shape
        device = query_embedding.device
        with torch.no_grad():
            input_embd = self.lm_model.encoder.embed_tokens(input_token)  #(B,L,512)

        encoder_input = torch.concat((query_embedding, input_embd) , dim = 1).contiguous()

        prefix_mask = torch.ones((B, Q ), dtype= enc_mask.dtype, device=device)
        attention_mask = torch.concat((prefix_mask, enc_mask) , dim=1).contiguous()  # [B, 32+L]
        label = label.contiguous()  # [B, L]
        out = self.lm_model(inputs_embeds=encoder_input,
                                attention_mask=attention_mask,
                                labels=label,
                                return_dict=True)
        return out
    

    def predict(self, query_embedding, input_token, enc_mask):
        B, Q, d = query_embedding.shape
        device = query_embedding.device
        with torch.no_grad():
            input_embd = self.lm_model.encoder.embed_tokens(input_token)  #(B,L,512)

        encoder_input = torch.concat((query_embedding, input_embd) , dim = 1)

        prefix_mask = torch.ones((B, Q ), dtype= enc_mask.dtype, device=device)
        attention_mask = torch.concat((prefix_mask, enc_mask) , dim=1)  # [B, 32+L]
        
        enc_out = self.lm_model.encoder(
            inputs_embeds=encoder_input,
            attention_mask=attention_mask,
            return_dict=True
            )

        gen_ids = self.lm_model.generate(
            encoder_outputs=enc_out,
            max_new_tokens=30,
            decoder_start_token_id=self.lm_model.config.decoder_start_token_id,
            attention_mask=attention_mask,
        )

        return gen_ids


class Blip2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.image_encoder = timm.create_model('vit_tiny_patch16_224', pretrained=True)
        self.image_encoder.reset_classifier(0)

        for param in self.image_encoder.parameters():
            param.requires_grad = False

        self.image_proj = nn.Linear(config.img_embd_dim, config.embedding_dim)

        self.q_former = QFormer(config)
        self.z_proj = nn.Linear(config.embedding_dim, config.lm_embedding_dim)

        self.lm_model = FlanT5Model()
    
    
    def stage1(self, image:torch.tensor, cls_caption:torch.tensor, dec_caption:torch.tensor):
        image_embedding = self.image_encoder.forward_features(image)  # [B, C, F]
        image_embedding = self.image_proj(image_embedding)

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = self.q_former(image_embedding, cls_caption, dec_caption, 1)
        return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits
    
    
    def stage2(self, image, input_token, label, enc_mask, dummy_input_size):
        image_embedding = self.image_encoder.forward_features(image)  # [B, C, F]
        image_embedding = self.image_proj(image_embedding)
        
        cls_caption_dummy = torch.zeros(dummy_input_size, dtype=torch.long, device = image.device)
        dec_caption_dummy = torch.zeros(dummy_input_size, dtype=torch.long, device = image.device)
        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = self.q_former(image_embedding, 
                                                            cls_caption_dummy, dec_caption_dummy, 2)
        
        z = self.z_proj(itc_query_embds)  # [B, Qs, D]

        out = self.lm_model(z, input_token, label, enc_mask)
            
        return out
    
    def predict(self, image, input_token, enc_mask, dummy_input_size):
        image_embedding = self.image_encoder.forward_features(image)  # [B, C, F]
        image_embedding = self.image_proj(image_embedding)

        cls_caption_dummy = torch.zeros(dummy_input_size, dtype=torch.long, device = image.device)
        dec_caption_dummy = torch.zeros(dummy_input_size, dtype=torch.long, device = image.device)

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = self.q_former(image_embedding, 
                                                            cls_caption_dummy, dec_caption_dummy, 2)
        
        z = self.z_proj(itc_query_embds)

        gen_ids = self.lm_model.predict(z, input_token, enc_mask)
        return gen_ids

In [10]:
model = Blip2Model(config)
print(model)

Blip2Model(
  (image_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=192, out_features=576, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (proj): Linear(in_features=192, out_features=192, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=192, out_features=768, bias=True)
        

In [8]:
class ITCLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query_embds, text_embds):
        # query_embds: B, 32, d
        # text_embds: B, 77, d
        text_logit = text_embds[:, :1] # B, 1, d
        B, _, _ = text_logit.shape
        B, Qs, d = query_embds.shape 
        query_embds = query_embds.reshape(B * Qs, d)
        text_embds = text_logit.squeeze()
        logits = query_embds @ text_embds.T   # B*Qs,B
        logits = torch.max(logits.reshape(B,Qs,B),dim=1)[0] # B,B
        label = torch.arange(B,device=query_embds.device)
        return (F.cross_entropy(logits,label)+ F.cross_entropy(logits.T,label)) / 2, logits


class ITMLoss(nn.Module):
    def __init__(self, config):
        super().__init__()
        d = config.embedding_dim
        self.classification_layer = nn.Linear(d,2)
    
    def forward(self, query_embd, label):
        # query_embd --> (B,32,768)
        #label ->(B,1) B x [0/1]

        match_logit = self.classification_layer(query_embd) #(B,32,2)
        match_logit = match_logit.mean(dim=1)
        return F.cross_entropy(match_logit,label)


class ITGLoss(nn.Module):
    def __init__(self, pad_token_id):
        super().__init__()
        self.pad_token_id = pad_token_id

    def forward(self, itg_logits, label_token):
        #itg_logits -> B,S,vocab size
        #label_token -> B,S
        B, S, V = itg_logits.shape
        loss = F.cross_entropy(
            itg_logits.view(B * S, V),
            label_token.view(B * S),
            ignore_index=self.pad_token_id
        )
        return loss


In [None]:
##########################   STAGE 1    #####################################

device = torch.device('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
num_epochs = 30

itc_loss_func = ITCLoss()
itm_loss_func = ITMLoss(config)
itg_loss_func = ITGLoss(bert_tokenizer.pad_token_id)

itc_loss_func = itc_loss_func.to(device)
itm_loss_func = itm_loss_func.to(device)
itg_loss_func = itg_loss_func.to(device)

model = model.to(device)
model.train()

for epoch in range(num_epochs):
    iteration = 0
    print(f"***************   Epoch {epoch + 1}  ***************")
    for img, cls_caption, dec_caption in stage1_train_dataloader:
        img = img.to(device)
        cls_caption = cls_caption.to(device)
        dec_caption = dec_caption.to(device)
        B, _, _, _ = img.shape

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = model.stage1(img, cls_caption, dec_caption)

        # ITC Loss
        itc_loss, itc_logits = itc_loss_func(itc_query_embds, itc_text_embds)

        # ITM Loss
        idx = torch.arange(B,device = device)
        itc_logits[idx,idx] = -1e9
        next_best_text_value , next_best_text_idx = torch.max(itc_logits,dim=1)
        mismatched_cls_caption = cls_caption[next_best_text_idx]
        mismatched_dec_caption = dec_caption[next_best_text_idx]

        _,_,mismatched_itm_query_embeds,_,_ = model.stage1(img, mismatched_cls_caption, mismatched_dec_caption)

        itm_query_embed_concatenated = torch.concat((itm_query_embds, mismatched_itm_query_embeds) ,dim=0 )
        itm_labels = torch.zeros(2 * B, dtype=torch.long).to(device)
        itm_labels[B:] = 1
        itm_loss = itm_loss_func(itm_query_embed_concatenated, itm_labels)

        # ITG Loss
        itg_labels = torch.concat((dec_caption[:, 1:], dec_caption[:, -1].unsqueeze(1)), dim=1)
        itg_loss = itg_loss_func(itg_logits, itg_labels)


        total_loss = itc_loss + itm_loss + itg_loss
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        print(f"Epoch {epoch + 1} : Iter [{iteration} / {len(stage1_train_dataloader)}]")
        print(f"Total Loss: {total_loss.item()}")
        print(f"ITC Loss: {itc_loss}, ITM Loss: {itm_loss}, ITG Loss: {itg_loss}")
        print("" + "*" * 50)
        iteration += 1

    torch.save(model.state_dict(), "q_former.pt")


***************   Epoch 1  ***************
Epoch 1 : Iter [0 / 506]
Total Loss: 102.34361267089844
ITC Loss: 40.18841552734375, ITM Loss: 0.7101684808731079, ITG Loss: 61.44502639770508
**************************************************
Epoch 1 : Iter [1 / 506]
Total Loss: 58.904991149902344
ITC Loss: 8.537028312683105, ITM Loss: 0.702996015548706, ITG Loss: 49.66496658325195
**************************************************
Epoch 1 : Iter [2 / 506]
Total Loss: 48.40378952026367
ITC Loss: 5.037400245666504, ITM Loss: 0.6962429285049438, ITG Loss: 42.67014694213867
**************************************************
Epoch 1 : Iter [3 / 506]
Total Loss: 44.332603454589844
ITC Loss: 6.2193403244018555, ITM Loss: 0.6926547288894653, ITG Loss: 37.42060852050781
**************************************************
Epoch 1 : Iter [4 / 506]
Total Loss: 41.41795349121094
ITC Loss: 4.824690818786621, ITM Loss: 0.6971796154975891, ITG Loss: 35.89608383178711
****************************************

In [None]:
##########################   STAGE 2    #####################################

MODEL_PATH = "q_former.pt"
if os.path.exists(MODEL_PATH):
    model  = torch.load(MODEL_PATH)

device = torch.device('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
num_epochs = 30

model = model.to(device)
model.train()

for epoch in range(num_epochs):
    iteration = 0
    print(f"***************   Epoch {epoch + 1}  ***************")
    for img, input_caption, input_mask in stage2_train_dataloader:
        img = img.to(device)
        input_caption = input_caption.to(device).squeeze()
        input_mask = input_mask.to(device).squeeze()
        B, S = input_caption.shape

        out = model.stage2(img, input_caption[:,:2], input_caption[:,2:], input_mask[:,:2], (B, S))
        total_loss = out.loss
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        print(f"Epoch {epoch + 1} : Iter [{iteration} / {len(stage2_train_dataloader)}]")
        print(f"Total Loss: {total_loss.item()}")
        print("" + "*" * 50)
        iteration += 1

    torch.save(model.state_dict(), "blip2.pt")


In [18]:
from PIL import Image

device = torch.device('cuda')
model.load_state_dict(torch.load("blip2.pt"))
model = model.to(device)
model.eval()

image_path = "example.jpg"
image = Image.open(image_path).convert('RGB')
image = stage2_train_dataset.transform(image)
image = image.unsqueeze(0)  # Add batch dimension

input_caption = "What is in the image? Answer:"
input_tokens, input_mask = flan_t5_tokenizer.tokenize_text(input_caption)

image = image.to(device)
input_tokens = input_tokens.to(device)
input_mask = input_mask.to(device)
print(input_tokens.shape, input_mask.shape)

gen_ids = model.predict(image, input_tokens, input_mask, (input_tokens.shape[0], input_tokens.shape[1]))

decoded_output = flan_t5_tokenizer.decode(gen_ids[0])
print(decoded_output)



torch.Size([1, 77]) torch.Size([1, 77])
serrurier serrurier serrurier serrurier serrurier serrurierhütte drill serrurier Clickfunnel serrurier footsteps clock serrurier Clickfunnel serrurier footsteps clock serrurier footsteps clock serrurier footsteps clock serrurier footsteps clock serrurier footsteps clock
