In [2]:
from torchvision import transforms
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPTokenizer  
from pycocotools.coco import COCO

class CocoImageCaptionDataset(Dataset):
    def __init__(self, image_dir, annotation_file, transform=None, max_length=50):
        self.image_dir = image_dir
        self.coco = COCO(annotation_file)
        self.transform = transform
        # Changed: Use CLIP tokenizer for caption embedding
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.max_length = max_length
        
        # Get image IDs and their captions
        self.image_ids = list(self.coco.imgs.keys())
        self.annotations = {img_id: [] for img_id in self.image_ids}
        
        for ann in self.coco.loadAnns(self.coco.getAnnIds()):
            self.annotations[ann['image_id']].append(ann['caption'])
        if self.tokenizer.pad_token is None or self.tokenizer.pad_token == self.tokenizer.eos_token:
             self.tokenizer.add_special_tokens({'pad_token': '<pad>'})

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        # Get image ID
        image_id = self.image_ids[idx]

        # Load image
        img_info = self.coco.loadImgs(image_id)[0]
        img_path = os.path.join(self.image_dir, img_info['file_name'])
        image = Image.open(img_path).convert('RGB')

        # Apply transforms if available
        if self.transform:
            image = self.transform(image)

        # Choose a random caption for this image
        captions = self.annotations[image_id]
        caption = captions[torch.randint(0, len(captions), (1,)).item()]
        
        # Tokenize caption with CLIP's tokenizer
        encoding = self.tokenizer(
            caption,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
            return_attention_mask=False
        )
        
        input_ids = encoding['input_ids'].squeeze(0)

        return image, input_ids

# Image transformations
# Changed: Updated normalization values to those typically used for CLIP
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
                         std=[0.26862954, 0.26130258, 0.27577711]),
])

# Dataset paths
train_image_dir = '/root/myprojectishere/CocoDataset/train2017'
train_annotation_file = '/root/myprojectishere/CocoDataset/annotations/captions_train2017.json'

val_image_dir = '/root/myprojectishere/CocoDataset/val2017'
val_annotation_file = '/root/myprojectishere/CocoDataset/annotations/captions_val2017.json'

# Initialize datasets
dataset_train = CocoImageCaptionDataset(
    image_dir=train_image_dir,
    annotation_file=train_annotation_file,
    transform=transform,
    max_length=50
)

dataset_val = CocoImageCaptionDataset(
    image_dir=val_image_dir,
    annotation_file=val_annotation_file,
    transform=transform,
    max_length=50
)

# Create DataLoaders
dataloader_train = DataLoader(dataset_train, batch_size=64, shuffle=True, num_workers=12,pin_memory=True)
dataloader_val = DataLoader(dataset_val, batch_size=32, shuffle=False, num_workers=12,pin_memory=True)

# Test iteration



  from .autonotebook import tqdm as notebook_tqdm


loading annotations into memory...
Done (t=0.78s)
creating index...
index created!
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!


In [5]:

from transformers import CLIPTextModel, CLIPTokenizer
import torchvision.models as models
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint, checkpoint_sequential




In [80]:
import math
class RotaryPositionEmbeddings(nn.Module):
    def __init__(self, d_embed, base=10000):
        super().__init__()
        self.d_embed = d_embed
        self.base = base

        # Create the frequencies for rotary embeddings
        inv_freq = 1.0 / (self.base ** (torch.arange(0, d_embed, 2).float() / d_embed))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seq_len,device):
        t = torch.arange(seq_len, dtype=torch.float32,device=device)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        return torch.cat((freqs.sin(), freqs.cos()), dim=-1)

def apply_rotary_embedding(x, rotary_emb):
    # x Batch_Size, H, Seq_Len, D_Head
    seq_len = x.size(-2)
    rotary_emb = rotary_emb[:seq_len, :]

    rotary_emb = rotary_emb.unsqueeze(0).unsqueeze(0)  # (1, 1, Seq_Len, D_Head)

    x_1, x_2 = torch.chunk(x, 2 , dim=-1)
    emb_sin, emb_cos = torch.chunk(rotary_emb, 2, dim=-1)

    x_rot = x_1 * emb_cos - x_2 * emb_sin
    x_pass = x_1 * emb_sin + x_2 * emb_cos

    return torch.cat([x_rot, x_pass], dim=-1)

# Updated SelfAttention with Rotary Position Embeddings
class SelfAttention(nn.Module):
    def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

        # Rotary Position Embeddings
        self.rotary_emb = RotaryPositionEmbeddings(self.d_head)

    def forward(self, x, attention_mask:None, causal_mask=False):
        batch_size, seq_len, d_embed = x.size()
        dtype = x.dtype
        q, k, v = self.in_proj(x).chunk(3, dim=-1)

        # Reshape to multi-head shape
        q = q.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)

        # Apply rotary position embeddings
        rotary_emb = self.rotary_emb(seq_len, x.device)
        q = apply_rotary_embedding(q, rotary_emb)
        k = apply_rotary_embedding(k, rotary_emb)

        # Compute attention weights
        attn_weights = q @ k.transpose(-1, -2) / math.sqrt(self.d_head)


        if attention_mask is not None:
        # attention_mask: (batch, seq_len) -> (batch, 1, 1, seq_len)
         extended_mask = attention_mask[:, None, None, :].to(dtype=attn_weights.dtype) 
         attn_weights = attn_weights.masked_fill(extended_mask == 0, float('-inf'))

        if causal_mask:
            mask = torch.ones_like(attn_weights, dtype=torch.bool).triu(1)
            attn_weights.masked_fill_(mask, float('-inf'))

        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,dtype=torch.float32).to(dtype)
        
        # Compute attention output
        attn_output = attn_weights @ v
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
        return self.out_proj(attn_output)

