In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Khối xây dựng này không đổi
class FusedTransformerBlock(nn.Module):
    """Một khối Transformer kết hợp cho cả Encoder và Decoder. (SA -> CA -> FFN)"""
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1, is_causal=False):
        super().__init__()
        self.is_causal = is_causal
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.ffn = nn.Sequential(nn.Linear(embed_dim, ff_dim), nn.GELU(), nn.Dropout(dropout_rate), nn.Linear(ff_dim, embed_dim))
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout3 = nn.Dropout(dropout_rate)

    def forward(self, text_embeds, visual_embeds, text_mask=None, visual_mask=None):
        sa_attn_mask = self._generate_causal_mask(text_embeds.size(1)).to(text_embeds.device) if self.is_causal else None
        sa_out, _ = self.self_attn(query=text_embeds, key=text_embeds, value=text_embeds, attn_mask=sa_attn_mask, key_padding_mask=text_mask)
        text_embeds = self.norm1(text_embeds + self.dropout1(sa_out))
        ca_out, _ = self.cross_attn(query=text_embeds, key=visual_embeds, value=visual_embeds, key_padding_mask=visual_mask)
        text_embeds = self.norm2(text_embeds + self.dropout2(ca_out))
        ffn_out = self.ffn(text_embeds)
        text_embeds = self.norm3(text_embeds + self.dropout3(ffn_out))
        return text_embeds

    def _generate_causal_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

# --- CÁC THÀNH PHẦN SẼ ĐƯỢC CHIA SẺ ---

class ImageEncoder(nn.Module):
    """Placeholder cho ViT, chịu trách nhiệm mã hóa ảnh."""
    def __init__(self, embed_dim=768, image_size=224, patch_size=16):
        super().__init__()
        num_patches = (image_size // patch_size) ** 2
        self.patch_proj = nn.Linear(3 * patch_size * patch_size, embed_dim) # Placeholder
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, image_patches):
        patch_embeds = self.patch_proj(image_patches)
        cls_token = self.cls_token.expand(patch_embeds.shape[0], -1, -1)
        visual_embeds = torch.cat((cls_token, patch_embeds), dim=1)
        visual_embeds += self.pos_embed
        return self.norm(visual_embeds)

class SharedTextEmbeddings(nn.Module):
    """Module chứa các lớp embedding văn bản được chia sẻ."""
    def __init__(self, vocab_size, embed_dim=768, max_seq_len=512):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, embed_dim))

    def forward(self, text_tokens):
        seq_len = text_tokens.size(1)
        embeds = self.token_embed(text_tokens) + self.pos_embed[:, :seq_len, :]
        return embeds

In [None]:
# --- MODULE 1: UNIMODAL ENCODER ---
class UnimodalModule(nn.Module):
    def __init__(self, image_encoder, shared_text_embeddings, num_layers=12, embed_dim=768, num_heads=12, ff_dim=3072):
        super().__init__()
        # Nhận các module đã được khởi tạo từ bên ngoài
        self.image_encoder = image_encoder
        self.shared_text_embeddings = shared_text_embeddings
        
        # Tự tạo ra các thành phần của riêng nó
        bert_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim, batch_first=True)
        self.text_encoder = nn.TransformerEncoder(bert_layer, num_layers=num_layers)

    def forward(self, image_patches=None, text_tokens=None, text_mask=None):
        visual_output = self.image_encoder(image_patches) if image_patches is not None else None
        
        text_output = None
        if text_tokens is not None:
            text_embeds = self.shared_text_embeddings(text_tokens)
            text_output = self.text_encoder(text_embeds, src_key_padding_mask=text_mask)
            
        return visual_output, text_output

# --- MODULE 2: IMAGE-GROUNDED TEXT ENCODER ---
class ImageGroundedEncoderModule(nn.Module):
    def __init__(self, shared_text_embeddings, num_layers=12, embed_dim=768, num_heads=12, ff_dim=3072):
        super().__init__()
        self.shared_text_embeddings = shared_text_embeddings
        self.blocks = nn.ModuleList([
            FusedTransformerBlock(embed_dim, num_heads, ff_dim, is_causal=False) for _ in range(num_layers)
        ])

    def forward(self, text_tokens, visual_embeds, text_mask=None):
        text_embeds = self.shared_text_embeddings(text_tokens)
        for block in self.blocks:
            text_embeds = block(text_embeds=text_embeds, visual_embeds=visual_embeds, text_mask=text_mask)
        return text_embeds

# --- MODULE 3: IMAGE-GROUNDED TEXT DECODER ---
class ImageGroundedDecoderModule(nn.Module):
    def __init__(self, shared_text_embeddings, vocab_size, num_layers=12, embed_dim=768, num_heads=12, ff_dim=3072):
        super().__init__()
        self.shared_text_embeddings = shared_text_embeddings
        self.blocks = nn.ModuleList([
            FusedTransformerBlock(embed_dim, num_heads, ff_dim, is_causal=True) for _ in range(num_layers)
        ])
        self.vocab_predictor = nn.Linear(embed_dim, vocab_size)

    def forward(self, text_tokens, visual_embeds, text_mask=None):
        text_embeds = self.shared_text_embeddings(text_tokens)
        for block in self.blocks:
            text_embeds = block(text_embeds=text_embeds, visual_embeds=visual_embeds, text_mask=text_mask)
        return self.vocab_predictor(text_embeds)

