In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import importlib
from torch.utils.data import DataLoader

from config import Blip2Config
from dataset import Blip2Dataset
from tokenizer import Blip2Tokenizer

In [None]:
import config, dataset, tokenizer
importlib.reload(config)
importlib.reload(dataset)
importlib.reload(tokenizer)

from config import Blip2Config
from dataset import Blip2Dataset
from tokenizer import Blip2Tokenizer

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
config = Blip2Config()

In [None]:
blip2_tokenizer = Blip2Tokenizer(config, tokenizer)
config.vocab_size = blip2_tokenizer.n_vocab
train_dataset = Blip2Dataset(config, split="train", tokenizer=blip2_tokenizer.tokenize_text)
train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

In [None]:
import copy
import timm
from transformers import BertConfig, BertModel

class BertMLPBlock(nn.Module):
    def __init__(self, intermediate, output):
        super().__init__()
        self.intermediate = intermediate
        self.output = output

    def forward(self, x):
        intermediate_output = self.intermediate(x)
        return self.output(intermediate_output, x)
    

class BertEncoderBlock(nn.Module):
    def __init__(self, bert_layer, bert_config, is_cross_attn=False):
        super().__init__()
        self.bert_config = bert_config
        self.is_cross_attn = is_cross_attn
        self.self_attn = bert_layer.attention
        self.mlp_img_transformer = BertMLPBlock(bert_layer.intermediate, bert_layer.output)
        self.mlp_text_transformer = BertMLPBlock(
                    copy.deepcopy(bert_layer.intermediate), 
                    copy.deepcopy(bert_layer.output)
                    )
        if is_cross_attn:
            self.cross_attn = nn.MultiheadAttention(embed_dim=self.bert_config.hidden_size, 
                                                    num_heads=self.bert_config.num_attention_heads, 
                                                    batch_first=True)
            self.cross_layer_norm = nn.LayerNorm(self.bert_config.hidden_size)
        
    def forward(self, query_embds, img_embds, text_embds, attn_mask):
        _, Qs, _ = query_embds.shape
        _, Ts, _ = text_embds.shape

        combined_embds = torch.concat((query_embds, text_embds), dim=1) # B, Qs + Ts, D

        self_attn_output = self.self_attn(combined_embds, attention_mask=attn_mask)[0]
        query_embds = combined_embds[:, :Qs]
        text_embds= combined_embds[:, Qs:]
        
        if self.is_cross_attn:
            hidden_states = self.cross_attn(query_embds, img_embds, img_embds)[0]
            query_embds = self.cross_layer_norm(query_embds + hidden_states)

        query_embds = self.mlp_img_transformer(query_embds)
        text_embds = self.mlp_text_transformer(text_embds)
        return query_embds, text_embds


class QTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert_cfg  = BertConfig.from_pretrained("bert-base-uncased")
        self.bert_model = BertModel.from_pretrained("bert-base-uncased", config = self.bert_cfg)
        
        self.encoder = nn.ModuleList()
        for i, bert_layer in enumerate(self.bert_model.encoder.layer):
            self.encoder.append(BertEncoderBlock(bert_layer, self.bert_cfg, i % 2 == 0))
        
        qs = config.num_queries
        ts = config.context_length
        combined_seq_len = qs + ts

        # ITC Loss Mask
        itc_attn_mask = torch.zeros((combined_seq_len, combined_seq_len))
        itc_attn_mask[:qs, :qs] = 1
        itc_attn_mask[qs:, qs:] = 1

        # ITM Loss Mask
        itm_attn_mask = torch.ones((combined_seq_len, combined_seq_len))

        # ITG Loss Mask
        itg_attn_mask = torch.ones((combined_seq_len, combined_seq_len))
        itg_attn_mask[:qs, qs:] = 0
        itg_attn_mask[qs:, qs:] = torch.tril(itg_attn_mask[qs:, qs:], diagonal=0)

        self.register_buffer("itc_attn_mask", itc_attn_mask)
        self.register_buffer("itm_attn_mask", itm_attn_mask)
        self.register_buffer("itg_attn_mask", itg_attn_mask)
        
    def forward(self, query_embds, img_embds, cls_text_embds, dec_text_embds):

        itc_query_embds = query_embds.clone()
        itm_query_embds = query_embds.clone()
        itg_query_embds = query_embds.clone()

        itc_text_embds = cls_text_embds.clone()
        itm_text_embds = cls_text_embds.clone()
        itg_text_embds = dec_text_embds.clone()

        for encoder in self.encoder:
            itc_query_embds, itc_text_embds = encoder(itc_query_embds, img_embds, itc_text_embds, self.itc_attn_mask)
            itm_query_embds, itm_text_embds = encoder(itm_query_embds, img_embds, itm_text_embds, self.itm_attn_mask)
            itg_query_embds, itg_text_embds = encoder(itg_query_embds, img_embds, itg_text_embds, self.itg_attn_mask)
        
        return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_query_embds, itg_text_embds
    