# Updated CrossAttention with Rotary Position Embeddings
class CrossAttention(nn.Module):
    def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
        self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x, y):
        batch_size, seq_len_q, d_embed = x.size()
        dtype=x.dtype
        if len(y.size()) == 2:
         y=y.unsqueeze(1)
         seq_len_kv=y.size(1)
        else:
         seq_len_kv=y.size(1)

        q = self.q_proj(x).view(batch_size, seq_len_q, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(y).view(batch_size, seq_len_kv, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(y).view(batch_size, seq_len_kv, self.n_heads, self.d_head).transpose(1, 2)

        # Compute attention weights
        attn_weights = q @ k.transpose(-1, -2) / math.sqrt(self.d_head)
        attn_weights = F.softmax(attn_weights,dim=-1,dtype=torch.float32).to(dtype)

        # Compute attention output
        attn_output = attn_weights @ v
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, d_embed)
        return self.out_proj(attn_output)

In [70]:
def test_self_attention():
    import torch
    from torch import nn

    batch_size = 2
    seq_len = 6
    d_embed = 64
    n_heads = 4

    # Dummy input
    x = torch.randn(batch_size, seq_len, d_embed)

    # Simulated padding mask: 1 = valid, 0 = pad
    # Second sequence is shorter and padded at the end
    attention_mask = torch.tensor([
        [1, 1, 1, 1, 1, 1],    # No padding
        [1, 1, 1, 0, 0, 0]     # Last 3 are padding
    ])

    # Instantiate the SelfAttention module
    self_attn = SelfAttention(n_heads=n_heads, d_embed=d_embed)

    # Test without causal mask
    print("=== Without Causal Mask ===")
    out = self_attn(x, attention_mask=attention_mask, causal_mask=False)
    print("Output shape:", out.shape)

    # Test with causal mask
    print("\n=== With Causal Mask ===")
    out_causal = self_attn(x, attention_mask=attention_mask, causal_mask=True)
    print("Output shape:", out_causal)

    # Verify no NaNs
    assert not torch.isnan(out).any(), "Output has NaNs"
    assert not torch.isnan(out_causal).any(), "Causal output has NaNs"
    print("\nTest passed ✅")


test_self_attention()

=== Without Causal Mask ===
Output shape: torch.Size([2, 6, 64])

=== With Causal Mask ===
Output shape: tensor([[[-2.2987e-02,  5.4691e-01,  3.9977e-02,  1.4782e-02, -2.7799e-02,
          -2.5005e-01,  6.4767e-01, -1.7515e-01, -1.4223e-01, -1.8354e-01,
           2.3692e-01, -1.4711e-01,  2.8328e-01,  5.7776e-02,  1.4011e-01,
           6.6510e-03, -3.0640e-01, -1.1216e-01, -2.9874e-01, -1.0162e-01,
           2.0155e-01, -5.3056e-02,  3.4103e-01,  3.6450e-02, -6.9006e-01,
           2.9858e-01, -2.9129e-01, -1.3770e-01,  8.0249e-02,  1.1041e-01,
          -6.4675e-01,  6.0678e-01,  9.2376e-02,  1.8472e-01,  1.5306e-01,
           1.3724e-01,  3.4186e-01,  4.1708e-01,  6.9922e-01, -5.2091e-02,
           8.5425e-02, -3.0541e-02,  1.3735e-02,  3.5371e-01,  4.2012e-01,
           1.8507e-01,  3.5497e-01,  7.1163e-01,  5.6964e-01,  3.7103e-01,
          -2.6992e-01,  8.0485e-01, -3.9273e-01,  8.4107e-02,  1.7595e-01,
          -7.6952e-01,  5.7156e-02,  1.5924e-02,  1.6571e-01, -3.2918e

In [71]:
class ResnetImageEncoder(nn.Module):
    def __init__(self,n_heads,d_cross=2048, d_model=512):
        super(ResnetImageEncoder, self).__init__()
        
        # ResNet backbone
        resnet = models.resnet101(weights='ResNet101_Weights.IMAGENET1K_V2')
        self.model = nn.Sequential(*list(resnet.children())[:-2])
        
        #checkpointing for low memory
        #self.segements=segments
        # Self-attention and normalization
        self.attn_self = SelfAttention(n_heads, d_cross)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_cross) * 0.02,requires_grad=True)  # Intializing with std=0.02 as mentioned in the Vit paper
        self.layernorm = LayerNormalization(d_cross)
        self.dropout=nn.Dropout(0.3)
        
        # Projection for cross-attention
        self.proj_cross = nn.Linear(d_cross, d_model)
        

        
    def forward(self, x):
        x=self.model(x)
        #x=checkpoint_sequential(self.model, segments=10, input=x)
        batch_size, channels, height, width = x.shape
        x = x.permute(0, 2, 3, 1).view(batch_size, height*width, channels)
        
        #necessary since we will get siglip loss for this 
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
       #Pre-ln with dropout recommend in many papers
        attn_output = self.layernorm(x)
        x= x + self.dropout(self.attn_self(attn_output,causal_mask=False))
        
       
        # cls_representation = x[:, 0]  # [b, 2048]
        # contrastive_output = self.proj_contrastive(cls_representation)  # [b, d_model]#<-There is no need to use this since its better to use [b,seq_len,d_model]
        
        # Create sequence representation for cross-attention
        cross_attn_output = self.proj_cross(x)  # [b, seq_len, d_model]
        
        return cross_attn_output

In [72]:
import torch
import torch.nn as nn

class LayerNormalization(nn.Module):
    def __init__(self, d_model: int, eps=1e-5): #Standard value for eps
        super().__init__() #Corrected super init
        self.para1 = nn.Parameter(torch.ones(d_model)) 
        self.para2 = nn.Parameter(torch.zeros(d_model)) 
        self.eps = eps
    #updated this to support mixed precision training
    def forward(self, x):
        #Imp had forgotten earlier
        if not self.training:
         return x
        x_float=x.to(torch.float32)
        mean = x_float.mean(dim=-1, keepdim=True) 
        std_normal = x_float.std(dim=-1,unbiased=False, keepdim=True) 
        normalized = (x - mean.to(x.dtype)) / torch.sqrt(std_normal + self.eps).to(x.dtype)
        return (self.para1 * normalized) + self.para2

In [73]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math # Added for ceil could also use +1 whatever
from torch.utils.checkpoint import checkpoint

class DeepSeekRouter(nn.Module):
    def __init__(self,d_ff,d_model, num_experts, capacity_factor=1.0,loss_coef=1e-2): 
        super().__init__()
        self.num_experts = num_experts
        self.hidden_size = d_model
        self.capacity_factor = capacity_factor
        self.ffn_dim = d_ff
        self.loss_coef = loss_coef 

        self.router = nn.Linear(d_model, num_experts) 
        # Create expert FFNs
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Linear(d_ff, d_model)
            ) for _ in range(num_experts)
        ])

        # Initialize weights 
        self._initialize_weights()

    def _initialize_weights(self):
        # Initialize router
        #all the values have been taken from the switch transformer paper plus recommended by chatgpt so yeah
        nn.init.xavier_uniform_(self.router.weight)
        if self.router.bias is not None:
            nn.init.zeros_(self.router.bias)

        for expert in self.experts:
            for name, param in expert.named_parameters():
                if param.dim() > 1:
                    # Kaiming for Linear layers before GELU, Xavier for the last Linear
                    if '0.weight' in name: # First linear layer
                         nn.init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu') # GELU ~ ReLU for init will have to research more since we are using gelu
                    elif '2.weight' in name: # Second linear layer
                         nn.init.xavier_uniform_(param)
                elif 'bias' in name:
                    nn.init.zeros_(param)

    def calculate_capacity(self, num_tokens):
        capacity = math.ceil(num_tokens * self.capacity_factor / self.num_experts)
        capacity = max(1, int(capacity))
        # ensure at min it has capacity that doesn't exceed total tokens
        capacity = min(capacity, num_tokens)
        return capacity

    def process_expert_batch(self, expert_inputs):
        """ Process batched inputs through each expert using activation checkpointing. """

        #(Expert_index,cap,d_model)<-shape
        expert_outputs = torch.zeros_like(expert_inputs)

        for expert_idx, expert in enumerate(self.experts):
            # Get the specific inputs for this expert
            current_expert_inputs = expert_inputs[expert_idx] # Shape: (capacity, hidden_size)

            # Identify which tokens within this expert's batch are actual tokens (not padding)
            valid_token_mask_expert = current_expert_inputs.abs().sum(dim=-1) > 1e-6

            if valid_token_mask_expert.any():
                 # Select only the valid tokens for processing

                valid_inputs = current_expert_inputs[valid_token_mask_expert]
                 #This is using deepspeed checkpointng could also use pytorchs but the training is done in deepspeed
                 # Apply checkpointing: pass the expert module and its valid inputs
                 # checkpoint will handle calling expert(valid_inputs) internally
                 #Lol had to use it
                valid_tokens=checkpoint(expert,valid_inputs,use_reentrant=False)
                 # Place the results back into the correct positions in the output buffer
                expert_outputs[expert_idx, valid_token_mask_expert]=valid_tokens
                 #expert_outputs[expert_idx, valid_token_mask_expert] =   expert(expert_inputs[expert_idx, valid_token_mask_expert])
            # else: No valid tokens for this expert, expert_outputs remains zeros

        return expert_outputs

    def calculate_load_balancing_loss(self, router_probs, final_dispatch_mask_bool):
        """ Calculate aux loss based on router probs and final dispatch counts """
        num_tokens, num_experts = router_probs.shape
        if num_tokens == 0:
            return torch.tensor(0.0, device=router_probs.device)

        tokens_per_expert = final_dispatch_mask_bool.float().sum(dim=0)
        fraction_tokens_per_expert = tokens_per_expert / num_tokens # f_i

        mean_router_prob_per_expert = router_probs.mean(dim=0) # P_i

        loss = self.num_experts * torch.sum(fraction_tokens_per_expert * mean_router_prob_per_expert)
        return loss * self.loss_coef


    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        num_tokens = batch_size * seq_len

        if num_tokens == 0:
            return x, torch.tensor(0.0, device=x.device)

        # Flatten the input for routing
        flatten_x = x.view(-1, d_model) # Shape: (num_tokens, d_model)

        # 1. Get router probabilities
        router_logits = self.router(flatten_x)
        router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) # Use float32 for mixed precision

        expert_weights_topk, expert_indices_topk = torch.topk(router_probs, k=2, dim=-1)

        # Normalize weights across top-k (important if combining both, good practice anyway)#from chatgpt
        # Using float32 for stability
        normalized_expert_weights_topk = F.softmax(expert_weights_topk, dim=-1, dtype=torch.float32)

        # Extract primary and secondary assignments (use normalized weights)
        primary_expert_indices = expert_indices_topk[:, 0]
        primary_expert_weights = normalized_expert_weights_topk[:, 0] # Shape: (num_tokens,)
        secondary_expert_indices = expert_indices_topk[:, 1]
        secondary_expert_weights = normalized_expert_weights_topk[:, 1] # Shape: (num_tokens,)

        # 3. Calculate Capacity
        capacity = self.calculate_capacity(num_tokens)

        # 4. Initial Assignment & Capacity Check (Primary Only First)
        token_counts = torch.bincount(primary_expert_indices, minlength=self.num_experts)
        over_capacity = token_counts > capacity

        # 5. Handle Overflow: Reroute to Secondary (Simplified Approach)
        # IMPORTANT LIMITATION: Does not check secondary expert capacity after rerouting.
        final_expert_indices = primary_expert_indices.clone()
        final_expert_weights = primary_expert_weights.clone()


        for expert_idx in torch.where(over_capacity)[0]:
            expert_mask = (primary_expert_indices == expert_idx)
            expert_token_indices = torch.where(expert_mask)[0]

            # Find indices of tokens assigned to this primary expert that exceed capacity
            overflow_indices = expert_token_indices[capacity:] 

            # Reroute these specific tokens to their secondary choice
            final_expert_indices[overflow_indices] = secondary_expert_indices[overflow_indices]
            final_expert_weights[overflow_indices] = secondary_expert_weights[overflow_indices]

        # 6. Determine Final Dispatch Positions (respecting capacity post-rerouting for *each expert*)
        # Re-calculate counts based on final assignments
        #final_token_counts = torch.bincount(final_expert_indices, minlength=self.num_experts)

        # Create final dispatch mask and positions
        final_dispatch_mask_bool = torch.zeros((num_tokens, self.num_experts), dtype=torch.bool, device=x.device)
        dispatch_positions = torch.zeros_like(final_expert_indices)

        # Use cumsum approach on final assignments to get positions within capacity limits
        current_expert_counts = torch.zeros(self.num_experts, dtype=torch.long, device=x.device)
        for token_idx in range(num_tokens):
            expert_id = final_expert_indices[token_idx].item()
            current_count = current_expert_counts[expert_id]
            if current_count < capacity:
                 final_dispatch_mask_bool[token_idx, expert_id] = True
                 dispatch_positions[token_idx] = current_count
                 current_expert_counts[expert_id] += 1
            # Else: Token is dropped if its final assigned expert is full

        # Calculate Aux Loss based on original probs and final dispatch mask
        aux_loss = self.calculate_load_balancing_loss(router_probs, final_dispatch_mask_bool)

        # 7. Gather Inputs for Experts
        # Get indices of tokens that are actually dispatched
        #dispatch_token_indices = torch.where(final_dispatch_mask_bool.sum(dim=1) > 0)[0]
        dispatch_token_indices, assigned_expert_indices = torch.nonzero(final_dispatch_mask_bool, as_tuple=True)

        if dispatch_token_indices.numel() == 0:
             # Handle case where no tokens are dispatched
             return torch.zeros_like(x), aux_loss

        # Get assignments for dispatched tokens
        assigned_expert = assigned_expert_indices#<-Tells which we which expert was selected for per token
        assigned_position = dispatch_positions[dispatch_token_indices]#<-Imp remember this gives us the capacity

        # Allocate expert storage
        expert_inputs = torch.zeros(self.num_experts, capacity, d_model, dtype=x.dtype, device=x.device)

        # Scatter inputs using direct indexing (safer)
        expert_inputs[assigned_expert, assigned_position] = flatten_x[dispatch_token_indices]

        # 8. Process tokens with experts
        expert_outputs_buffer = self.process_expert_batch(expert_inputs)

        # 9. Gather results back & Apply Weights
        output_flat = torch.zeros_like(flatten_x)

        # Retrieve outputs for dispatched tokens
        dispatched_outputs = expert_outputs_buffer[assigned_expert, assigned_position] # Shape: (num_dispatched, d_model)

        # Get weights for dispatched tokens
        dispatched_weights = final_expert_weights[dispatch_token_indices] # Shape: (num_dispatched,)

        # Weight the outputs
        weighted_outputs = dispatched_outputs * dispatched_weights.unsqueeze(-1).to(x.dtype)

        # Scatter back using index_add_
        output_flat.index_add_(0, dispatch_token_indices, weighted_outputs)

        # 10. Reshape back to original dimensions
        output = output_flat.view(batch_size, seq_len, d_model)

        return output, aux_loss

