In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
from torch.utils.data import DataLoader
import os
!pip install matplotlib

from config import Blip2Config
from dataset import Blip2Dataset
from tokenizer import FlanT5Tokenizer, BertTokenizer



In [2]:
print(torch.cuda.is_available())

import config, dataset, tokenizer
importlib.reload(config)
importlib.reload(dataset)
importlib.reload(tokenizer)

from config import Blip2Config
from dataset import Blip2Dataset
from tokenizer import FlanT5Tokenizer, BertTokenizer

True


In [3]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
t5_model_name = "google/flan-t5-small"
bert_autotokenizer =  AutoTokenizer.from_pretrained("bert-base-uncased") 
t5_autotokenizer = AutoTokenizer.from_pretrained(t5_model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
config = Blip2Config()

In [5]:
flan_t5_tokenizer = FlanT5Tokenizer(config, t5_autotokenizer)
bert_tokenizer = BertTokenizer(config, bert_autotokenizer)
config.bert_vocab_size = bert_tokenizer.n_vocab
config.t5_vocab_size = flan_t5_tokenizer.n_vocab

stage1_train_dataset = Blip2Dataset(config, split="train", tokenizer=bert_tokenizer.tokenize_text, type="bert")
stage1_train_dataloader = DataLoader(stage1_train_dataset, batch_size=config.batch_size, shuffle=True)

stage2_train_dataset = Blip2Dataset(config, split="train", tokenizer=flan_t5_tokenizer.tokenize_text, type="flan_t5")
stage2_train_dataloader = DataLoader(stage2_train_dataset, batch_size=config.batch_size, shuffle=True)

Skipping non-image file: image
Skipping non-image file: image


In [6]:
def to_additive_mask(mask01: torch.Tensor, *, device, dtype):
    """
    Convert 0/1 'allow mask' (1=allow, 0=block) of shape (L,L)
    into additive (L,L) with 0.0 for allow, -inf for block.
    """
    m = mask01.to(device=device, dtype=dtype)
    # where allow(1) -> 0.0, block(0) -> -inf
    return torch.where(m > 0, torch.zeros_like(m), torch.full_like(m, float("-inf")))

In [7]:
import copy
import timm
from transformers import BertConfig, BertModel

class BertMLPBlock(nn.Module):
    def __init__(self, intermediate, output):
        super().__init__()
        self.intermediate = intermediate
        self.output = output

    def forward(self, x):
        intermediate_output = self.intermediate(x)
        return self.output(intermediate_output, x)
    

class BertEncoderBlock(nn.Module):
    def __init__(self, bert_layer, bert_config, is_cross_attn=False):
        super().__init__()
        self.bert_config = bert_config
        self.is_cross_attn = is_cross_attn

        d = bert_config.hidden_size
        h = bert_config.num_attention_heads
        self.self_attn = nn.MultiheadAttention(d, h, batch_first=True)
        self.self_ln = nn.LayerNorm(d)

        self.mlp_img_transformer = BertMLPBlock(bert_layer.intermediate, bert_layer.output)
        self.mlp_text_transformer = BertMLPBlock(
                    copy.deepcopy(bert_layer.intermediate), 
                    copy.deepcopy(bert_layer.output)
                    )
        if is_cross_attn:
            self.cross_attn = nn.MultiheadAttention(embed_dim=self.bert_config.hidden_size, 
                                                    num_heads=self.bert_config.num_attention_heads, 
                                                    batch_first=True)
            self.cross_layer_norm = nn.LayerNorm(self.bert_config.hidden_size)
        
    def forward(self, query_embds, img_embds, text_embds, attn_mask):
        _, Qs, _ = query_embds.shape
        _, Ts, _ = text_embds.shape

        combined_embds = torch.concat((query_embds, text_embds), dim=1) # B, Qs + Ts, D

        attn_out, _ = self.self_attn(
            combined_embds, combined_embds, combined_embds,
            attn_mask=attn_mask,         # (L, L), broadcast over batch & heads
            key_padding_mask=None     # (B, L) bool, True=mask (optional)
        )
        combined_embds = self.self_ln(combined_embds+ attn_out)

        query_embds = combined_embds[:, :Qs]
        text_embds= combined_embds[:, Qs:]
        
        if self.is_cross_attn:
            hidden_states = self.cross_attn(query_embds, img_embds, img_embds)[0]
            query_embds = self.cross_layer_norm(query_embds + hidden_states)

        query_embds = self.mlp_img_transformer(query_embds)
        text_embds = self.mlp_text_transformer(text_embds)
        return query_embds, text_embds


class QTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert_cfg  = BertConfig.from_pretrained("bert-base-uncased")
        self.bert_model = BertModel.from_pretrained("bert-base-uncased", config = self.bert_cfg)
        print("num of bert layers ",self.bert_model.encoder.layer)
        self.encoder = nn.ModuleList()
        for i, bert_layer in enumerate(self.bert_model.encoder.layer):
            if(i>6):
                break
            self.encoder.append(BertEncoderBlock(bert_layer, self.bert_cfg, i % 2 == 0))
        
        qs = config.num_queries
        ts = config.context_length
        combined_seq_len = qs + ts

        ####  STAGE 1: ITC, ITM, ITG Loss Masks ####
        # ITC Loss Mask
        itc_attn_mask = torch.zeros((combined_seq_len, combined_seq_len))
        itc_attn_mask[:qs, :qs] = 1
        itc_attn_mask[qs:, qs:] = 1
        

        # ITM Loss Mask
        itm_attn_mask = torch.ones((combined_seq_len, combined_seq_len))

        # ITG Loss Mask
        itg_attn_mask = torch.ones((combined_seq_len, combined_seq_len))
        itg_attn_mask[:qs, qs:] = 0
        itg_attn_mask[qs:, qs:] = torch.tril(itg_attn_mask[qs:, qs:], diagonal=0)


        self.register_buffer("itc_attn_mask", itc_attn_mask)
        self.register_buffer("itm_attn_mask", itm_attn_mask)
        self.register_buffer("itg_attn_mask", itg_attn_mask)

        ####  STAGE 2: ####
        # ITC Loss Mask will be same as stage 1 and reused for stage 2

    def forward(self, query_embds, img_embds, cls_text_embds, dec_text_embds, stage):

        itc_query_embds = query_embds.clone()
        itm_query_embds = query_embds.clone()
        itg_query_embds = query_embds.clone()

        itc_text_embds = cls_text_embds.clone()
        itm_text_embds = cls_text_embds.clone()
        itg_text_embds = dec_text_embds.clone()

        device = query_embds.device
        dtype  = query_embds.dtype

        # Convert base masks to additive for this device/dtype
        itc_add = to_additive_mask(self.itc_attn_mask, device=device, dtype=dtype)
        itm_add = to_additive_mask(self.itm_attn_mask, device=device, dtype=dtype)
        itg_add = to_additive_mask(self.itg_attn_mask, device=device, dtype=dtype)


        for encoder in self.encoder:
            itc_query_embds, itc_text_embds = encoder(itc_query_embds, img_embds, itc_text_embds, itc_add)
            if stage == 1:
                itm_query_embds, itm_text_embds = encoder(itm_query_embds, img_embds, itm_text_embds, itm_add)
                itg_query_embds, itg_text_embds = encoder(itg_query_embds, img_embds, itg_text_embds, itg_add)
        return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_query_embds, itg_text_embds
    

class QFormer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.q_transformer = QTransformer(config)
        self.learned_query = nn.Parameter(torch.randn(config.num_queries, config.embedding_dim))
        self.output_embedding  = nn.Embedding(config.bert_vocab_size, config.embedding_dim)
        self.position_embedding = nn.Embedding(config.context_length, config.embedding_dim)

        position_ids = torch.arange(self.config.context_length).unsqueeze(0)
        self.register_buffer("position_ids", position_ids)

    def forward(self, image_embedding: torch.tensor, cls_tokens: torch.tensor, dec_tokens: torch.tensor, stage:int):
        B, S, E = image_embedding.shape
        learned_query = self.learned_query.unsqueeze(0).expand(B, -1, -1)

        cls_text_embeddings = self.output_embedding(cls_tokens) #(S,768)
        cls_text_embeddings = cls_text_embeddings + self.position_embedding(self.position_ids.expand(B, -1))
        dec_text_embeddings = self.output_embedding(dec_tokens) #(S,768)
        dec_text_embeddings = dec_text_embeddings + self.position_embedding(self.position_ids.expand(B, -1))

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_query_embds, itg_text_embds = self.q_transformer(
            learned_query, image_embedding, cls_text_embeddings, dec_text_embeddings, stage)

        if itg_text_embds is not None:
            itg_logits = itg_text_embds @ self.output_embedding.weight.T # (S,Vocab_size)
        else:
            itg_logits = None

        return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits


class FlanT5Model(nn.Module):
    def __init__(self):
        super(FlanT5Model, self).__init__()
        self.lm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
        for param in self.lm_model.parameters():
            param.requires_grad = False

    def forward(self, query_embedding, input_token, label, enc_mask):
        #query_embd : (B,32,512)
        # input_token : (B,L)
        B, Q, d = query_embedding.shape
        device = query_embedding.device
        with torch.no_grad():
            input_embd = self.lm_model.encoder.embed_tokens(input_token)  #(B,L,512)

        encoder_input = torch.concat((query_embedding, input_embd) , dim = 1).contiguous()

        prefix_mask = torch.ones((B, Q ), dtype= enc_mask.dtype, device=device)
        attention_mask = torch.concat((prefix_mask, enc_mask) , dim=1).contiguous()  # [B, 32+L]
        label = label.contiguous()  # [B, L]
        out = self.lm_model(inputs_embeds=encoder_input,
                                attention_mask=attention_mask,
                                labels=label,
                                return_dict=True)
        return out
    

    def predict(self, query_embedding, input_token, enc_mask):
        B, Q, d = query_embedding.shape
        device = query_embedding.device
        with torch.no_grad():
            input_embd = self.lm_model.encoder.embed_tokens(input_token)  #(B,L,512)

        encoder_input = torch.concat((query_embedding, input_embd) , dim = 1)

        prefix_mask = torch.ones((B, Q ), dtype= enc_mask.dtype, device=device)
        attention_mask = torch.concat((prefix_mask, enc_mask) , dim=1)  # [B, 32+L]
        
        enc_out = self.lm_model.encoder(
            inputs_embeds=encoder_input,
            attention_mask=attention_mask,
            return_dict=True
            )

        gen_ids = self.lm_model.generate(
            encoder_outputs=enc_out,
            max_new_tokens=30,
            decoder_start_token_id=self.lm_model.config.decoder_start_token_id,
            attention_mask=attention_mask,
        )

        return gen_ids


class Blip2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.image_encoder = timm.create_model('vit_tiny_patch16_224', pretrained=True)
        self.image_encoder.reset_classifier(0)

        for param in self.image_encoder.parameters():
            param.requires_grad = False

        self.image_proj = nn.Linear(config.img_embd_dim, config.embedding_dim)

        self.q_former = QFormer(config)
        self.z_proj = nn.Linear(config.embedding_dim, config.lm_embedding_dim)

        self.lm_model = FlanT5Model()
    
    
    def stage1(self, image:torch.tensor, cls_caption:torch.tensor, dec_caption:torch.tensor):
        image_embedding = self.image_encoder.forward_features(image)  # [B, C, F]
        image_embedding = self.image_proj(image_embedding)

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = self.q_former(image_embedding, cls_caption, dec_caption, 1)
        return itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits
    
    
    def stage2(self, image, input_token, label, enc_mask, dummy_input_size):
        image_embedding = self.image_encoder.forward_features(image)  # [B, C, F]
        image_embedding = self.image_proj(image_embedding)
        
        cls_caption_dummy = torch.zeros(dummy_input_size, dtype=torch.long, device = image.device)
        dec_caption_dummy = torch.zeros(dummy_input_size, dtype=torch.long, device = image.device)
        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = self.q_former(image_embedding, 
                                                            cls_caption_dummy, dec_caption_dummy, 2)
        
        z = self.z_proj(itc_query_embds)  # [B, Qs, D]

        out = self.lm_model(z, input_token, label, enc_mask)
            
        return out
    
    def predict(self, image, input_token, enc_mask, dummy_input_size):
        image_embedding = self.image_encoder.forward_features(image)  # [B, C, F]
        image_embedding = self.image_proj(image_embedding)

        cls_caption_dummy = torch.zeros(dummy_input_size, dtype=torch.long, device = image.device)
        dec_caption_dummy = torch.zeros(dummy_input_size, dtype=torch.long, device = image.device)

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = self.q_former(image_embedding, 
                                                            cls_caption_dummy, dec_caption_dummy, 2)
        
        z = self.z_proj(itc_query_embds)

        gen_ids = self.lm_model.predict(z, input_token, enc_mask)
        return gen_ids



In [8]:
model = Blip2Model(config)
print(model)
bert = BertModel.from_pretrained("bert-base-uncased")
with torch.no_grad():
    model.q_former.output_embedding.weight.copy_(bert.embeddings.word_embeddings.weight)
    pe = bert.embeddings.position_embeddings.weight
    print(pe.shape)
    model.q_former.position_embedding.weight[:pe.size(0)].copy_(pe[:77])

import torch

def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

total, trainable = count_parameters(model)
print(f"Total parameters: {total:,}")
print(f"Trainable parameters: {trainable:,}")


#model size


num of bert layers  ModuleList(
  (0-11): 12 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=768, out_features=3072, bias=True)
      (intermediate_act_fn): GELUActivation()
    )
    (output): BertOutput(
      (dense): Linear(in_features=3072, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False

In [None]:
class ITCLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query_embds, text_embds):
        # query_embds: B, 32, d
        # text_embds: B, 77, d
        text_logit = text_embds[:, :1] # B, 1, d
        B, _, _ = text_logit.shape
        B, Qs, d = query_embds.shape 
        query_embds = query_embds.reshape(B * Qs, d)
        text_embds = text_logit.squeeze()
        logits = query_embds @ text_embds.T   # B*Qs,B
        logits = torch.max(logits.reshape(B,Qs,B),dim=1)[0] # B,B
        label = torch.arange(B,device=query_embds.device)
        return (F.cross_entropy(logits,label)+ F.cross_entropy(logits.T,label)) / 2, logits


class ITMLoss(nn.Module):
    def __init__(self, config):
        super().__init__()
        d = config.embedding_dim
        self.classification_layer = nn.Linear(d,2)
    
    def forward(self, query_embd, label):
        # query_embd --> (B,32,768)
        #label ->(B,1) B x [0/1]
        query_embd = query_embd.mean(dim=1)    #(B,768)
        match_logit = self.classification_layer(query_embd) #(B,2)
        return F.cross_entropy(match_logit,label)


class ITGLoss(nn.Module):
    def __init__(self, pad_token_id):
        super().__init__()
        self.pad_token_id = pad_token_id

    def forward(self, itg_logits, label_token):
        #itg_logits -> B,S,vocab size
        #label_token -> B,S
        B, S, V = itg_logits.shape
        loss = F.cross_entropy(
            itg_logits.view(B * S, V),
            label_token.view(B * S),
            ignore_index=self.pad_token_id
        )
        return loss


In [None]:
# --- put this near your imports ---
import math

def global_grad_norm(parameters):
    total = 0.0
    for p in parameters:
        if p.grad is None: 
            continue
        param_norm = p.grad.data.float().norm(2)
        total += param_norm.item() ** 2
    return math.sqrt(total)

def grad_stats(model, topk=5):
    stats = []

    for n,p in model.named_parameters():
        if p.grad is None: 
            continue
        g = p.grad.detach()
        abs_g = g.abs()
        stats.append({
            "name":n,
            "numel": g.numel(),
            "mean": float(abs_g.mean()),
            "median": float(abs_g.median()),
            "max": float(abs_g.max()),
            "min": float(abs_g.min()),
            "has_nan": bool(torch.isnan(g).any()),
            "has_inf": bool(torch.isinf(g).any()),
        })
    # sort by median magnitude (smallest first)
    stats.sort(key=lambda d: d["median"])
    # print a few smallest and a few largest
    print("\n=== Smallest gradient medians ===")
    for s in stats[:topk]:
        print(f"{s['name']:<60} med={s['median']:.3e} mean={s['mean']:.3e} max={s['max']:.3e} NaN={s['has_nan']} Inf={s['has_inf']}")
    print("\n=== Largest gradient medians ===")
    for s in stats[-topk:]:
        print(f"{s['name']:<60} med={s['median']:.3e} mean={s['mean']:.3e} max={s['max']:.3e} NaN={s['has_nan']} Inf={s['has_inf']}")

In [None]:
##########################   STAGE 1    #####################################
model.load_state_dict(torch.load("q_former.pt"))
device = torch.device('cuda')
itm_loss_func = ITMLoss(config)
itm_loss_func.load_state_dict(torch.load("itm_head.pt"))

optimizer = torch.optim.Adam(list(model.parameters()) + list(itm_loss_func.parameters()), lr = 1e-5)
num_epochs = 0

itc_loss_func = ITCLoss()
itg_loss_func = ITGLoss(bert_tokenizer.pad_token_id)

itc_loss_func = itc_loss_func.to(device)
itm_loss_func = itm_loss_func.to(device)
itg_loss_func = itg_loss_func.to(device)

model = model.to(device)
model.train()
grad_accum_steps = 64
for epoch in range(num_epochs):
    iteration = 0
    
    print(f"***************   Epoch {epoch + 1}  ***************")
    for img, cls_caption, dec_caption in (stage1_train_dataloader):
        img = img.to(device)
        cls_caption = cls_caption.to(device)
        dec_caption = dec_caption.to(device)
        B, _, _, _ = img.shape

        itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = model.stage1(img, cls_caption, dec_caption)

        # ITC Loss
        itc_loss, itc_logits = itc_loss_func(itc_query_embds, itc_text_embds)

        # ITM Loss
        idx = torch.arange(B,device = device)
        itc_logits[idx,idx] = -1e9
        next_best_text_value , next_best_text_idx = torch.min(itc_logits,dim=1)
        mismatched_cls_caption = cls_caption[next_best_text_idx]
        mismatched_dec_caption = dec_caption[next_best_text_idx]

        _,_,mismatched_itm_query_embeds,_,_ = model.stage1(img, mismatched_cls_caption, mismatched_dec_caption)

        itm_query_embed_concatenated = torch.concat((itm_query_embds, mismatched_itm_query_embeds) ,dim=0 )
        itm_labels = torch.zeros(2 * B, dtype=torch.long).to(device)
        itm_labels[B:] = 1
        itm_loss = itm_loss_func(itm_query_embed_concatenated, itm_labels)
        # ITG Loss
        itg_labels = torch.concat((dec_caption[:, 1:], dec_caption[:, -1].unsqueeze(1)), dim=1)
        itg_loss = itg_loss_func(itg_logits, itg_labels)


        total_loss = itg_loss 
        total_loss /= grad_accum_steps
        total_loss.backward()
        if (iteration+1) % grad_accum_steps == 0:
            max_norm = 0.1  # good starting point; try 0.5–2.0
            gn = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

            # (optional) also clip loss-heads if they live outside model, e.g., itm_loss_func
            torch.nn.utils.clip_grad_norm_(itm_loss_func.parameters(), max_norm)

            grad_stats(model, topk=6)   # per-parameter snapshot
            gn = global_grad_norm(model.parameters())
            reported_gn = gn * grad_accum_steps
            # print(f"Iter {iteration} : Grad Norm: {reported_gn}")
            optimizer.step()
            optimizer.zero_grad()


            print(f"Epoch {epoch + 1} : Iter [{iteration} / {len(stage1_train_dataloader)}]")
            print(f"Total Loss: {total_loss.item()}")
            print(f"ITC Loss: {itc_loss}, ITM Loss: {itm_loss.item()}, ITG Loss: {itg_loss}")
            print("" + "*" * 50)
        iteration += 1
        

        
        

    torch.save(model.state_dict(), "q_former.pt")
    torch.save(itm_loss_func.state_dict(), "itm_head.pt")


***************   Epoch 1  ***************

=== Smallest gradient medians ===
name                                                         med=0.000e+00 mean=2.814e-05 max=7.002e-04 NaN=False Inf=False
name                                                         med=5.454e-09 mean=7.916e-07 max=1.173e-03 NaN=False Inf=False
name                                                         med=1.489e-07 mean=5.275e-07 max=5.312e-05 NaN=False Inf=False
name                                                         med=1.582e-07 mean=6.927e-07 max=3.630e-05 NaN=False Inf=False
name                                                         med=1.724e-07 mean=5.880e-07 max=3.024e-05 NaN=False Inf=False
name                                                         med=1.805e-07 mean=5.444e-07 max=2.915e-05 NaN=False Inf=False

=== Largest gradient medians ===
name                                                         med=3.215e-05 mean=3.822e-05 max=1.538e-04 NaN=False Inf=False
name                

In [74]:
# state_dict = torch.load("q_former.pt", map_location="cpu")

# # Apply to your model
# model.load_state_dict(state_dict)

img, cls_caption, dec_caption  = next(iter(stage1_train_dataloader))
print(device)
img = img.to(device)
cls_caption = cls_caption.to(device)
dec_caption = dec_caption.to(device)
#randomly mask cls_caption and dec_caption

itc_query_embds, itc_text_embds, itm_query_embds, itm_text_embds, itg_logits = model.stage1(img, cls_caption, dec_caption)

itg_predicted_ids = torch.argmax(itg_logits, dim=-1)  # (B, S)

itg_predicted_texts = [
    bert_tokenizer.decode(ids) for ids in itg_predicted_ids.cpu().tolist()
]

print("ITG Predicted Texts:", [s.split('.') for s in itg_predicted_texts])

gt_texts = [
    bert_tokenizer.decode(ids) for ids in dec_caption.cpu().tolist()
]
print("Ground Truth Texts:", gt_texts)


cuda
ITG Predicted Texts: [['a man in a trick on a surfboard', ' the ocean', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', ''], ['a white and is a red collar is is on the beach', ' the beachy beach', '', '', '', '', ' a', '', '', '', '', '', '', '', '', '', '', '', '', '', '', ''], ['a girl girl in a red and is a pink shirt is in a grass', '', ' her', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', ''], ['a dogs play in the poolie pool', '', '', '', '', '', '', '', '', '', '', '', '', '']]
Ground Truth Texts: ['[unused99] a man performs a trick on a surfboard in the water.', '[unused99] a white dog with a black chest piece runs along the shore of a grainy beach.', '[unused99] a young girl with a read shirt and a white hat playing among the foliage.', '[unused99] two dogs wrestle in a kiddie pool.']


In [115]:
##########################   STAGE 2    #####################################
lr_list = [1e-4,3e-5]
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

for lr_ in lr_list:
    device = torch.device('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr = lr_)
    num_epochs = 8

    model = model.to(device)
    model.train()
    grad_accum_steps = 64
    for epoch in range(num_epochs):
        iteration = 0
        print(f"***************   Epoch {epoch + 1}  ***************")
        for img, question_placeholder, input_caption, input_mask in stage2_train_dataloader:
            img = img.to(device)
            question_placeholder = question_placeholder.to(device)
            question_placeholder = question_placeholder.squeeze()
            input_caption = input_caption.to(device).squeeze()
           

            input_mask = input_mask.to(device).squeeze()
            B, S = input_caption.shape

            pad_id = tokenizer.pad_token_id  # should be 0 for flan-t5
            labels = input_caption.clone()
            labels[labels == pad_id] = -100

            out = model.stage2(img, question_placeholder, labels, input_mask, (B, S))
            total_loss = out.loss

            total_loss /= grad_accum_steps
            total_loss.backward()

            if (iteration+1) % grad_accum_steps == 0:
                max_norm = 0.5
                gn = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
                grad_stats(model, topk=6)   # per-parameter snapshot
                gn = global_grad_norm(model.parameters())
                reported_gn = gn * grad_accum_steps
                # print(f"Iter {iteration} : Grad Norm: {reported_gn}")
                optimizer.step()
                optimizer.zero_grad()

                print(f"Epoch {epoch + 1} : Iter [{iteration} / {len(stage2_train_dataloader)}]")
                print(f"Total Loss: {total_loss.item()}")
                print("" + "*" * 50)
            iteration += 1
            
            

        torch.save(model.state_dict(), "blip2.pt")


***************   Epoch 1  ***************

=== Smallest gradient medians ===
name                                                         med=0.000e+00 mean=0.000e+00 max=0.000e+00 NaN=False Inf=False
name                                                         med=0.000e+00 mean=0.000e+00 max=0.000e+00 NaN=False Inf=False
name                                                         med=0.000e+00 mean=0.000e+00 max=0.000e+00 NaN=False Inf=False
name                                                         med=0.000e+00 mean=0.000e+00 max=0.000e+00 NaN=False Inf=False
name                                                         med=0.000e+00 mean=0.000e+00 max=0.000e+00 NaN=False Inf=False
name                                                         med=0.000e+00 mean=0.000e+00 max=0.000e+00 NaN=False Inf=False

=== Largest gradient medians ===
name                                                         med=1.072e-04 mean=1.275e-04 max=4.741e-04 NaN=False Inf=False
name                

In [11]:
model.load_state_dict(torch.load("blip2.pt"))
device = torch.device('cuda')
model.to(device)
model.eval()
with torch.no_grad():
    img, question_placeholder, input_caption, input_mask = next(iter(stage2_train_dataloader))
    img = img.to(device)
    question_placeholder = question_placeholder.to(device).squeeze()
    input_caption = input_caption.to(device).squeeze()
    input_mask = input_mask.to(device).squeeze()

    B, S = input_caption.shape

    # forward pass
    out = model.stage2(img, question_placeholder, input_caption, input_mask, (B, S))

    # logits: (B, L, V)
    logits = out.logits

    # greedy prediction
    pred_ids = torch.argmax(logits, dim=-1)   # (B, L)

    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

    decoded = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)

    print("Predicted texts:", decoded)

    ground_truth = tokenizer.batch_decode(input_caption, skip_special_tokens=True)
    print("Ground truth texts:", ground_truth)

        

  return torch._native_multi_head_attention(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Predicted texts: ['A  gold  are bya s  on    cars cars are a                        A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A', 'A man woman is a red shirt book is playing  hands.  ...               A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A', 'A man is a  shirt is  is riding a trick   bike .                  A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A', 'A man is  is a white shirt is  a streetcliffwaying.s        A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A A']
Ground truth texts: ['Silver and blue car marked 104 raises dust on road as two background people watch .', 'A young girl in a grey illustrated shirt is holding her hands over her head .', 'A boy wearing a red shirt and jeans is doing a flip on his bike .', 'A boy in jeans and a black shirt skating down a stair rail .']


In [47]:
from PIL import Image

device = torch.device('cuda')
model.load_state_dict(torch.load("blip2.pt"))
model = model.to(device)
img, question_placeholder,cls_caption, dec_caption  = next(iter(stage2_train_dataloader))
img = img.to(device)
question_placeholder = question_placeholder.to(device)
question_placeholder = question_placeholder.squeeze()[:1]

model.eval()

image_path = "example.jpg"
image = Image.open(image_path).convert('RGB')
image = stage2_train_dataset.transform(image)
image = image.unsqueeze(0)  # Add batch dimension

input_caption = "Question: describe the image. Answer: "
input_tokens, input_mask = flan_t5_tokenizer.tokenize_text(input_caption)

image = image.to(device)
input_tokens = input_tokens.to(device)
input_mask = input_mask.to(device)
print(input_tokens.shape, input_mask.shape)



torch.Size([1, 77]) torch.Size([1, 77])


In [57]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
device = torch.device('cuda')
model.load_state_dict(torch.load("blip2.pt"))
model = model.to(device)
img, question_placeholder,cls_caption, dec_caption  = next(iter(stage2_train_dataloader))
img = img.to(device)
question_placeholder = question_placeholder.to(device)
question_placeholder = question_placeholder.squeeze()[:1]

model.eval()

image_path = "example.jpg"
image = Image.open(image_path).convert('RGB')
plt.imshow(image)


img_pil = Image.open(image_path).convert("RGB")

plt.figure(figsize=(5,5))
plt.imshow(np.array(img_pil))
plt.axis("off")
plt.tight_layout()
plt.show()

image = stage2_train_dataset.transform(image)
image = image.unsqueeze(0)  # Add batch dimension

input_caption = "Question: describe the image. Answer: "
input_tokens, input_mask = flan_t5_tokenizer.tokenize_text(input_caption)

image = image.to(device)
input_tokens = input_tokens.to(device)
input_mask = input_mask.to(device)
print(input_tokens.shape, input_mask.shape)

gen_ids = model.predict(image, question_placeholder, input_mask, (input_tokens.shape[0], input_tokens.shape[1]))

decoded_output = flan_t5_tokenizer.decode(gen_ids[0])
print(decoded_output)

# ground_truth = flan_t5_tokenizer.decode(cls_caption[0][0])
# print("Ground truth texts:", ground_truth)


torch.Size([1, 77]) torch.Size([1, 77])
A man is riding a bike in the snow.