In [None]:
class MED_Composed(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, num_layers=12, num_heads=12, **kwargs):
        super().__init__()
        
        # 1. Khởi tạo các thành phần DÙNG CHUNG
        self.shared_image_encoder = ImageEncoder(embed_dim=embed_dim)
        self.shared_text_embeddings = SharedTextEmbeddings(vocab_size=vocab_size, embed_dim=embed_dim)
        
        # 2. Khởi tạo các module riêng biệt và "TIÊM" (inject) các thành phần dùng chung vào
        self.unimodal_module = UnimodalModule(
            image_encoder=self.shared_image_encoder,
            shared_text_embeddings=self.shared_text_embeddings,
            num_layers=num_layers, embed_dim=embed_dim, num_heads=num_heads
        )
        
        self.grounded_encoder_module = ImageGroundedEncoderModule(
            shared_text_embeddings=self.shared_text_embeddings,
            num_layers=num_layers, embed_dim=embed_dim, num_heads=num_heads
        )
        
        self.grounded_decoder_module = ImageGroundedDecoderModule(
            shared_text_embeddings=self.shared_text_embeddings,
            vocab_size=vocab_size,
            num_layers=num_layers, embed_dim=embed_dim, num_heads=num_heads
        )

    def forward(self, mode, image_patches=None, text_tokens=None, text_mask=None):
        if mode == 'unimodal':
            # Module unimodal tự quản lý image_encoder
            return self.unimodal_module(image_patches, text_tokens, text_mask)
            
        elif mode in ['grounded_encoder', 'grounded_decoder']:
            if image_patches is None or text_tokens is None:
                raise ValueError("Grounded modes require both image and text inputs.")
            
            # Mã hóa ảnh một lần duy nhất bằng module chia sẻ
            visual_embeds = self.shared_image_encoder(image_patches)
            
            if mode == 'grounded_encoder':
                return self.grounded_encoder_module(text_tokens, visual_embeds, text_mask)
            else: # grounded_decoder
                return self.grounded_decoder_module(text_tokens, visual_embeds, text_mask)
        else:
            raise ValueError(f"Unknown mode: {mode}")

### Style 2

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, BertConfig

# --- Helper functions/classes for better readability ---
# (Using nn.Identity as a placeholder where actual layers would be in a full impl)

class ImageEncoder(nn.Module):
    def __init__(self, image_size=224, patch_size=16, embed_dim=768, num_layers=12, num_heads=12):
        super().__init__()
        # In a real BLIP implementation, this would be a Vision Transformer (ViT).
        # You'd typically use 'timm' library or a custom ViT implementation.
        # For simplicity and focusing on overall architecture, we'll mimic its output.

        # Mimic patch embedding and position embedding
        self.patch_embedding = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (image_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        
        # Transformer Encoder layers (using standard PyTorch TransformerEncoderLayer)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dim_feedforward=embed_dim * 4, 
            batch_first=True # Important for batch_size first
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, images):
        x = self.patch_embedding(images).flatten(2).transpose(1, 2) # (B, N_patches, E)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1) # (B, N_patches+1, E)
        x += self.pos_embedding[:, :(x.shape[1])] # Add positional embedding

        # Transformer Encoder expects (batch_size, sequence_length, embed_dim)
        features = self.transformer_encoder(x)
        return features # (B, N_patches+1, E)

class TextEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        # We leverage Hugging Face Transformers for the base BERT model
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.bert_model = AutoModel.from_pretrained(model_name)
        self.config = self.bert_model.config

    def forward(self, text_input_ids, attention_mask):
        # attention_mask from tokenizer ensures padding tokens are ignored
        outputs = self.bert_model(input_ids=text_input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state # (B, seq_len, E)

# --- Custom Transformer Layer for Image-Grounded Encoder ---
# This class specifically modifies the standard EncoderLayer to insert cross-attention
# as per the BLIP diagram (after self-attention and before FFN).
class ImageGroundedEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.GELU()):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model) # For FFN after cross-attn
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = activation

    def forward(self, src, memory, src_mask=None, memory_mask=None, src_key_padding_mask=None, memory_key_padding_mask=None):
        # Self-Attention
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Cross-Attention (query from src, key/value from memory)
        src2 = self.cross_attn(src, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
        src = src + self.dropout2(src2)
        src = self.norm2(src) # Norm after cross-attention

        # Feed-forward
        src3 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout3(src3)
        src = self.norm3(src) # Norm after FFN
        return src

class ImageGroundedTextEncoder(nn.Module):
    def __init__(self, bert_config, num_image_grounded_layers=4):
        super().__init__()
        # Use a base BERT model (just its embeddings and the first few layers if needed)
        self.bert_embedding_layer = AutoModel.from_pretrained("bert-base-uncased", config=bert_config).embeddings
        self.embed_dim = bert_config.hidden_size
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") # For special tokens if needed

        # Create custom encoder layers that include cross-attention
        self.encoder_layers = nn.ModuleList([
            ImageGroundedEncoderLayer(self.embed_dim, bert_config.num_attention_heads, bert_config.intermediate_size, bert_config.hidden_dropout_prob)
            for _ in range(num_image_grounded_layers)
        ])
        
        # Token for [Encode] - Assuming it's added via tokenizer to input
        # No explicit embedding for [Encode] needed if tokenizer handles it.

    def forward(self, text_input_ids, attention_mask, image_features):
        # Get embeddings from BERT's embedding layer
        # text_input_ids should already contain the [Encode] token
        text_embeddings = self.bert_embedding_layer(input_ids=text_input_ids)
        
        # Create mask for text self-attention (from attention_mask)
        # nn.MultiheadAttention expects boolean masks: True for masked (ignored) position
        src_key_padding_mask = (attention_mask == 0) # Convert 0s (padding) to True
        
        # Run through custom encoder layers, injecting image features via cross-attention
        hidden_states = text_embeddings
        for layer in self.encoder_layers:
            # src is text, memory is image_features
            hidden_states = layer(
                hidden_states, 
                image_features, 
                src_key_padding_mask=src_key_padding_mask # Mask for text (src) self-attention
                # No memory_key_padding_mask for image features for simplicity, assuming no padding
            )
        
        # The output of this module is the representation of the [Encode] token.
        # This typically means extracting the embedding at the position of the [Encode] token.
        # For this example, we return all hidden_states, and the BLIP class will handle extraction.
        return hidden_states # (B, seq_len, E)

# --- ImageGroundedTextDecoder Class (for LM) ---
class ImageGroundedTextDecoder(nn.Module):
    def __init__(self, bert_config, num_image_grounded_layers=4):
        super().__init__()
        self.bert_embedding_layer = AutoModel.from_pretrained("bert-base-uncased", config=bert_config).embeddings
        self.embed_dim = bert_config.hidden_size

        # PyTorch's TransformerDecoderLayer inherently has self-attention and cross-attention
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.embed_dim, 
            nhead=bert_config.num_attention_heads, 
            dim_feedforward=bert_config.intermediate_size, 
            dropout=bert_config.hidden_dropout_prob,
            activation=nn.GELU(),
            batch_first=True # Important for batch_size first
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_image_grounded_layers)
        
        # Output layer for language modeling
        self.lm_head = nn.Linear(self.embed_dim, bert_config.vocab_size)

    def forward(self, text_input_ids, attention_mask, image_features):
        text_embeddings = self.bert_embedding_layer(input_ids=text_input_ids)
        
        # Create causal mask for self-attention in decoder
        # Causal mask: ensure token only attends to previous tokens
        seq_len = text_input_ids.shape[1]
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(text_input_ids.device)
        
        # Create key padding mask for decoder (based on attention_mask)
        tgt_key_padding_mask = (attention_mask == 0) # True for padding positions
        
        # Decoder forward pass
        # tgt: text embeddings, memory: image features
        # tgt_mask: causal mask, tgt_key_padding_mask: padding mask for text
        # memory_key_padding_mask: padding mask for image features (if any)
        hidden_states = self.transformer_decoder(
            tgt=text_embeddings,
            memory=image_features,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        
        logits = self.lm_head(hidden_states)
        return logits # (B, seq_len, vocab_size)


# --- BLIP Main Class (Tổng hợp) ---
class BLIP(nn.Module):
    def __init__(self, image_encoder_params, text_encoder_params, bert_config, 
                 itm_num_layers=4, lm_num_layers=4):
        super().__init__()

        # Khởi tạo các module cơ bản
        self.image_encoder = ImageEncoder(**image_encoder_params)
        
        # TextEncoder cho Unimodal và ITC (dùng BERT gốc)
        self.text_encoder_unimodal = TextEncoder(**text_encoder_params) 
        
        # Image-grounded Text Encoder (cho ITM)
        self.image_grounded_text_encoder = ImageGroundedTextEncoder(
            bert_config=bert_config,
            num_image_grounded_layers=itm_num_layers
        )
        
        # Image-grounded Text Decoder (cho LM)
        self.image_grounded_text_decoder = ImageGroundedTextDecoder(
            bert_config=bert_config,
            num_image_grounded_layers=lm_num_layers
        )

        # Các linear layers cho ITC và ITM heads
        embed_dim = bert_config.hidden_size 
        self.itc_head = nn.Linear(embed_dim, embed_dim) # For contrastive projection
        self.itm_head = nn.Linear(embed_dim, 2) # Binary classification for ITM (matched/unmatched)

        # Loss functions
        self.cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=self.text_encoder_unimodal.tokenizer.pad_token_id)
        # For ITC, typically uses InfoNCE Loss, which is more complex.
        # For simplicity, I'll use a placeholder similarity loss here.
        self.temperature = nn.Parameter(torch.ones([]) * 0.07) # Learnable temperature for contrastive loss

    # --- Các phương thức theo "functionality" ---

    def unimodal_encode(self, images, text_input_ids, text_attention_mask):
        image_features = self.image_encoder(images) 
        image_cls_feature = image_features[:, 0, :] # [CLS] token of image

        text_features = self.text_encoder_unimodal(text_input_ids, text_attention_mask)
        text_cls_feature = text_features[:, 0, :] # [CLS] token of text

        return image_cls_feature, text_cls_feature

    def image_grounded_encode(self, images, text_input_ids, text_attention_mask):
        image_features = self.image_encoder(images) 
        
        multimodal_representations = self.image_grounded_text_encoder(
            text_input_ids, text_attention_mask, image_features
        )
        
        # Assume [Encode] token is at the end of the sequence for its representation
        encode_token_representation = multimodal_representations[:, -1, :] 
        
        return encode_token_representation

    def image_grounded_decode(self, images, text_input_ids, text_attention_mask):
        image_features = self.image_encoder(images) 
        logits = self.image_grounded_text_decoder(
            text_input_ids, text_attention_mask, image_features
        )
        return logits

    def forward(self, images, 
                text_input_ids_itc=None, text_attention_mask_itc=None,
                text_input_ids_itm=None, text_attention_mask_itm=None, itm_labels=None,
                text_input_ids_lm=None, text_attention_mask_lm=None, lm_labels=None,
                task_type="pretrain"):
        
        total_loss = 0
        losses_dict = {}

        if task_type == "pretrain" or task_type == "itc":
            itc_loss = self.calculate_itc_loss(images, text_input_ids_itc, text_attention_mask_itc)
            total_loss += itc_loss
            losses_dict["itc_loss"] = itc_loss

        if task_type == "pretrain" or task_type == "itm":
            itm_loss = self.calculate_itm_loss(images, text_input_ids_itm, text_attention_mask_itm, itm_labels)
            total_loss += itm_loss
            losses_dict["itm_loss"] = itm_loss

        if task_type == "pretrain" or task_type == "lm":
            lm_loss = self.calculate_lm_loss(images, text_input_ids_lm, text_attention_mask_lm, lm_labels)
            total_loss += lm_loss
            losses_dict["lm_loss"] = lm_loss
            
        if task_type == "pretrain":
            return total_loss, losses_dict
        elif task_type in ["itc", "itm", "lm"]:
            return total_loss # Return specific loss if only one task is requested
        elif task_type == "inference":
            # Implement your inference logic here (e.g., text generation, image captioning)
            raise NotImplementedError("Inference logic not implemented yet.")
        else:
            raise ValueError(f"Unknown task_type: {task_type}")

    # --- Các hàm tính Loss (Implemented using PyTorch ops) ---
    def calculate_itc_loss(self, images, text_input_ids, text_attention_mask):
        img_cls, text_cls = self.unimodal_encode(images, text_input_ids, text_attention_mask)
        
        # Project features for contrastive loss
        img_proj = self.itc_head(img_cls)
        text_proj = self.itc_head(text_cls)

        # Normalize features
        img_proj = torch.nn.functional.normalize(img_proj, dim=-1)
        text_proj = torch.nn.functional.normalize(text_proj, dim=-1)

        # Compute cosine similarity and apply temperature
        # temp * CosineSimilarity(query, key)
        logits_per_image = torch.matmul(img_proj, text_proj.T) / self.temperature
        logits_per_text = logits_per_image.T # Transposed for text perspective

        # InfoNCE Loss (symmetric cross-entropy)
        labels = torch.arange(logits_per_image.shape[0], device=logits_per_image.device) # Diagonal is positive pair
        loss_i = self.cross_entropy_loss(logits_per_image, labels)
        loss_t = self.cross_entropy_loss(logits_per_text, labels)
        
        loss = (loss_i + loss_t) / 2
        return loss

    def calculate_itm_loss(self, images, text_input_ids, text_attention_mask, labels):
        # Labels are 0 for unmatched, 1 for matched
        multimodal_representation = self.image_grounded_encode(images, text_input_ids, text_attention_mask)
        logits = self.itm_head(multimodal_representation) 
        loss = self.cross_entropy_loss(logits, labels)
        return loss

    def calculate_lm_loss(self, images, text_input_ids, text_attention_mask, labels):
        # text_input_ids_lm should be shifted right for input to decoder
        # labels should be the actual target tokens
        logits = self.image_grounded_decode(images, text_input_ids, text_attention_mask)
        
        # For LM, we typically predict the next token.
        # labels should be text_input_ids[:, 1:] and logits[:, :-1, :]
        # Flatten for CrossEntropyLoss
        loss = self.cross_entropy_loss(logits.view(-1, logits.shape[-1]), labels.view(-1))
        return loss