# Example Usage (similar to before)
"""if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size = 4
    seq_len = 10
    d_model = 32
    num_experts = 8
    ffn_dim = d_model * 4 # Example FFN dim

    router_layer = DeepSeekRouter(
        d_model=d_model,
        num_experts=num_experts,
        d_ff=ffn_dim,
        capacity_factor=1.25, # Example capacity factor
        loss_coef=0.01        # Example loss coefficient
    ).to(device)

    input_tensor = torch.randn(batch_size, seq_len, d_model, device=device)

    router_layer.train()
    output_tensor, aux_loss = router_layer(input_tensor)

    print("Input Shape:", input_tensor.shape)
    print("Output Shape:", output_tensor)
    print(f"Auxiliary Loss: {aux_loss.item():.4f}")


    router_layer.eval()
    with torch.no_grad():
         output_inf, _ = router_layer(input_tensor)
    print("Inference output shape:", output_inf.shape)"""
    

'if __name__ == \'__main__\':\n    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")\n\n    batch_size = 4\n    seq_len = 10\n    d_model = 32\n    num_experts = 8\n    ffn_dim = d_model * 4 # Example FFN dim\n\n    router_layer = DeepSeekRouter(\n        d_model=d_model,\n        num_experts=num_experts,\n        d_ff=ffn_dim,\n        capacity_factor=1.25, # Example capacity factor\n        loss_coef=0.01        # Example loss coefficient\n    ).to(device)\n\n    input_tensor = torch.randn(batch_size, seq_len, d_model, device=device)\n\n    router_layer.train()\n    output_tensor, aux_loss = router_layer(input_tensor)\n\n    print("Input Shape:", input_tensor.shape)\n    print("Output Shape:", output_tensor)\n    print(f"Auxiliary Loss: {aux_loss.item():.4f}")\n\n\n    router_layer.eval()\n    with torch.no_grad():\n         output_inf, _ = router_layer(input_tensor)\n    print("Inference output shape:", output_inf.shape)'