class QFormer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.q_transformer = QTransformer(config)
        self.learned_query = nn.Parameter(torch.randn(config.num_queries, config.embedding_dim))
        self.output_embedding  = nn.Embedding(config.vocab_size, config.embedding_dim)
        self.position_embedding = nn.Embedding(config.context_length, config.embedding_dim)

        position_ids = torch.arange(self.config.context_length).unsqueeze(0)
        self.register_buffer("position_ids", position_ids)

    def forward(self, image_embedding: torch.tensor, cls_tokens: torch.tensor, dec_tokens: torch.tensor):
        B, S, E = image_embedding.shape
        learned_query = self.learned_query.unsqueeze(0).expand(B, -1, -1)

        cls_text_embeddings = self.output_embedding(cls_tokens) #(S,768)
        cls_text_embeddings = cls_text_embeddings + self.position_embedding(self.position_ids.expand(B, -1))
        dec_text_embeddings = self.output_embedding(dec_tokens) #(S,768)
        dec_text_embeddings = dec_text_embeddings + self.position_embedding(self.position_ids.expand(B, -1))

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_query_embds, itg_text_embds = self.q_transformer(
            learned_query, image_embedding, cls_text_embeddings, dec_text_embeddings)
        itg_logits = itg_text_embds @ self.output_embedding.weight.T # (S,Vocab_size)

        return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits


class Blip2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.image_encoder = timm.create_model('vit_tiny_patch16_224', pretrained=True)
        self.image_encoder.reset_classifier(0)
        self.image_proj = nn.Linear(config.img_embd_dim, config.embedding_dim)

        self.lm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
        self.q_former = QFormer(config)
        self.z_proj = nn.Linear(config.embedding_dim, config.lm_embedding_dim)
    
    def forward(self, image:torch.tensor, cls_caption:torch.tensor, dec_caption:torch.tensor, stage:int):
        image_embedding = self.image_encoder.forward_features(image)  # [B, C, F]
        image_embedding = self.image_proj(image_embedding)

        if(stage == 1):
            itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = self.q_former(image_embedding, cls_caption, dec_caption)
            return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits
        return self.q_former(image_embedding, cls_caption, dec_caption)

        

In [None]:
blip2_model = Blip2Model(config)
print(blip2_model)

In [None]:
device = torch.device('cuda')
img, cls_caption, dec_caption = next(iter(train_dataloader))

img = img.to(device)
cls_caption = cls_caption.to(device)
dec_caption = dec_caption.to(device)


model = blip2_model.to(device)

model.eval()
with torch.no_grad():
    itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = model(img, cls_caption, dec_caption, 1)

In [None]:
class ITCLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query_embds, text_embds):
        # query_embds: B, 32, d
        # text_embds: B, 77, d
        text_logit = text_embds[:, :1] # B, 1, d
        B, _, _ = text_logit.shape
        out = torch.zeros((B, B))
        for i in range(B):
            for j in range(B):
                out[i][j] = torch.max(query_embds[i] @ text_logit[j].T) # 32, 1 -> 1
        return out



tensor([[[4, 0],
         [6, 3]],

        [[9, 1],
         [3, 8]]]) tensor([[[9, 3]],

        [[5, 2]]])
tensor([[63., 36.],
        [84., 47.]])