# --- Ví dụ sử dụng ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Cấu hình các tham số cho từng module
    image_encoder_params = {
        "image_size": 224,
        "patch_size": 16,
        "embed_dim": 768,
        "num_layers": 12, # Number of Encoder blocks in ViT
        "num_heads": 12
    }
    
    text_encoder_params = {
        "model_name": "bert-base-uncased"
    }
    
    # Lấy config từ BERT để khởi tạo các lớp MED
    # Đây là cách tốt để đảm bảo kích thước embedding, số head, v.v. nhất quán
    bert_config = BertConfig.from_pretrained("bert-base-uncased")

    # Khởi tạo mô hình BLIP
    blip_model = BLIP(
        image_encoder_params=image_encoder_params,
        text_encoder_params=text_encoder_params,
        bert_config=bert_config,
        itm_num_layers=4, # Number of Image-Grounded Encoder layers
        lm_num_layers=4   # Number of Image-Grounded Decoder layers
    ).to(device)
    
    blip_model.eval() # Chuyển sang chế độ evaluation

    print("BLIP Model initialized successfully!")
    print(f"Total parameters: {sum(p.numel() for p in blip_model.parameters() if p.requires_grad):,}")

    # --- Tạo dữ liệu dummy để kiểm tra forward pass ---
    batch_size = 2
    dummy_images = torch.randn(batch_size, 3, 224, 224).to(device) 

    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    dummy_texts = ["a little girl holding a kitten next to a blue fence", "a dog running in a park"]

    # For ITC, ITM, LM, we prepare distinct inputs as per BLIP's training strategy
    # ITC: original image-text pairs
    encoded_itc = tokenizer(dummy_texts, padding="max_length", truncation=True, max_length=50, return_tensors="pt")
    text_input_ids_itc = encoded_itc.input_ids.to(device)
    text_attention_mask_itc = encoded_itc.attention_mask.to(device)

    # ITM: mixed positive/negative pairs. Here we simplify.
    # [CLS] text [SEP] [ENC]
    itm_texts = [f"[CLS] {t} [SEP] {tokenizer.encode('[ENC]', add_special_tokens=False)[0]}" for t in dummy_texts] # Not a real token, needs to be handled by tokenizer vocab
    # In a real BLIP setup, [ENC] and [DEC] are added as special tokens to tokenizer's vocabulary.
    # For dummy, let's just reuse original tokens and assume [ENC] is the last token added by some process.
    # A more robust dummy:
    encode_token_id = tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})['additional_special_tokens'][0]
    encoded_itm_raw = tokenizer(dummy_texts, padding=True, truncation=True, max_length=49, return_tensors="pt")
    text_input_ids_itm = torch.cat([encoded_itm_raw.input_ids, torch.full((batch_size, 1), encode_token_id, dtype=torch.long)], dim=1).to(device)
    text_attention_mask_itm = torch.cat([encoded_itm_raw.attention_mask, torch.ones((batch_size, 1), dtype=torch.long)], dim=1).to(device)
    itm_labels = torch.tensor([1, 0], dtype=torch.long).to(device) # Dummy labels: 1 matched, 0 unmatched

    # LM: auto-regressive generation.
    # [CLS] [DEC] text_prefix
    decode_token_id = tokenizer.add_special_tokens({'additional_special_tokens': ['[DEC]']})['additional_special_tokens'][0]
    # For LM input: [DEC] + text_prefix
    # For LM labels: text_prefix + [EOS] (shifted)
    
    # Dummy LM inputs: [DEC] + "a little girl" -> predict " holding a kitten next to a blue fence"
    lm_input_texts = [f"{tokenizer.decode(decode_token_id)} a little girl", f"{tokenizer.decode(decode_token_id)} a dog running"]
    lm_target_texts = ["a little girl holding a kitten next to a blue fence", "a dog running in a park"]

    encoded_lm_input = tokenizer(lm_input_texts, padding="max_length", truncation=True, max_length=50, return_tensors="pt")
    text_input_ids_lm = encoded_lm_input.input_ids.to(device)
    text_attention_mask_lm = encoded_lm_input.attention_mask.to(device)

    # For labels, we need the full target sequence, shifted for prediction
    encoded_lm_target = tokenizer(lm_target_texts, padding="max_length", truncation=True, max_length=50, return_tensors="pt")
    lm_labels = encoded_lm_target.input_ids.to(device)
    
    # In LM loss, we typically predict token i+1 from input token i.
    # So, labels should be the target sequence, and input is shifted.
    # Here, `text_input_ids_lm` is input, `lm_labels` is target.
    # Make sure padding tokens are ignored in loss.

    print("\n--- Testing Unimodal Encoding (for ITC) ---")
    with torch.no_grad():
        img_cls, text_cls = blip_model.unimodal_encode(dummy_images, text_input_ids_itc, text_attention_mask_itc)
    print(f"Image CLS feature shape: {img_cls.shape}") 
    print(f"Text CLS feature shape: {text_cls.shape}") 

    print("\n--- Testing Image-Grounded Encoding (for ITM) ---")
    with torch.no_grad():
        itm_output = blip_model.image_grounded_encode(dummy_images, text_input_ids_itm, text_attention_mask_itm)
    print(f"Image-Grounded Encode (ITM) output shape: {itm_output.shape}")

    print("\n--- Testing Image-Grounded Decoding (for LM) ---")
    with torch.no_grad():
        lm_logits = blip_model.image_grounded_decode(dummy_images, text_input_ids_lm, text_attention_mask_lm)
    print(f"Image-Grounded Decode (LM) logits shape: {lm_logits.shape}")

    print("\n--- Testing Full Pretrain Forward Pass ---")
    # Set model to train mode to allow gradient calculation if running training
    blip_model.train() 
    total_loss, losses_dict = blip_model.forward(
        images=dummy_images,
        text_input_ids_itc=text_input_ids_itc,
        text_attention_mask_itc=text_attention_mask_itc,
        text_input_ids_itm=text_input_ids_itm,
        text_attention_mask_itm=text_attention_mask_itm,
        itm_labels=itm_labels,
        text_input_ids_lm=text_input_ids_lm,
        text_attention_mask_lm=text_attention_mask_lm,
        lm_labels=lm_labels,
        task_type="pretrain"
    )
    print(f"Total Pretrain Loss: {total_loss.item():.4f}")
    print(f"Individual Losses: {losses_dict}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, BertConfig
import copy # For deepcopy for momentum encoders

# --- 1. Image Encoder (Simplified Vision Transformer structure) ---
class ImageEncoder(nn.Module):
    def __init__(self, image_size=224, patch_size=16, embed_dim=768, num_layers=12, num_heads=12):
        super().__init__()
        # Mimic ViT's patch embedding: 3x3 input image, embed_dim output
        # kernel_size=patch_size, stride=patch_size ensures non-overlapping patches
        self.patch_embedding = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
        
        # Calculate number of patches and add CLS token
        num_patches = (image_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim)) # Learnable CLS token
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim)) # Learnable positional embedding

        # Transformer Encoder layers
        # Using PyTorch's built-in TransformerEncoderLayer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=num_heads, 
            dim_feedforward=embed_dim * 4, # Standard expansion in FFN
            dropout=0.1, 
            activation=nn.GELU(),
            batch_first=True # Important for batch_size first
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, images):
        # Apply patch embedding and flatten to (B, N_patches, E)
        x = self.patch_embedding(images).flatten(2).transpose(1, 2) 
         
        # Prepend CLS token and add positional embedding
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1) # (B, N_patches+1, E)
        x += self.pos_embedding[:, :(x.shape[1])] # Add positional embedding (truncated if sequence length varies)
        
        # Pass through Transformer Encoder
        features = self.transformer_encoder(x)
        return features # (B, N_patches+1, E)