In [10]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, d_cross, n_heads, d_ff, num_experts, 
                 capacity_factor=1.0,dropout=0.1, strategy="flamingo"):
        super().__init__()
        self.strategy = strategy.lower()
        
        # Self-Attention
        self.norm1 = LayerNormalization(d_model)
        self.self_attn = SelfAttention(n_heads, d_model)
        
        # Cross-Attention
        self.norm2 = LayerNormalization(d_model)
        self.norm_encoder = LayerNormalization(d_cross)
        self.cross_attn = CrossAttention(n_heads, d_model, d_cross)
        
        # Reverse Cross-Attention
        self.norm_img = LayerNormalization(d_cross)
        self.norm_rev_cross = LayerNormalization(d_model)
        self.rev_cross_attn = CrossAttention(n_heads, d_cross, d_model)
        
        # Gated Fusion (if enabled)
        if self.strategy == "gating":
            self.fusion_gate = nn.Sequential(
                nn.Linear(d_model * 2, d_model),
                nn.Sigmoid()
            )
            self.fusion_norm = LayerNormalization(d_model)
        
        # MoE Feed-Forward
        self.norm3 = LayerNormalization(d_model)
        self.moe = DeepSeekRouter(d_ff=d_ff,d_model=d_model, num_experts=num_experts, capacity_factor=capacity_factor,loss_coef=1e-2)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, attention_mask=None, encoder_output=None,):
        aux_loss_total = 0.0
        
        x_norm = self.norm1(x)
        self_attn_out = self.self_attn(x_norm, attention_mask=attention_mask, causal_mask=True)
        x = x + self.dropout(self_attn_out)
        
        # Only proceed with cross-attention if encoder output is provided
        if encoder_output is not None:
            # Cross-Attention Block (Pre-LN)
            x_norm = self.norm2(x)
            encoder_norm = self.norm_encoder(encoder_output)
            cross_attn_out = self.cross_attn(x_norm, encoder_norm)
            enhanced_x = x + self.dropout(cross_attn_out)
            
            # Reverse Cross-Attention (Pre-LN)
            enhanced_x_norm = self.norm_rev_cross(enhanced_x)
            encoder_norm_img = self.norm_img(encoder_output)
            rev_cross_out = self.rev_cross_attn(encoder_norm_img, enhanced_x_norm)
            enhanced_encoder = encoder_output + self.dropout(rev_cross_out)
            
            # Apply strategy based on the various papers
            
            if self.strategy == "vilbert":
                #print(f"Using : {self.strategy}")
                x = enhanced_encoder
            elif self.strategy == "gating":
               # print(f"Using strategy: {self.strategy}")
                fusion_input = torch.cat((enhanced_x, enhanced_encoder), dim=-1)
                gate = self.fusion_gate(fusion_input)
                x = enhanced_x * (1 - gate) + enhanced_encoder * gate
                x = self.fusion_norm(x)
            elif self.strategy == "flamingo":
                #print(f"Using strategy: {self.strategy}")
                x = enhanced_x
            else:
                #print(f"Since no stratergy is selected default case for : {self.strategy}")
                x = enhanced_encoder + enhanced_x
        
        # MoE Feed-Forward Block (Pre-LN)
        x_norm = self.norm3(x)
        moe_out, moe_aux_loss = self.moe(x_norm)
        x = x + self.dropout(moe_out)
        aux_loss_total += moe_aux_loss
                        
        #print(f"x shape: {x.shape}")
        #print(f"enhanced_x shape: {enhanced_x.shape}")
        #print(f"enhanced_encoder shape: {enhanced_encoder.shape}")
        
        return x, enhanced_x, enhanced_encoder, aux_loss_total

In [81]:
x=torch.rand(3,50,512)
z=torch.rand(3,50,512)
a=torch.randint(0,1,(1,50))
y=DecoderBlock(512,512,6,1024,3)

In [86]:
def test_decoder_block():
    import torch
    batch_size = 2
    seq_len = 5
    img_len = 3
    d_model = 64
    d_cross = 64
    n_heads = 8
    d_ff = 128
    num_experts = 4

    # Create dummy data
    x = torch.randn(batch_size, seq_len, d_model)
    encoder_output = torch.randn(batch_size, img_len, d_cross)
    attention_mask = torch.ones(batch_size, seq_len)  # no padding

    # Instantiate decoder block
    decoder = DecoderBlock(
        d_model=d_model,
        d_cross=d_cross,
        n_heads=n_heads,
        d_ff=d_ff,
        num_experts=num_experts,
        strategy="vilbert"  # Try "vilbert", "flamingo", etc.
    )

    # Forward pass
    x_out, enhanced_x, enhanced_encoder, aux_loss = decoder(
        x,
        attention_mask=attention_mask,
        encoder_output=encoder_output
    )

    # Assertions
    assert x_out.shape == (batch_size, seq_len, d_model), "Output shape mismatch"
    assert enhanced_x.shape == (batch_size, seq_len, d_model), "Enhanced_x shape mismatch"
    assert enhanced_encoder.shape == (batch_size, img_len, d_cross), "Enhanced encoder shape mismatch"
    assert isinstance(aux_loss, torch.Tensor), "Aux loss not returned as tensor"

    print("✅ DecoderBlock test passed!")

test_decoder_block()


TypeError: DecoderBlock.forward() got an unexpected keyword argument 'attention_mask'

In [11]:
class DecoderBlocks(nn.Module):
    def __init__(self, num_layers, d_model, d_cross, n_heads, d_ff, num_experts, 
                 capacity_factor=1.0, dropout=0.1, strategy="flamingo"):
        super().__init__()
        self.strategy = strategy.lower()
        self.layers = nn.ModuleList([
            DecoderBlock(d_model, d_cross, n_heads, d_ff, num_experts, capacity_factor, dropout, strategy)
            for _ in range(num_layers)
        ])
        self.norm = LayerNormalization(d_model)
    
    def forward(self, x, encoder_output=None):
        aux_loss_total = 0.0
        if encoder_output is not None:
            if self.strategy == "gating" or encoder_output.size(1) > x.size(1):
                encoder_output = encoder_output[:, :x.size(1), :]
            for layer in self.layers:
                x, enhanced_x, enhanced_encoder, aux_loss = layer(x, encoder_output)
                aux_loss_total += aux_loss
                encoder_output = enhanced_encoder
                #encoder_output=enhanced_x
        
        x = self.norm(x)
        return x, aux_loss_total

In [12]:

from transformers import CLIPTextModel

class ClipTextEncoder(nn.Module):
    def __init__(self, dropout=0.1, model_name="openai/clip-vit-base-patch32", freeze_clip=False):
        super().__init__()
        self.clip_text = CLIPTextModel.from_pretrained(model_name, torch_dtype=torch.bfloat16)
        if freeze_clip:
            for param in self.clip_text.parameters():
                param.requires_grad = False
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.clip_text(input_ids=input_ids, attention_mask=attention_mask)
        return self.dropout(outputs.last_hidden_state)


In [13]:
import random
class ImageEncoder(nn.Module):
    def __init__(self, num_layers: int, d_model: int, d_cross: int, n_heads: int, d_ff: int, 
                 num_experts: int, vocab_size: int, strategy:str):
        super().__init__()
        self.encoder = ResnetImageEncoder(n_heads)
        self.decoder = DecoderBlocks(num_layers, d_model, d_cross, n_heads, d_ff, num_experts,capacity_factor=1.0, dropout=0.1, strategy=strategy)
        self.fc1 = nn.Linear(d_model, vocab_size)  # This is the output layer
        self.embedding = ClipTextEncoder()
        for param in self.embedding.parameters():
         param.data = param.data.to(torch.bfloat16)
        self.img_embedding = nn.Linear(512, d_model)
        self.text_embedding = nn.Linear(d_model, d_model)
        self.logit_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1 / 0.07)))
        self.d_model=d_model

    def get_img_embedding(self, image):
        img_emb = self.encoder(image)
        img_emb = img_emb[:, 0]  #<- Extract [CLS] token representation using for the resnet atm
        return self.img_embedding(img_emb)
    
    def get_text_embedding(self, captions, encoder_output,attention_mask=None):
        captions_embeds = self.embedding(captions,attention_mask=attention_mask)
        x, _ = self.decoder(captions_embeds, encoder_output=encoder_output)
        x = x[:, 0]  # Assuming CLS token is the representation
        return self.text_embedding(x)
    
    def siglip_constrastive_loss(self, text_emb, img_emb):
        batch_size = img_emb.shape[0]
        img_emb = F.normalize(img_emb, dim=1)  # Normalize for cosine similarity
        text_emb = F.normalize(text_emb, dim=1)  # Normalize for cosine similarity
        logit_scale = self.logit_scale.exp()
        logit_per_image = logit_scale * img_emb @ text_emb.t()
        logit_per_text = logit_per_image.t()
        labels = torch.eye(batch_size, device=img_emb.device)
        img_loss = F.binary_cross_entropy_with_logits(logit_per_image, labels)#<-since we need similar score thats why
        text_loss = F.binary_cross_entropy_with_logits(logit_per_text, labels)
        return (img_loss + text_loss) / 2
    
    def forward(self, images, captions,attention_mask=None,teacher_forcing_ratio=1.0):
        device = images.device
        batch_size, seq_len = captions.shape
        
        # Get image embeddings
        encoder_output = self.encoder(images)
        image_embed = self.get_img_embedding(images)
        
        # Prepare decoder inputs (remove last token from captions)
        decoder_input = captions[:, :-1]
        captions_embeds = self.embedding(decoder_input,attention_mask=attention_mask)
        
        # Full teacher forcing or evaluation mode
        if not self.training or teacher_forcing_ratio == 1.0:
            # Forward pass through decoder
            decoder_output, aux_loss = self.decoder(captions_embeds, encoder_output=encoder_output)
            
            # Apply output projection
            logits = self.fc1(decoder_output)
            
            # Get text embeddings for contrastive loss
            text_embeds = self.get_text_embedding(captions, encoder_output=encoder_output)
            clip_loss = self.siglip_constrastive_loss(text_embeds, image_embed)
            #predicted_tokens = logits.argmax(dim=-1)  # shape: (batch_size, seq_len-1)
            #targets = captions[:, 1:]  # ground truth tokens from position 1 onwards

            # For a sample in the batch, convert indices to tokens 
            #print("Predicted tokens for sample 0:", predicted_tokens[0])
            #print("Target tokens for sample 0:   ", targets[0])
            
            return logits, aux_loss, clip_loss
        
        # Partial teacher forcing (during training)<-Logic is flawed dont use this 
        decoder_output = torch.zeros(batch_size, seq_len-1, 512, device=device)
        aux_losses = []
        
        # Start with first token
        current_input = captions_embeds[:, 0:1]
        assert captions[:, :-1].shape == captions[:, 1:].shape, "Shapes of decoder inputs and targets do not match!"
        
        for t in range(seq_len - 1):
            # For subsequent tokens, concatenate with previous tokens
            if t == 0:
                step_output, step_aux_loss = self.decoder(
                    current_input,
                    encoder_output=encoder_output)
            else:
                accumulated_input = torch.cat([captions_embeds[:, :t], current_input], dim=1)
                step_output, step_aux_loss = self.decoder(
                    accumulated_input,
                    encoder_output=encoder_output)
            
            aux_losses.append(step_aux_loss)
            decoder_output[:, t:t+1] = step_output[:, -1:]
            
            # Decide whether to use teacher forcing for next token
            use_teacher_forcing = random.random() < teacher_forcing_ratio
            
            # Don't need to compute next input for the last iteration
            #Earlier it had t<seq_len-2
            if use_teacher_forcing:
                    next_input = captions_embeds[:, t+1:t+2]
            else:
                # ...
                top1_pred = step_output.argmax(2)

                if top1_pred.dtype != captions.dtype:
                 top1_pred = top1_pred.to(captions.dtype)

                # This self.embedding is the DECODER's embedding layer
                next_input_embedding = self.embedding(top1_pred)

                # FIX: Cast if necessary
                is_bf16_mode = step_output.dtype == torch.bfloat16

                if is_bf16_mode and next_input_embedding.dtype == torch.float32:
                    print('yes')
                    next_input = next_input_embedding.to(torch.bfloat16)
                else:
                    next_input = next_input_embedding
                
                next_input=next_input.detach()
                # ...
            current_input = next_input
        
        # Apply output projection
        logits = self.fc1(decoder_output)
        #predicted_tokens = logits.argmax(dim=-1)  # shape: (batch_size, seq_len-1)
        #targets = captions[:, 1:]  # ground truth tokens from position 1 onwards

        # For a sample in the batch, convert indices to tokens (assuming you have a tokenizer or vocab mapping)
        #print("Predicted tokens for sample 0:", predicted_tokens[0])
        #print("Target tokens for sample 0:   ", targets[0])
        
        # Get text embeddings from predictions
        pred_tokens = logits.argmax(2)
        final_sequence_embeds = self.get_text_embedding(pred_tokens, encoder_output=encoder_output,attention_mask=None)
        
        # Compute contrastive loss
        clip_loss = self.siglip_constrastive_loss(final_sequence_embeds, image_embed)
        combined_aux_loss = sum(aux_losses) / len(aux_losses) if aux_losses else 0
        
        return logits, combined_aux_loss, clip_loss
    
    def greedy_caption_generation(self,image,attention_mask=None, seq_len=50, start_idx_token=49406, end_idx_token=49407):
        self.eval()
        device = image.device
        with torch.no_grad():
            # Get image features
            encoder_output = self.encoder(image.unsqueeze(0))
            # Start with start token
            current_token = torch.tensor([[start_idx_token]], device=device)
            output_tokens = [start_idx_token]
            
            for _ in range(seq_len):
                # Embed current sequence
                token_emb = self.embedding(current_token,attention_mask=attention_mask)
                
                # Get decoder output
                decoder_out, _ = self.decoder(token_emb, encoder_output=encoder_output)
                
                # Get logits and predicted token
                logits = self.fc1(decoder_out[:, -1])
                pred_token = logits.argmax(1)
                
                # Add to sequence
                pred_token_item = pred_token.item()
                output_tokens.append(pred_token_item)
                current_token = torch.cat((current_token, pred_token.view(1,1)), dim=1)
                
                # Stop if end token
                if pred_token_item == end_idx_token:
                    break
        
        self.train()
        return output_tokens[1:]  # Remove start token
    
    def beam_search_caption_generation(self,image,attention_mask=None, seq_len=50, start_idx_token=49406, end_idx_token=49407, beam_width=3):
        self.eval()
        device = image.device
        with torch.no_grad():
            # Get image features
            encoder_output = self.encoder(image.unsqueeze(0))
            
            # Initialize beam
            start_seq = torch.tensor([[start_idx_token]], device=device)
            sequences = [(start_seq, 0.0, False)]  # (sequence, score, is_complete)            
            # Beam search loop
            for _ in range(seq_len):
                # Break if all sequences are complete
                if all(is_complete for _, _, is_complete in sequences):
                    break
                
                candidates = []
                
                # Expand each sequence
                for seq, score, is_complete in sequences:
                    # Skip completed sequences
                    if is_complete:
                        candidates.append((seq, score, True))
                        continue
                    
                    # Forward pass
                    token_emb = self.embedding(seq,attention_mask=attention_mask)
                    decoder_out, _ = self.decoder(token_emb, encoder_output=encoder_output)
                    logits = self.fc1(decoder_out[:, -1])
                    
                    # Get top-k tokens
                    probs = F.softmax(logits, dim=-1)
                    top_k_probs, top_k_indices = torch.topk(probs, beam_width)
                    
                    # Create new candidates
                    for i in range(beam_width):
                        next_token = top_k_indices[0, i]
                        next_score = score - torch.log(top_k_probs[0, i]).item()  # Convert to log probability#check here if - or +
                        #could also use score+top_k_probs[:,i] your choice
                        next_seq = torch.cat((seq, next_token.view(1,1)), dim=1)
                        is_end = (next_token == end_idx_token)
                        
                        candidates.append((next_seq, next_score, is_end))
                
                # Keep top sequences
                sequences = sorted(candidates, key=lambda x: x[1])[:beam_width]
            
            # Return best completed sequence, or best incomplete if none completed
            completed = [seq for seq, _, is_complete in sequences if is_complete]
            if completed:
                best_seq = completed[0]
            else:
                best_seq = sequences[0][0]
            
        self.train()
        return best_seq.squeeze(0)[1:].tolist()  # Remove batch dim and start token
    
    def top_p_sampling(self,image,attention_mask=None, p=0.9, seq_len=50, temperature=1.0, start_idx_token=49406, end_idx_token=49407):
        self.eval()
        device = image.device
        with torch.no_grad():
            # Get image features
            encoder_output = self.encoder(image.unsqueeze(0))
            
            # Start with start token
            current_token = torch.tensor([[start_idx_token]], device=device)
            output_tokens = [start_idx_token]

            
            for _ in range(seq_len):
                # Embed current sequence
                token_emb = self.embedding(current_token,attention_mask=attention_mask)
                
                # Get decoder output
                decoder_out, _ = self.decoder(token_emb, encoder_output=encoder_output)
                
                # Get logits and apply temperature
                logits = self.fc1(decoder_out[:, -1]) / temperature
                
                # Convert to probabilities
                probs = F.softmax(logits, dim=-1)
                
                # Sort probabilities
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                
                # Compute cumulative probabilities
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                
                # Find tokens within nucleus
                nucleus = cumulative_probs < p
                
                # Ensure at least one token is selected
                if not torch.any(nucleus):
                    nucleus[0] = True
                else:
                    # Add first token that exceeds p
                    nucleus_end = torch.argmax(cumulative_probs >= p)
                    if nucleus_end > 0:  # Avoid adding again if it's the first token
                        nucleus[nucleus_end] = True
                
                # Filter by nucleus
                nucleus_probs = sorted_probs[nucleus]
                nucleus_indices = sorted_indices[nucleus]
                
                # Sample token
                sampled_idx = torch.multinomial(nucleus_probs, 1)
                pred_token = nucleus_indices[sampled_idx].item()
                
                # Add to sequence
                output_tokens.append(pred_token)
                current_token = torch.cat((current_token, 
                                          torch.tensor([[pred_token]], device=device)), dim=1)
                
                # Stop if end token
                if pred_token == end_idx_token:
                    break
        
        self.train()
        return output_tokens[1:]  # Remove start token