# --- 2. Text Encoder (Standard BERT for Unimodal) ---
class TextEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.bert_model = AutoModel.from_pretrained(model_name)
        self.config = self.bert_model.config

    def forward(self, text_input_ids, attention_mask):
        outputs = self.bert_model(input_ids=text_input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state # (B, seq_len, E)

# --- 3. Image-Grounded Transformer Layer (for Multimodal Encoder/Decoder) ---
# This class implements the core logic of a Transformer layer with cross-attention.
# It uses built-in MultiheadAttention and Linear layers.
class ImageGroundedTransformerLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.GELU(), is_decoder=False):
        super().__init__()
        self.is_decoder = is_decoder

        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout_ffn = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model) # After self-attention
        self.norm2 = nn.LayerNorm(d_model) # After cross-attention
        self.norm3 = nn.LayerNorm(d_model) # After FFN
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = activation

    def forward(self, src, memory, 
                src_mask=None, memory_mask=None, # attn_mask for MHA
                src_key_padding_mask=None, memory_key_padding_mask=None):
        
        # Self-Attention
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Cross-Attention (query from src, key/value from memory)
        src2 = self.cross_attn(src, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        # Feed-forward
        src3 = self.linear2(self.dropout_ffn(self.activation(self.linear1(src))))
        src = src + self.dropout3(src3)
        src = self.norm3(src)
        return src

# --- 4. Image-Grounded Text Encoder (for ITM) ---
# This part is effectively a BERT encoder with injected cross-attention layers.
class ImageGroundedTextEncoder(nn.Module):
    def __init__(self, bert_config, num_image_grounded_layers=4):
        super().__init__()
        # Start with BERT's embeddings
        bert_model_base = AutoModel.from_pretrained("bert-base-uncased", config=bert_config)
        self.bert_embedding_layer = bert_model_base.embeddings
        self.embed_dim = bert_config.hidden_size

        # In BLIP, the Image-Grounded Text Encoder re-uses text encoder's weights
        # and injects cross-attention. For simplicity and built-in usage,
        # we construct custom layers. A more exact BLIP would modify BertModel directly.
        self.encoder_layers = nn.ModuleList([
            ImageGroundedTransformerLayer(
                self.embed_dim, 
                bert_config.num_attention_heads, 
                bert_config.intermediate_size, 
                bert_config.hidden_dropout_prob,
                is_decoder=False # These are encoder-like layers
            )
            for _ in range(num_image_grounded_layers)
        ])

    def forward(self, text_input_ids, attention_mask, image_features):
        text_embeddings = self.bert_embedding_layer(input_ids=text_input_ids)
        
        src_key_padding_mask = (attention_mask == 0) 
        
        hidden_states = text_embeddings
        for layer in self.encoder_layers:
            hidden_states = layer(
                src=hidden_states, 
                memory=image_features, 
                src_key_padding_mask=src_key_padding_mask
            )
        return hidden_states 

# --- 5. Image-Grounded Text Decoder (for LM) ---
# This uses PyTorch's built-in TransformerDecoderLayer directly,
# as it naturally supports causal self-attention and cross-attention.
class ImageGroundedTextDecoder(nn.Module):
    def __init__(self, bert_config, num_image_grounded_layers=4):
        super().__init__()
        bert_model_base = AutoModel.from_pretrained("bert-base-uncased", config=bert_config)
        self.bert_embedding_layer = bert_model_base.embeddings
        self.embed_dim = bert_config.hidden_size

        # PyTorch's TransformerDecoderLayer inherently has causal self-attention and cross-attention
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.embed_dim, 
            nhead=bert_config.num_attention_heads, 
            dim_feedforward=bert_config.intermediate_size, 
            dropout=bert_config.hidden_dropout_prob,
            activation=nn.GELU(),
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_image_grounded_layers)
        
        self.lm_head = nn.Linear(self.embed_dim, bert_config.vocab_size)
        # Weight tying (common for language models)
        self.lm_head.weight = self.bert_embedding_layer.word_embeddings.weight

    def forward(self, text_input_ids, attention_mask, image_features):
        text_embeddings = self.bert_embedding_layer(input_ids=text_input_ids)
        
        seq_len = text_input_ids.shape[1]
        # Causal mask for self-attention
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(text_input_ids.device)
        
        tgt_key_padding_mask = (attention_mask == 0) # True for padding positions
        
        hidden_states = self.transformer_decoder(
            tgt=text_embeddings,
            memory=image_features,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        
        logits = self.lm_head(hidden_states)
        return logits # (B, seq_len, vocab_size)


# --- BLIP Main Class (Tổng hợp) ---
class BLIP(nn.Module):
    def __init__(self, image_encoder_params, text_encoder_params, bert_config, 
                 itm_num_layers=4, lm_num_layers=4, 
                 add_special_tokens=True):
        super().__init__()

        # Tokenizer setup for special tokens
        self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_params["model_name"])
        if add_special_tokens:
            # Add [Encode] and [Decode] tokens to the tokenizer vocabulary
            # Check if tokens already exist to avoid adding duplicates
            if '[ENC]' not in self.tokenizer.get_vocab():
                self.tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']})
            if '[DEC]' not in self.tokenizer.get_vocab():
                self.tokenizer.add_special_tokens({'additional_special_tokens': ['[DEC]']})
            self.enc_token_id = self.tokenizer.convert_tokens_to_ids('[ENC]')
            self.dec_token_id = self.tokenizer.convert_tokens_to_ids('[DEC]')
        else:
            self.enc_token_id = None 
            self.dec_token_id = None

        # 1. Khởi tạo các module cơ bản
        self.image_encoder = ImageEncoder(**image_encoder_params)
        self.embed_dim = image_encoder_params["embed_dim"] # Get embed_dim from params

        self.text_encoder_unimodal = TextEncoder(**text_encoder_params)
        
        # 2. Momentum Encoders (for ITC)
        self.image_encoder_m = copy.deepcopy(self.image_encoder)
        self.text_encoder_unimodal_m = copy.deepcopy(self.text_encoder_unimodal)

        # Freeze momentum encoders (no direct gradient updates)
        for param in self.image_encoder_m.parameters():
            param.requires_grad = False
        for param in self.text_encoder_unimodal_m.parameters():
            param.requires_grad = False
        
        # 3. Image-grounded Text Encoder (cho ITM)
        self.image_grounded_text_encoder = ImageGroundedTextEncoder(
            bert_config=bert_config,
            num_image_grounded_layers=itm_num_layers
        )
        
        # 4. Image-grounded Text Decoder (cho LM)
        self.image_grounded_text_decoder = ImageGroundedTextDecoder(
            bert_config=bert_config,
            num_image_grounded_layers=lm_num_layers
        )

        # Resize token embeddings for all text models if vocab size changed
        if add_special_tokens and len(self.tokenizer) != self.text_encoder_unimodal.bert_model.config.vocab_size:
            new_vocab_size = len(self.tokenizer)
            self.text_encoder_unimodal.bert_model.resize_token_embeddings(new_vocab_size)
            self.text_encoder_unimodal_m.bert_model.resize_token_embeddings(new_vocab_size)
            # Ensure weight tying reflects the new embedding layer
            # The .bert_embedding_layer in ImageGroundedTextEncoder/Decoder refer to the original BERT's embedding layer,
            # so their weights should point to the resized text_encoder_unimodal's embedding weights.
            self.image_grounded_text_encoder.bert_embedding_layer.word_embeddings.weight = \
                self.text_encoder_unimodal.bert_model.embeddings.word_embeddings.weight
            self.image_grounded_text_decoder.bert_embedding_layer.word_embeddings.weight = \
                self.text_encoder_unimodal.bert_model.embeddings.word_embeddings.weight
            self.image_grounded_text_decoder.lm_head.weight = \
                self.text_encoder_unimodal.bert_model.embeddings.word_embeddings.weight # Tie again


        # 5. Linear layers for ITC and ITM heads
        self.itc_head = nn.Linear(self.embed_dim, self.embed_dim) 
        self.itm_head = nn.Linear(self.embed_dim, 2) 

        self.temperature = nn.Parameter(torch.ones([]) * 0.07) 
        
        # Loss functions
        self.cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)

    # --- Update momentum encoders (used in training loop) ---
    @torch.no_grad()
    def _momentum_update(self, m=0.995):
        """Momentum update of the momentum encoder's weights."""
        for param_q, param_k in zip(self.image_encoder.parameters(), self.image_encoder_m.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)
        for param_q, param_k in zip(self.text_encoder_unimodal.parameters(), self.text_encoder_unimodal_m.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)

    # --- Encoding functions ---
    def unimodal_encode(self, images, text_input_ids, text_attention_mask):
        image_features = self.image_encoder(images) 
        image_cls_feature = image_features[:, 0, :] # [CLS] token of image

        text_features = self.text_encoder_unimodal(text_input_ids, text_attention_mask)
        text_cls_feature = text_features[:, 0, :] # [CLS] token of text

        return image_cls_feature, text_cls_feature

    @torch.no_grad() 
    def unimodal_encode_m(self, images, text_input_ids, text_attention_mask):
        image_features_m = self.image_encoder_m(images)
        image_cls_feature_m = image_features_m[:, 0, :]

        text_features_m = self.text_encoder_unimodal_m(text_input_ids, text_attention_mask)
        text_cls_feature_m = text_features_m[:, 0, :]
        return image_cls_feature_m, text_cls_feature_m

    def image_grounded_encode(self, images, text_input_ids, text_attention_mask):
        image_features = self.image_encoder(images) 
        
        multimodal_representations = self.image_grounded_text_encoder(
            text_input_ids, text_attention_mask, image_features
        )
        
        encode_token_representation = multimodal_representations[:, -1, :] 
        
        return encode_token_representation

    def image_grounded_decode(self, images, text_input_ids, text_attention_mask):
        image_features = self.image_encoder(images) 
        logits = self.image_grounded_text_decoder(
            text_input_ids, text_attention_mask, image_features
        )
        return logits

    # --- Loss Calculation Functions ---
    def calculate_itc_loss(self, images, text_input_ids, text_attention_mask):
        img_cls, text_cls = self.unimodal_encode(images, text_input_ids, text_attention_mask)
        
        img_proj = self.itc_head(img_cls)
        text_proj = self.itc_head(text_cls)

        img_proj = F.normalize(img_proj, dim=-1)
        text_proj = F.normalize(text_proj, dim=-1)

        with torch.no_grad():
            self._momentum_update() 
            img_cls_m, text_cls_m = self.unimodal_encode_m(images, text_input_ids, text_attention_mask)
            img_proj_m = F.normalize(self.itc_head(img_cls_m), dim=-1)
            text_proj_m = F.normalize(self.itc_head(text_cls_m), dim=-1)

        logits_per_image = torch.matmul(img_proj, text_proj_m.T) / self.temperature
        logits_per_text = torch.matmul(text_proj, img_proj_m.T) / self.temperature 

        labels = torch.arange(logits_per_image.shape[0], device=logits_per_image.device) 
        
        loss_i = self.cross_entropy_loss(logits_per_image, labels)
        loss_t = self.cross_entropy_loss(logits_per_text, labels)
        
        loss = (loss_i + loss_t) / 2
        return loss

    def calculate_itm_loss(self, images, text_input_ids, text_attention_mask, labels):
        multimodal_representation = self.image_grounded_encode(images, text_input_ids, text_attention_mask)
        logits = self.itm_head(multimodal_representation) 
        
        loss = self.cross_entropy_loss(logits, labels)
        return loss

    def calculate_lm_loss(self, images, text_input_ids, text_attention_mask, labels):
        logits = self.image_grounded_decode(images, text_input_ids, text_attention_mask)
        
        loss = self.cross_entropy_loss(logits.view(-1, logits.shape[-1]), labels.view(-1))
        return loss

    def forward(self, images, 
                text_input_ids_itc=None, text_attention_mask_itc=None,
                text_input_ids_itm=None, text_attention_mask_itm=None, itm_labels=None,
                text_input_ids_lm=None, text_attention_mask_lm=None, lm_labels=None,
                task_type="pretrain"):
        
        total_loss = 0
        losses_dict = {}

        if task_type == "pretrain" or task_type == "itc":
            itc_loss = self.calculate_itc_loss(images, text_input_ids_itc, text_attention_mask_itc)
            total_loss += itc_loss
            losses_dict["itc_loss"] = itc_loss

        if task_type == "pretrain" or task_type == "itm":
            itm_loss = self.calculate_itm_loss(images, text_input_ids_itm, text_attention_mask_itm, itm_labels)
            total_loss += itm_loss
            losses_dict["itm_loss"] = itm_loss

        if task_type == "pretrain" or task_type == "lm":
            lm_loss = self.calculate_lm_loss(images, text_input_ids_lm, text_attention_mask_lm, lm_labels)
            total_loss += lm_loss
            losses_dict["lm_loss"] = lm_loss
            
        if task_type == "pretrain":
            return total_loss, losses_dict
        elif task_type in ["itc", "itm", "lm"]:
            if task_type == "itc": return itc_loss
            if task_type == "itm": return itm_loss
            if task_type == "lm": return lm_loss
        elif task_type == "inference":
            raise NotImplementedError("Inference (e.g., text generation) not implemented in forward. Use .generate() method.")
        else:
            raise ValueError(f"Unknown task_type: {task_type}")

    @torch.no_grad()
    def generate(self, images, max_length=20, num_beams=1, temperature=1.0):
        self.eval() 
        
        batch_size = images.shape[0]
        device = images.device

        input_ids = torch.full((batch_size, 1), self.dec_token_id, dtype=torch.long, device=device)
        attention_mask = torch.ones(batch_size, 1, dtype=torch.long, device=device)

        image_features = self.image_encoder(images)

        generated_ids = []
        for _ in range(max_length):
            logits = self.image_grounded_decode(input_ids=input_ids, attention_mask=attention_mask, image_features=image_features)
            
            next_token_logits = logits[:, -1, :] / temperature

            next_token = torch.argmax(next_token_logits, dim=-1)
            
            generated_ids.append(next_token)
            
            input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
            attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), dtype=torch.long, device=device)], dim=-1)

            # Check if all sequences have generated EOS or SEP token (BERT usually uses SEP for EOS)
            if torch.all(next_token == self.tokenizer.sep_token_id) or torch.all(next_token == self.tokenizer.pad_token_id):
                 break
        
        generated_ids = torch.stack(generated_ids, dim=1)
        # Decode, skipping special tokens like [CLS], [SEP], [PAD], [ENC], [DEC]
        generated_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in generated_ids]
        
        self.train() 
        return generated_texts