In [12]:


from torch.utils.tensorboard import SummaryWriter
import deepspeed
from deepspeed.ops.adam import FusedAdam
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

#Setting up the linear Lr with cosine annealing deepspeed didnt provide me with it

learning_rate = 1e-4
num_gradient_steps_per_epoch = len(dataloader_train) // 4 
num_training_steps = num_gradient_steps_per_epoch * 3 # TOTAL training steps<-check this
#probably have to write this like get the num batch from len of dataloader then int divide from gradstep since we have to see when .step will be called and then for total number 
#epochs
num_warmup_steps = int(0.1*(num_training_steps))

eta_min = 0 # Minimum LR for cosine decay

# Model Intialization

model=ImageEncoder(5,512,512,8,2048,3,49408,'vilbert')

#Setting up the optimizer
optimizer=FusedAdam(model.parameters(),lr=1e-4,betas=(0.9,0.999),eps=1e-8,weight_decay=0.01)

# 1. Warmup Scheduler
# Starts from factor 1e-4 (approx 0) and goes to 1.0 over num_warmup_steps
scheduler_warmup = LinearLR(optimizer, start_factor=1e-4, end_factor=1.0, total_iters=num_warmup_steps)

# 2. Cosine Decay Scheduler
# Starts decaying AFTER warmup is finished
num_cosine_steps = num_training_steps - num_warmup_steps
scheduler_cosine = CosineAnnealingLR(optimizer, T_max=num_cosine_steps, eta_min=eta_min)

# 3. Chain them
# The scheduler_cosine starts at the step specified in milestones
scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_cosine], milestones=[num_warmup_steps])

# Initialize DeepSpeed
model_engine= deepspeed.initialize(
    model=model,
    optimizer=optimizer, 
    lr_scheduler=scheduler,
    config_params='/root/myprojectishere/myproject/ds_config.json'
)


# Extract output_path and job_name from the loaded config
# Create the TensorBoard writer
#writer = SummaryWriter(log_dir='/root/myprojectishere/myproject/tensor')


[2025-04-05 13:35:31,741] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
Using /root/.cache/torch_extensions/py310_cu124 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py310_cu124/fused_adam/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module fused_adam...


ninja: no work to do.
Time to load fused_adam op: 0.03848910331726074 seconds
[2025-04-05 13:35:36,682] [INFO] [logging.py:107:log_dist] [Rank -1] DeepSpeed info: version=0.16.5, git-hash=unknown, git-branch=unknown
[2025-04-05 13:35:36,683] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-04-05 13:35:36,684] [INFO] [comm.py:673:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...
[2025-04-05 13:35:36,917] [INFO] [comm.py:728:mpi_discovery] Discovered MPI settings of world_rank=0, local_rank=0, world_size=1, master_addr=172.17.0.2, master_port=29500
[2025-04-05 13:35:36,919] [INFO] [comm.py:689:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2025-04-05 13:35:36,924] [INFO] [config.py:734:__init__] Config mesh_device None world_size = 1


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2025-04-05 13:35:37,330] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2025-04-05 13:35:37,333] [INFO] [logging.py:107:log_dist] [Rank 0] Using client Optimizer as basic optimizer
[2025-04-05 13:35:37,333] [INFO] [logging.py:107:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2025-04-05 13:35:37,426] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed Basic Optimizer = FusedAdam
[2025-04-05 13:35:37,427] [INFO] [utils.py:59:is_zero_supported_optimizer] Checking ZeRO support for optimizer=FusedAdam type=<class 'deepspeed.ops.adam.fused_adam.FusedAdam'>
[2025-04-05 13:35:37,427] [INFO] [logging.py:107:log_dist] [Rank 0] Creating torch.bfloat16 ZeRO stage 2 optimizer
[2025-04-05 13:35:37,428] [INFO] [stage_1_and_2.py:149:__init__] Reduce bucket size 500000000
[2025-04-05 13:35:37,428] [INFO] [stage_1_and_2.py:150:__init__] Allgather bucket size 500000000
[2025-04-05 13:35:37,429] [INFO] [stage_1_and_2.py:151:__init

In [13]:
def train_epoch_deepspeed(model_engine,criteria, dataloader,epoch, total_epochs): 
    """
    Trains the model for one epoch using the DeepSpeed engine and logs to TensorBoard.
    ... (rest of docstring) ...
    Args:
        ...
        writer: A torch.utils.tensorboard.SummaryWriter instance (or None if not rank 0).
    """
    num_batches = len(dataloader)
    if epoch==0:
        print(f"The number of batches that have been selected---->:{num_batches}")

        
    model_engine.train()

    total_loss = 0
    total_ce_loss = 0
    total_clip_loss = 0
    total_aux_loss = 0

    # Adaptive teacher forcing schedule
    min_teacher_forcing = 0.3
    teacher_forcing_ratio = max(
        min_teacher_forcing,
        1.0 - (epoch / (total_epochs * 0.75))
    )
    print(f"Epoch {epoch+1}/{total_epochs} using teacher forcing ratio: {teacher_forcing_ratio:.3f}")

    for batch_idx, (images, captions) in enumerate(dataloader):
        images = images.to(model_engine.local_rank).to(dtype=torch.bfloat16)
        captions = captions.to(model_engine.local_rank).to(dtype=torch.long)

        # Forward pass
        outputs, aux_loss, clip_loss = model_engine(images, captions, None, teacher_forcing_ratio)

        # Calculate losses
        targets = captions[:, 1:]
        ce_loss = criteria(outputs.reshape(-1, outputs.shape[-1]), targets.reshape(-1))

        alpha = min(0.2, 0.05 + (epoch / total_epochs) * 0.15)
        beta = 0.1
        combined_loss = ce_loss + alpha * clip_loss
        if aux_loss is not None:
            combined_loss += beta * aux_loss

        # Backward pass
        model_engine.backward(combined_loss)

        # Optimizer step 
        model_engine.step()
        
        # # --- OR ---# If you didn't pass lr_scheduler to initialize:
        # model_engine.step() # Only steps the optimizer
        # lr_scheduler_engine.step() # Manually step the scheduler AFTER the optimizer step


        # Accumulate for epoch average logging (optional, could just rely on step logs)
        current_lr = model_engine.get_lr()[0]
        total_loss += combined_loss.item()
        total_ce_loss += ce_loss.item()
        total_clip_loss += clip_loss.item() if torch.is_tensor(clip_loss) else clip_loss
        if aux_loss is not None:
         total_aux_loss += aux_loss.item() if torch.is_tensor(aux_loss) else aux_loss
    



        #Print progress to console 
        if  (batch_idx + 1) % 32 == 0:
             print(f"Epoch: {epoch+1}/{total_epochs}, Batch: {batch_idx+1}/{len(dataloader)}, "
                   f"Global Step: {model_engine.global_steps}, "
                   f"Step Loss: {combined_loss.item():.4f}, "
                   f"LR: {current_lr:.6e}",
                   f"Aux_loss:{aux_loss.item():.4f}",
                   f"Clip_loss:{clip_loss.item():.4f}",
                   f"Ce_loss:{ce_loss.item():.4f}"
                   )
        """if(batch_idx+1)==200:
            torch.cuda.empty_cache()
            return model_engine"""

    # --- End of epoch ---#
    num_batches = len(dataloader)
    #if writer is not None: # Log average epoch losses if rank 0
    avg_loss = total_loss / num_batches
    avg_ce_loss = total_ce_loss / num_batches
    avg_clip_loss = total_clip_loss / num_batches
    avg_aux_loss = total_aux_loss / num_batches if total_aux_loss > 0 else 0

        #writer.add_scalar('Loss/train_epoch_combined', avg_loss, epoch + 1) 
        #writer.add_scalar('Loss/train_epoch_ce', avg_ce_loss, epoch + 1)
        #writer.add_scalar('Loss/train_epoch_clip', avg_clip_loss, epoch + 1)
        #writer.add_scalar('Loss/train_epoch_aux', avg_aux_loss, epoch + 1)

    print(f"\n--- Epoch {epoch+1} Summary ---")
    print(f"Avg Loss: {avg_loss:.4f}")
    print(f"avg_clip_loss:{avg_clip_loss}")
    print(f"avg_ce_loss:{avg_ce_loss:.4f}")
        # ... (rest of the print summary) ...
    print("")
    return model_engine 


In [14]:
def model_for_validation(base_model, dataloader, criteria):
    base_model.eval()
    teacher_forcing = 1.0  # Full teacher forcing for evaluation

    total_loss = 0
    total_batches = len(dataloader)

    with torch.no_grad():  # <-- Important context for disabling gradients
        for batch_idx, (images, captions) in enumerate(dataloader):
            images = images.to('cuda').to(dtype=torch.bfloat16)
            captions = captions.to('cuda').to(dtype=torch.long)

            outputs, aux_loss, siglip_loss = base_model(images, captions, None, teacher_forcing)
            
            targets = captions[:, 1:]
            ce_loss = criteria(outputs.reshape(-1, outputs.shape[-1]), targets.reshape(-1))
            
            combined_loss = ce_loss + siglip_loss
            if aux_loss is not None:
                combined_loss += aux_loss

            total_loss += combined_loss.item()

            if (batch_idx + 1) % 8 == 0:
                print(f"Batch {batch_idx+1}/{total_batches}, "
                      f"Combined Loss: {combined_loss.item():.4f}, "
                      f"CE: {ce_loss.item():.4f}, "
                      f"AUX: {aux_loss.item() if aux_loss is not None else 0:.4f}, "
                      f"CLIP: {siglip_loss.item():.4f}")

    avg_loss = total_loss / total_batches
    print(f"\nValidation Avg Loss: {avg_loss:.4f}")
    return avg_loss


In [15]:
total_epochs = 3
criteria = nn.CrossEntropyLoss(ignore_index=49407)
model_engine = model_engine[0]  # If needed, unpack just once

for epoch in range(total_epochs):
    model_engine = train_epoch_deepspeed(
        model_engine=model_engine,
        dataloader=dataloader_train,
        epoch=epoch,
        criteria=criteria,
        total_epochs=total_epochs
    )

    val_loss = model_for_validation(
        model_engine,
        dataloader_val,
        criteria
    )
    print(f"[Epoch {epoch+1}/{total_epochs}] Validation Loss: {val_loss:.4f}")


The number of batches that have been selected---->:1849
Epoch 1/3 using teacher forcing ratio: 1.000


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Epoch: 1/3, Batch: 32/1849, Global Step: 8, Step Loss: 10.9158, LR: 5.806522e-06 Aux_loss:0.0474 Clip_loss:0.7216 Ce_loss:10.8750
Epoch: 1/3, Batch: 64/1849, Global Step: 16, Step Loss: 10.0865, LR: 1.160304e-05 Aux_loss:0.0475 Clip_loss:0.3845 Ce_loss:10.0625
Epoch: 1/3, Batch: 96/1849, Global Step: 24, Step Loss: 9.2639, LR: 1.739957e-05 Aux_loss:0.0486 Clip_loss:0.1818 Ce_loss:9.2500
Epoch: 1/3, Batch: 128/1849, Global Step: 32, Step Loss: 8.8223, LR: 2.319609e-05 Aux_loss:0.0478 Clip_loss:0.1002 Ce_loss:8.8125
Epoch: 1/3, Batch: 160/1849, Global Step: 40, Step Loss: 8.2588, LR: 2.899261e-05 Aux_loss:0.0473 Clip_loss:0.0812 Ce_loss:8.2500
Epoch: 1/3, Batch: 192/1849, Global Step: 48, Step Loss: 7.6338, LR: 3.478913e-05 Aux_loss:0.0477 Clip_loss:0.0803 Ce_loss:7.6250
[2025-04-05 13:55:19,535] [INFO] [logging.py:107:log_dist] [Rank 0] step=50, skipped=0, lr=[3.6238260869565215e-05], mom=[(0.9, 0.999)]
[2025-04-05 13:55:19,546] [INFO] [timer.py:264:stop] epoch=0/micro_step=200/global_s



Epoch: 1/3, Batch: 576/1849, Global Step: 144, Step Loss: 4.1961, LR: 9.999430e-05 Aux_loss:0.0493 Clip_loss:0.0732 Ce_loss:4.1875
[2025-04-05 14:34:47,485] [INFO] [logging.py:107:log_dist] [Rank 0] step=150, skipped=0, lr=[9.99771892244767e-05], mom=[(0.9, 0.999)]
[2025-04-05 14:34:47,497] [INFO] [timer.py:264:stop] epoch=0/micro_step=600/global_step=150, RunningAvgSamplesPerSec=10.838072268088874, CurrSamplesPerSec=10.698158225140014, MemAllocated=2.89GB, MaxMemAllocated=17.16GB
Epoch: 1/3, Batch: 608/1849, Global Step: 152, Step Loss: 3.9618, LR: 9.996895e-05 Aux_loss:0.0490 Clip_loss:0.0763 Ce_loss:3.9531
Epoch: 1/3, Batch: 640/1849, Global Step: 160, Step Loss: 3.8991, LR: 9.992334e-05 Aux_loss:0.0483 Clip_loss:0.0727 Ce_loss:3.8906
Epoch: 1/3, Batch: 672/1849, Global Step: 168, Step Loss: 4.0398, LR: 9.985749e-05 Aux_loss:0.0485 Clip_loss:0.0729 Ce_loss:4.0312
Epoch: 1/3, Batch: 704/1849, Global Step: 176, Step Loss: 3.7274, LR: 9.977142e-05 Aux_loss:0.0486 Clip_loss:0.0768 Ce_lo

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Batch 8/157, Combined Loss: 0.5434, CE: 0.3574, AUX: 0.0479, CLIP: 0.1381
Batch 16/157, Combined Loss: 0.8074, CE: 0.6211, AUX: 0.0482, CLIP: 0.1381
Batch 24/157, Combined Loss: 0.5164, CE: 0.3301, AUX: 0.0477, CLIP: 0.1386
Batch 32/157, Combined Loss: 0.4254, CE: 0.2412, AUX: 0.0478, CLIP: 0.1364
Batch 40/157, Combined Loss: 0.5312, CE: 0.3457, AUX: 0.0483, CLIP: 0.1372
Batch 48/157, Combined Loss: 0.5663, CE: 0.3809, AUX: 0.0480, CLIP: 0.1374
Batch 56/157, Combined Loss: 0.9729, CE: 0.7852, AUX: 0.0487, CLIP: 0.1390
Batch 64/157, Combined Loss: 0.6917, CE: 0.5078, AUX: 0.0474, CLIP: 0.1365
Batch 72/157, Combined Loss: 0.7278, CE: 0.5430, AUX: 0.0477, CLIP: 0.1371
Batch 80/157, Combined Loss: 0.4223, CE: 0.2344, AUX: 0.0474, CLIP: 0.1405
Batch 88/157, Combined Loss: 0.6597, CE: 0.4746, AUX: 0.0484, CLIP: 0.1366
Batch 96/157, Combined Loss: 0.8014, CE: 0.6133, AUX: 0.0477, CLIP: 0.1404
Batch 104/157, Combined Loss: 0.5947, CE: 0.4102, AUX: 0.0475, CLIP: 0.1371
Batch 112/157, Combined L

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

RuntimeError: NVML_SUCCESS == r INTERNAL ASSERT FAILED at "/pytorch/c10/cuda/CUDACachingAllocator.cpp":1015, please report a bug to PyTorch. 

In [18]:
torch.load('/root/myprojectishere/myproject/model_weights_maybe_this_works.pth')

RuntimeError: NVML_SUCCESS == r INTERNAL ASSERT FAILED at "/pytorch/c10/cuda/CUDACachingAllocator.cpp":1015, please report a bug to PyTorch. 

In [33]:
for img,caption in dataloader_val:
    break

In [35]:
img.to('cpu')

tensor([[[-0.1572, -0.3032, -0.4930,  ..., -1.2813, -1.5879, -1.5879],
         [-0.1864, -0.4346, -0.5806,  ..., -1.4273, -1.5879, -1.3689],
         [-0.2156, -0.4930, -0.6536,  ..., -1.4711, -1.5733, -1.4857],
         ...,
         [ 0.2077,  0.2661,  0.2953,  ..., -1.0331, -1.0039, -1.0331],
         [ 0.2369,  0.2515,  0.2807,  ..., -1.0623, -1.0331, -1.0769],
         [ 0.2807,  0.2661,  0.2807,  ..., -1.0623, -1.0477, -1.0915]],

        [[ 0.1089, -0.0412, -0.3264,  ..., -1.3769, -1.5570, -1.5420],
         [ 0.0939, -0.1913, -0.4314,  ..., -1.4369, -1.5570, -1.3319],
         [ 0.0488, -0.3564, -0.5065,  ..., -1.4519, -1.5420, -1.3919],
         ...,
         [-0.1613, -0.1463, -0.1163,  ..., -1.1968, -1.1818, -1.1968],
         [-0.1613, -0.1463, -0.1313,  ..., -1.1968, -1.2268, -1.2568],
         [-0.1313, -0.1163, -0.1313,  ..., -1.2118, -1.2268, -1.2718]],

        [[-0.5559, -0.5559, -0.4706,  ..., -1.1532, -1.2954, -1.2669],
         [-0.5844, -0.6697, -0.3568,  ..., -1

In [34]:
img=img[0]
img


tensor([[[-0.1572, -0.3032, -0.4930,  ..., -1.2813, -1.5879, -1.5879],
         [-0.1864, -0.4346, -0.5806,  ..., -1.4273, -1.5879, -1.3689],
         [-0.2156, -0.4930, -0.6536,  ..., -1.4711, -1.5733, -1.4857],
         ...,
         [ 0.2077,  0.2661,  0.2953,  ..., -1.0331, -1.0039, -1.0331],
         [ 0.2369,  0.2515,  0.2807,  ..., -1.0623, -1.0331, -1.0769],
         [ 0.2807,  0.2661,  0.2807,  ..., -1.0623, -1.0477, -1.0915]],

        [[ 0.1089, -0.0412, -0.3264,  ..., -1.3769, -1.5570, -1.5420],
         [ 0.0939, -0.1913, -0.4314,  ..., -1.4369, -1.5570, -1.3319],
         [ 0.0488, -0.3564, -0.5065,  ..., -1.4519, -1.5420, -1.3919],
         ...,
         [-0.1613, -0.1463, -0.1163,  ..., -1.1968, -1.1818, -1.1968],
         [-0.1613, -0.1463, -0.1313,  ..., -1.1968, -1.2268, -1.2568],
         [-0.1313, -0.1163, -0.1313,  ..., -1.2118, -1.2268, -1.2718]],

        [[-0.5559, -0.5559, -0.4706,  ..., -1.1532, -1.2954, -1.2669],
         [-0.5844, -0.6697, -0.3568,  ..., -1

In [23]:
model.load_state_dict(torch.load('/root/myprojectishere/myproject/model_weights_maybe_this_works.pth'))

<All keys matched successfully>

In [46]:
z=model.beam_search_caption_generation(img)

In [55]:
z

[836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 269,
 269,
 269,
 269,
 269,
 269,
 269,
 269,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836,
 836]

In [44]:
from transformers import CLIPTokenizer

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")


In [64]:
token_ids = [] # example greedy output

caption = tokenizer.decode(token_ids, skip_special_tokens=True)
print("Caption:", caption)


Caption: nicole