# --- Ví dụ sử dụng ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    image_encoder_params = {
        "image_size": 224,
        "patch_size": 16,
        "embed_dim": 768, # ViT-Base dimension
        "num_layers": 12, # Number of Encoder blocks in ViT
        "num_heads": 12
    }
    
    text_encoder_params = {
        "model_name": "bert-base-uncased"
    }
    
    bert_config = BertConfig.from_pretrained("bert-base-uncased")

    blip_model = BLIP(
        image_encoder_params=image_encoder_params,
        text_encoder_params=text_encoder_params,
        bert_config=bert_config,
        itm_num_layers=4, 
        lm_num_layers=4,  
        add_special_tokens=True 
    ).to(device)
    
    tokenizer = blip_model.tokenizer
    enc_token_id = tokenizer.convert_tokens_to_ids('[ENC]')
    dec_token_id = tokenizer.convert_tokens_to_ids('[DEC]')
    print(f"Tokenizer vocab size: {len(tokenizer)}")
    print(f"[ENC] token ID: {enc_token_id}")
    print(f"[DEC] token ID: {dec_token_id}")
    print(f"PAD token ID: {tokenizer.pad_token_id}")
    print(f"SEP token ID (often used as EOS for BERT): {tokenizer.sep_token_id}")

    blip_model.eval() 
    print("BLIP Model initialized successfully!")
    print(f"Total parameters: {sum(p.numel() for p in blip_model.parameters() if p.requires_grad):,}")
    
    # --- Tạo dữ liệu dummy ---
    batch_size = 2
    dummy_images = torch.randn(batch_size, 3, 224, 224).to(device) 
    dummy_texts = ["a little girl holding a kitten next to a blue fence", "a dog running in a park"]
    
    max_seq_len = 50

    # For ITC (standard text input)
    encoded_itc = tokenizer(dummy_texts, padding="max_length", truncation=True, max_length=max_seq_len, return_tensors="pt")
    text_input_ids_itc = encoded_itc.input_ids.to(device)
    text_attention_mask_itc = encoded_itc.attention_mask.to(device)

    # For ITM ([CLS] text [SEP] [ENC])
    # The [ENC] token is appended. Max length adjusts for it.
    encoded_itm_base = tokenizer(dummy_texts, padding="max_length", truncation=True, max_length=max_seq_len-1, return_tensors="pt")
    text_input_ids_itm = torch.cat([encoded_itm_base.input_ids, torch.full((batch_size, 1), enc_token_id, dtype=torch.long)], dim=1).to(device)
    text_attention_mask_itm = torch.cat([encoded_itm_base.attention_mask, torch.ones((batch_size, 1), dtype=torch.long)], dim=1).to(device)
    itm_labels = torch.tensor([1, 0], dtype=torch.long).to(device) # 1 matched, 0 unmatched

    # For LM ([DEC] text_prefix_for_decoder_input and text_full_for_labels)
    # Decoder input will be [DEC] + text_tokens_shifted_right (e.g., [DEC] T1 T2 T3)
    # Labels will be text_tokens (e.g., T1 T2 T3 [SEP/PAD])
    
    # Text for decoder input: Prepend [DEC]
    lm_input_texts = [f"{tokenizer.decode(dec_token_id, skip_special_tokens=True)} {t}" for t in dummy_texts]
    encoded_lm_input = tokenizer(lm_input_texts, padding="max_length", truncation=True, max_length=max_seq_len, return_tensors="pt")
    text_input_ids_lm = encoded_lm_input.input_ids.to(device)
    text_attention_mask_lm = encoded_lm_input.attention_mask.to(device)

    # Labels for LM loss (shifted by one, padding masked to -100)
    # Original target tokens: CLS T1 T2 T3 SEP PAD PAD
    # Labels should be: T1 T2 T3 SEP PAD PAD PAD (shifted left, first CLS removed)
    # This means input_ids to decoder are what we give.
    # labels are what we expect to predict, one token after the input.
    
    # We will use text_input_ids_lm as the input to the decoder.
    # And then compute labels by shifting it.
    lm_labels = text_input_ids_lm.clone()
    lm_labels = torch.cat([lm_labels[:, 1:], torch.full((batch_size, 1), tokenizer.pad_token_id, dtype=torch.long, device=device)], dim=1)
    lm_labels[text_attention_mask_lm == 0] = -100 # Mask padding

    print("\n--- Testing Unimodal Encoding (for ITC) ---")
    with torch.no_grad():
        img_cls, text_cls = blip_model.unimodal_encode(dummy_images, text_input_ids_itc, text_attention_mask_itc)
    print(f"Image CLS feature shape: {img_cls.shape}") 
    print(f"Text CLS feature shape: {text_cls.shape}") 

    print("\n--- Testing Image-Grounded Encoding (for ITM) ---")
    with torch.no_grad():
        itm_output = blip_model.image_grounded_encode(dummy_images, text_input_ids_itm, text_attention_mask_itm)
    print(f"Image-Grounded Encode (ITM) output shape: {itm_output.shape}")

    print("\n--- Testing Image-Grounded Decoding (for LM) ---")
    with torch.no_grad():
        lm_logits = blip_model.image_grounded_decode(dummy_images, text_input_ids_lm, text_attention_mask_lm)
    print(f"Image-Grounded Decode (LM) logits shape: {lm_logits.shape}")

    print("\n--- Testing Full Pretrain Forward Pass (Train Mode) ---")
    blip_model.train() 
    total_loss, losses_dict = blip_model.forward(
        images=dummy_images,
        text_input_ids_itc=text_input_ids_itc,
        text_attention_mask_itc=text_attention_mask_itc,
        text_input_ids_itm=text_input_ids_itm,
        text_attention_mask_itm=text_attention_mask_itm,
        itm_labels=itm_labels,
        text_input_ids_lm=text_input_ids_lm, 
        text_attention_mask_lm=text_attention_mask_lm,
        lm_labels=lm_labels, 
        task_type="pretrain"
    )
    print(f"Total Pretrain Loss: {total_loss.item():.4f}")
    print(f"Individual Losses: {losses_dict}")

    print("\n--- Testing Inference (Generation) ---")
    generated_captions = blip_model.generate(dummy_images, max_length=20)
    print("Generated Captions:")
    for i, caption in enumerate(generated_captions):
        print(f"Image {i+1}: {caption}")

### Implement

In [None]:
import torch 
import torch.nn as nn 

class ImageEncoder(nn.Module):
    def __init__(self, image_size = 224, patch_size = 16, embed_dim = 768):
        super().__init__()
        num_patch = (image_size // patch_size) ** 2
        