# Grad-ECLIP-based Fine-grained Fine-tuning of CLIP

This notebook implements the fine-grained fine-tuning approach using Grad-ECLIP as described in the paper. The method combines global contrastive loss with local focal loss to enhance CLIP's fine-grained understanding capabilities.

## Methodology Overview:
1. **Global Loss**: Standard CLIP contrastive learning for instance-level alignment
2. **Local Loss**: Fine-grained region-text matching using Grad-ECLIP explanations
3. **Dense Features**: Extract spatial features from modified ViT encoder
4. **Phrase Extraction**: Use NLTK to extract "adjective + noun" concepts
5. **Region-Text Alignment**: Use Grad-ECLIP heat maps for automatic alignment

In [45]:
!pip install nltk
!pip install pycocotools



In [46]:
import nltk
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger')
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('maxent_ne_chunker_tab')
nltk.download("words")

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package maxent_ne_chunker_tab to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package maxent_ne_chunker_tab is already up-to-date!
[nltk_data] Downloading package words to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package words is already up-to-date!


True

In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import nltk
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag
import clip
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from transformers import CLIPModel, CLIPProcessor
import re
import time
from torchvision.transforms import Resize
from pathlib import Path

# Download required NLTK data
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')
import os


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

print(os.getcwd())

SCRIPT_DIR = Path(os.getcwd())
   
check_point_path = SCRIPT_DIR/ "checkpoints"
check_point_path.mkdir(parents=True, exist_ok=True)

Using device: cuda
/home/infres/pmbathe-24/Projet-IA-Fairness/Grad_ECLIP


[nltk_data] Downloading package punkt to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## 1. Modified CLIP Architecture for Dense Features

We modify the last transformer layer of the ViT encoder to extract dense spatial features by keeping projection and norm layers while discarding self-attention.

In [48]:
class ModifiedViTEncoder(nn.Module):
    """Modified ViT encoder to extract dense spatial features"""
    
    def __init__(self, original_clip_model):
        super().__init__()
        self.visual = original_clip_model.visual
        self.original_forward = self.visual.forward
        
        # Modify the last transformer layer
        last_layer = self.visual.transformer.resblocks[-1]
        self.modified_last_layer = ModifiedTransformerBlock(last_layer)
        # Replace the last layer in the transformer
        self.visual.transformer.resblocks = nn.ModuleList(
            list(self.visual.transformer.resblocks[:-1]) + [self.modified_last_layer]
        )
        
    def forward(self, x, return_dense=False):
        if return_dense:
            return self.forward_with_dense_features(x)
        else:
            return self.visual(x)
    
    def forward_with_dense_features(self, x):
        # Process through visual encoder up to the last layer
        x = self.visual.conv1(x)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = x.permute(0, 2, 1)
        x = torch.cat([self.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
        x = x + self.visual.positional_embedding.to(x.dtype)
        x = self.visual.ln_pre(x)
        
        # Process through transformer blocks except the last one
        for block in self.visual.transformer.resblocks[:-1]:
            x = block(x)
        
        # Process through modified last layer to get dense features
        x, dense_features = self.modified_last_layer(x, return_dense=True)
        
        # Get global features
        global_features = x[:, 0, :]  # CLS token
        global_features = self.visual.ln_post(global_features)
        if self.visual.proj is not None:
            global_features = global_features @ self.visual.proj
            
        return global_features, dense_features

class ModifiedTransformerBlock(nn.Module):
    """Modified transformer block that can output dense spatial features"""
    
    def __init__(self, original_block):
        super().__init__()
        self.ln_1 = original_block.ln_1
        self.ln_2 = original_block.ln_2
        self.mlp = original_block.mlp
        self.attn = original_block.attn
        
    def forward(self, x, return_dense=False):
        if return_dense:
            # For dense features, skip attention for spatial tokens, only apply norm
            dense_x = x[:, 1:, :]  # Remove CLS token for dense features
            dense_features = self.ln_1(dense_x)  # Apply norm only
            
            # Regular forward for CLS token and full sequence
            attn_out = self.attn(self.ln_1(x), self.ln_1(x), self.ln_1(x))[0]  # Self-attention with q,k,v
            x = x + attn_out
            x = x + self.mlp(self.ln_2(x))
            
            return x, dense_features
        else:
            # Regular transformer block forward
            attn_out = self.attn(self.ln_1(x), self.ln_1(x), self.ln_1(x))[0]
            x = x + attn_out
            x = x + self.mlp(self.ln_2(x))
            return x

## 2. Phrase Extraction using NLTK

Extract object concepts from captions using "adjective + noun" patterns as specified in the methodology.

In [49]:
class PhraseExtractor:
    """Extract phrases containing object concepts using NLTK"""
    
    def __init__(self):
        pass
    
    def extract_phrases(self, caption, max_phrases=10):
        """
        Extract phrases following 'adjective + noun' pattern from caption
        Args:
            caption: Input text caption
            max_phrases: Maximum number of phrases to extract
        Returns:
            List of extracted phrases
        """
        # Tokenize and tag parts of speech
        tokens = word_tokenize(caption.lower())
        pos_tags = pos_tag(tokens)
        
        phrases = []
        
        # Extract individual nouns
        for word, pos in pos_tags:
            if pos.startswith('NN') and len(word) > 2:  # Noun
                phrases.append(word)
        
        # Extract adjective + noun patterns
        for i in range(len(pos_tags) - 1):
            word1, pos1 = pos_tags[i]
            word2, pos2 = pos_tags[i + 1]
            
            # Adjective + Noun pattern
            if pos1.startswith('JJ') and pos2.startswith('NN'):
                phrase = f"{word1} {word2}"
                phrases.append(phrase)
        
        # Remove duplicates and limit number
        phrases = list(set(phrases))[:max_phrases]
        
        # Ensure we have at least some phrases
        if not phrases:
            # Fallback to any nouns if no patterns found
            phrases = [word for word, pos in pos_tags if pos.startswith('NN')][:max_phrases]
        
        return phrases

# Test phrase extraction
extractor = PhraseExtractor()
test_caption = "a dog in a black car waiting for traffic lights"
extracted_phrases = extractor.extract_phrases(test_caption)
print(f"Caption: {test_caption}")
print(f"Extracted phrases: {extracted_phrases}")

Caption: a dog in a black car waiting for traffic lights
Extracted phrases: ['lights', 'traffic', 'black car', 'dog', 'car']


## 3. Grad-ECLIP Implementation for Heat Map Generation

Implement the Grad-ECLIP method to generate explanation heat maps for region-text alignment.

In [50]:
import os
import cv2
import math
import clip
import json
import numpy as np
from clip import tokenize
import torch
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize
import torch.nn.functional as F
from skimage.transform import resize as np_resize
from transformers import CLIPTokenizerFast
from tqdm import tqdm
import nltk
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag

# Download required NLTK data
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Transform for image processing
_transform = Compose([
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

def imgprocess_keepsize(img, patch_size=[16, 16], scale_factor=1):
    w, h = img.size
    ph, pw = patch_size
    nw = int(w * scale_factor / pw + 0.5) * pw
    nh = int(h * scale_factor / ph + 0.5) * ph

    ResizeOp = Resize((nh, nw), interpolation=InterpolationMode.BICUBIC)
    img = ResizeOp(img).convert("RGB")
    return _transform(img)

def attention_layer(q, k, v, num_heads=1, attn_mask=None):
    """Compute 'Scaled Dot Product Attention'"""
    tgt_len, bsz, embed_dim = q.shape
    head_dim = embed_dim // num_heads
    scaling = float(head_dim) ** -0.5
    q = q * scaling
    
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    attn_output_weights = torch.bmm(q, k.transpose(1, 2))
    if attn_mask is not None:
        attn_output_weights += attn_mask
    attn_output_weights = F.softmax(attn_output_weights, dim=-1)
    attn_output_heads = torch.bmm(attn_output_weights, v)
    assert list(attn_output_heads.size()) == [bsz * num_heads, tgt_len, head_dim]
    attn_output = attn_output_heads.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, -1)
    attn_output_weights = attn_output_weights.sum(dim=1) / num_heads
    return attn_output, attn_output_weights

def clip_encode_dense(x, clipmodel):
    """Dense encoding following the exact Grad-ECLIP implementation"""
    vision_width = clipmodel.visual.transformer.width
    vision_heads = vision_width // 64
    clip_inres = clipmodel.visual.input_resolution
    clip_ksize = clipmodel.visual.conv1.kernel_size
    
    # modified from CLIP
    x = x.half()
    x = clipmodel.visual.conv1(x)  
    feah, feaw = x.shape[-2:]

    x = x.reshape(x.shape[0], x.shape[1], -1) 
    x = x.permute(0, 2, 1) 
    class_embedding = clipmodel.visual.class_embedding.to(x.dtype)

    x = torch.cat([class_embedding + torch.zeros(x.shape[0], 1, x.shape[-1]).to(x), x], dim=1)

    pos_embedding = clipmodel.visual.positional_embedding.to(x.dtype)
    tok_pos, img_pos = pos_embedding[:1, :], pos_embedding[1:, :]
    pos_h = clip_inres // clip_ksize[0]
    pos_w = clip_inres // clip_ksize[1]
    assert img_pos.size(0) == (pos_h * pos_w), f"the size of pos_embedding ({img_pos.size(0)}) does not match resolution shape pos_h ({pos_h}) * pos_w ({pos_w})"
    img_pos = img_pos.reshape(1, pos_h, pos_w, img_pos.shape[1]).permute(0, 3, 1, 2)
    img_pos = torch.nn.functional.interpolate(img_pos, size=(feah, feaw), mode='bicubic', align_corners=False)
    img_pos = img_pos.reshape(1, img_pos.shape[1], -1).permute(0, 2, 1)
    pos_embedding = torch.cat((tok_pos[None, ...], img_pos), dim=1)
    x = x + pos_embedding
    x = clipmodel.visual.ln_pre(x)
    
    x = x.permute(1, 0, 2)  # NLD -> LND
    x_in = torch.nn.Sequential(*clipmodel.visual.transformer.resblocks[:-1])(x)

    ##################
    # LastTR.attention
    targetTR = clipmodel.visual.transformer.resblocks[-1]
    x_before_attn = targetTR.ln_1(x_in)
    
    linear = torch._C._nn.linear    
    q, k, v = linear(x_before_attn, targetTR.attn.in_proj_weight, targetTR.attn.in_proj_bias).chunk(3, dim=-1)
    attn_output, attn = attention_layer(q, k, v, 1) #vision_heads
    x_after_attn = linear(attn_output, targetTR.attn.out_proj.weight, targetTR.attn.out_proj.bias)
    
    x = x_after_attn + x_in
    x_out = x + targetTR.mlp(targetTR.ln_2(x))

    x = x_out.permute(1, 0, 2)  # LND -> NLD
    x = clipmodel.visual.ln_post(x)
    x = x @ clipmodel.visual.proj
    
    ## ==== get lastv ==============
    with torch.no_grad():
        qkv = torch.stack((q, k, v), dim=0)
        qkv = linear(qkv, targetTR.attn.out_proj.weight, targetTR.attn.out_proj.bias)
        q_out, k_out, v_out = qkv[0], qkv[1], qkv[2]

        v_final = v_out + x_in
        v_final = v_final + targetTR.mlp(targetTR.ln_2(v_final))
        v_final = v_final.permute(1, 0, 2)
        v_final = clipmodel.visual.ln_post(v_final)
        v_final = v_final @ clipmodel.visual.proj
    ##############
    
    return x, v_final[:,1:], x_in, v, q_out, k_out, attn, att_output, (feah, feaw)

def grad_eclip(c, q_out, k_out, v, att_output, map_size, withksim=True):
    """Generate Grad-ECLIP heat map"""
    D = k_out.shape[-1]
    ## gradient on last attention output
    grad = torch.autograd.grad(
        c,
        att_output,
        retain_graph=True)[0]
    grad = grad.detach()
    grad_cls = grad[:1,0,:]
    if withksim:
        q_cls = q_out[:1,0,:]
        k_patch = k_out[1:,0,:]
        q_cls = F.normalize(q_cls, dim=-1)
        k_patch = F.normalize(k_patch, dim=-1)
        cosine_qk = (q_cls * k_patch).sum(-1) 
        cosine_qk = (cosine_qk-cosine_qk.min()) / (cosine_qk.max()-cosine_qk.min())
        emap_lastv = F.relu_((grad_cls * v[1:,0,:] * cosine_qk[:,None]).detach().sum(-1)) # 
    else:
        emap_lastv = F.relu_((grad_cls * v[1:,0,:]).detach().sum(-1)) 
    return emap_lastv.reshape(*map_size)

def generate_hm(img, txt_embedding, resize, clipmodel, hm_type="eclip_gt"):
    """Generate heat map for image-text pair"""
    if isinstance(img, torch.Tensor):
        # Convert tensor back to PIL for processing
        img_pil = T.ToPILImage()(img.cpu())
        img_keepsized = imgprocess_keepsize(img_pil).to(device).unsqueeze(0)
    else:
        img_keepsized = imgprocess_keepsize(img).to(device).unsqueeze(0)
    
    outputs, v_final, last_input, v, q_out, k_out,\
        attn, att_output, map_size = clip_encode_dense(img_keepsized, clipmodel)
    img_embedding = F.normalize(outputs[:,0], dim=-1)
    cosines = (img_embedding @ txt_embedding.T)[0]

    emap = [grad_eclip(c, q_out, k_out, v, att_output, map_size, withksim=True) for c in cosines]
    emap = torch.stack(emap, dim=0).sum(0)  
    emap -= emap.min()
    emap /= emap.max()
    emap = resize(emap.unsqueeze(0))[0]
    return emap

class GradECLIP:
    """Grad-ECLIP implementation using the exact methodology"""
    
    def __init__(self, clipmodel, device):
        self.clipmodel = clipmodel
        self.device = device
    
    def generate_heatmap_for_phrase(self, image, phrase_text, target_size):
        """
        Generate heat map for a single phrase using Grad-ECLIP
        Args:
            image: PIL Image or tensor
            phrase_text: String phrase
            target_size: (H, W) target size for heat map
        Returns:
            Heat map tensor of shape [H, W]
        """
        # Process text
        text_tokens = clip.tokenize([phrase_text]).to(self.device)
        text_embedding = self.clipmodel.encode_text(text_tokens)
        text_embedding = F.normalize(text_embedding, dim=-1)
        
        # Create resize transform
        resize_transform = Resize(target_size)
        
        # Generate heat map
        heatmap = generate_hm(image, text_embedding, resize_transform, self.clipmodel)
        
        return heatmap
    
    def generate_heatmaps_batch(self, images, phrases_list, spatial_size):
        """
        Generate heat maps for batch of images and phrases
        Args:
            images: Batch of images [B, C, H, W]
            phrases_list: List of phrases for each image
            spatial_size: (H, W) spatial dimensions
        Returns:
            Heat maps [B, max_phrases, H, W]
        """
        batch_size = images.shape[0]
        max_phrases = max(len(phrases) for phrases in phrases_list)
        
        all_heatmaps = []
        
        for b in range(batch_size):
            image = images[b]  # [C, H, W]
            phrases = phrases_list[b]
            
            batch_heatmaps = []
            
            for n in range(max_phrases):
                if n < len(phrases) and phrases[n].strip():
                    # Generate heatmap for this phrase
                    heatmap = self.generate_heatmap_for_phrase(
                        image, phrases[n], spatial_size
                    )
                    batch_heatmaps.append(heatmap)
                else:
                    # Empty heatmap for padded phrases
                    heatmap = torch.zeros(spatial_size, device=self.device)
                    batch_heatmaps.append(heatmap)
            
            all_heatmaps.append(torch.stack(batch_heatmaps))
        
        return torch.stack(all_heatmaps)  # [B, max_phrases, H, W]

Using device: cuda


[nltk_data] Downloading package punkt to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/infres/pmbathe-24/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## 4. Loss Functions Implementation

Implement both global contrastive loss and local focal loss as described in the methodology.

In [51]:
class GlobalContrastiveLoss(nn.Module):
    """Global contrastive loss for CLIP fine-tuning"""
    
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = nn.Parameter(torch.tensor(temperature))
    
    def forward(self, image_features, text_features):
        """
        Compute global contrastive loss
        Args:
            image_features: [B, D] normalized image embeddings
            text_features: [B, D] normalized text embeddings
        """
        batch_size = image_features.shape[0]
        
        # Normalize features
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        
        # Compute similarity matrix
        logits_per_image = image_features @ text_features.t() / self.temperature
        logits_per_text = text_features @ image_features.t() / self.temperature
        
        # Create labels
        labels = torch.arange(batch_size, device=image_features.device)
        
        # Compute cross-entropy loss
        loss_img = F.cross_entropy(logits_per_image, labels)
        loss_txt = F.cross_entropy(logits_per_text, labels)
        
        # Average the two losses
        global_loss = (loss_img + loss_txt) / 2
        
        return global_loss

class FixedLocalFocalLoss(nn.Module):
    """Fixed local focal loss with proper tensor handling"""
    
    def __init__(self, alpha=2.0):
        super().__init__()
        self.alpha = alpha
    
    def forward(self, region_features, phrase_features):
        """
        Compute focal loss for region-phrase matching
        Args:
            region_features: [B, N, D] region embeddings  
            phrase_features: [B, N, D] phrase embeddings
        """
        # Ensure proper shapes - fix any dimension issues silently
        if len(region_features.shape) == 4:
            # Pool over extra dimension [B, N, H, D] -> [B, N, D]
            region_features = region_features.mean(dim=2)
        
        if len(region_features.shape) != 3 or len(phrase_features.shape) != 3:
            return torch.tensor(0.0, device=region_features.device, requires_grad=True)
        
        B, N, D = region_features.shape
        B_p, N_p, D_p = phrase_features.shape
        
        if B != B_p or N != N_p:
            return torch.tensor(0.0, device=region_features.device, requires_grad=True)
        
        # Normalize features
        region_features = F.normalize(region_features, dim=-1)
        phrase_features = F.normalize(phrase_features, dim=-1)
        
        total_loss = 0.0
        valid_pairs = 0
        
        for b in range(B):
            for t in range(N):
                # Check if phrase is valid (not all zeros)
                if torch.norm(phrase_features[b, t]) > 1e-6:
                    # Positive pair loss
                    pos_sim = torch.cosine_similarity(region_features[b, t], phrase_features[b, t], dim=0)
                    pos_sim = torch.clamp(pos_sim, min=1e-8, max=1-1e-8)
                    pos_loss = -(1 - pos_sim) ** self.alpha * torch.log(pos_sim)
                    total_loss += pos_loss
                    valid_pairs += 1
                    
                    # Negative pairs loss  
                    for t_prime in range(N):
                        if t_prime != t and torch.norm(phrase_features[b, t_prime]) > 1e-6:
                            neg_sim = torch.cosine_similarity(region_features[b, t], phrase_features[b, t_prime], dim=0)
                            neg_sim = torch.clamp(neg_sim, min=1e-8, max=1-1e-8)
                            neg_loss = -neg_sim ** self.alpha * torch.log(1 - neg_sim)
                            total_loss += neg_loss
        
        if valid_pairs > 0:
            return total_loss / valid_pairs
        else:
            return torch.tensor(0.0, device=region_features.device, requires_grad=True)

## 5. Complete Fine-tuning Model

Integrate all components into a complete fine-tuning model that combines global and local losses.

In [52]:
class GradECLIPFineTuner(nn.Module):
    """Complete Grad-ECLIP fine-tuning model"""
    
    def __init__(self, clip_model_name="ViT-B/32"):
        super().__init__()
        
        # Load pre-trained CLIP
        self.clip_model, self.preprocess = clip.load(clip_model_name, device=device)
        
        # Create modified visual encoder for dense features
        self.visual_encoder = ModifiedViTEncoder(self.clip_model)
        self.text_encoder = self.clip_model.encode_text
        
        # Initialize components
        self.phrase_extractor = PhraseExtractor()
        
        # Loss functions
        self.global_loss_fn = GlobalContrastiveLoss()
        self.local_loss_fn = FixedLocalFocalLoss()
        
        # Get spatial dimensions based on model
        if "ViT-B/32" in clip_model_name:
            self.spatial_size = (7, 7)
        elif "ViT-B/16" in clip_model_name:
            self.spatial_size = (14, 14)
        else:
            self.spatial_size = (7, 7)
    
    def forward(self, images, texts):
        """Forward pass with fixed tensor shapes"""
        batch_size = images.shape[0]
        
        try:
            # 1. Extract global and dense features
            global_image_features, dense_features = self.visual_encoder(images.float(), return_dense=True)
            
            # 2. Encode full text captions for global loss
            text_tokens = clip.tokenize(texts, truncate=True).to(device)
            global_text_features = self.text_encoder(text_tokens)
            
            # 3. Compute global loss
            global_loss = self.global_loss_fn(global_image_features, global_text_features)
            
            # 4. Extract phrases and encode them
            all_phrases = []
            max_phrases = 0
            
            for text in texts:
                phrases = self.phrase_extractor.extract_phrases(text, max_phrases=3)
                all_phrases.append(phrases)
                max_phrases = max(max_phrases, len(phrases))
            
            if max_phrases == 0:
                return global_loss, torch.tensor(0.0, device=device), global_loss
            
            # Pad phrases to same length
            phrase_features_list = []
            valid_phrases_mask = []
            
            for phrases in all_phrases:
                batch_phrase_features = []
                batch_mask = []
                
                for i in range(max_phrases):
                    if i < len(phrases):
                        phrase_tokens = clip.tokenize([phrases[i]], truncate=True).to(device)
                        phrase_feat = self.text_encoder(phrase_tokens).squeeze(0)
                        batch_phrase_features.append(phrase_feat)
                        batch_mask.append(1.0)
                    else:
                        phrase_feat = torch.zeros_like(global_text_features[0])
                        batch_phrase_features.append(phrase_feat)
                        batch_mask.append(0.0)
                
                phrase_features_list.append(torch.stack(batch_phrase_features))
                valid_phrases_mask.append(batch_mask)
            
            phrase_features = torch.stack(phrase_features_list)  # [B, N, D_text]
            
            # 5. Generate region features directly without intermediate heatmaps
            B, HW, D_dense = dense_features.shape
            _, N, D_text = phrase_features.shape
            
            # Project dense features to text space if needed
            if D_dense != D_text:
                if not hasattr(self, 'dense_to_text_proj'):
                    self.dense_to_text_proj = nn.Linear(D_dense, D_text).to(device)
                dense_features_proj = self.dense_to_text_proj(dense_features)  # [B, HW, D_text]
            else:
                dense_features_proj = dense_features
            
            # Compute attention-weighted region features directly
            dense_norm = F.normalize(dense_features_proj, dim=-1)  # [B, HW, D_text]
            phrase_norm = F.normalize(phrase_features, dim=-1)     # [B, N, D_text]
            
            # Compute attention scores: [B, N, HW]  
            attention_scores = torch.bmm(phrase_norm, dense_norm.transpose(1, 2))
            attention_weights = F.softmax(attention_scores, dim=-1)  # [B, N, HW]
            
            # Extract region features: [B, N, D_text]
            region_features = torch.bmm(attention_weights, dense_features_proj)
            
            # Apply masking for padded phrases
            for b in range(B):
                for n in range(N):
                    if valid_phrases_mask[b][n] == 0.0:
                        region_features[b, n] = torch.zeros_like(region_features[b, n])
                        phrase_features[b, n] = torch.zeros_like(phrase_features[b, n])
            
            # 6. Compute local loss - now with proper [B, N, D] shapes
            local_loss = self.local_loss_fn(region_features, phrase_features)
            
            # 7. Combine losses with weighting
            total_loss = global_loss + 0.5 * local_loss  # Weight local loss
            
            return global_loss, local_loss, total_loss
            
        except Exception as e:
            # Silent fallback
            return (torch.tensor(0.0, device=device, requires_grad=True), 
                    torch.tensor(0.0, device=device, requires_grad=True),
                    torch.tensor(0.0, device=device, requires_grad=True))

## 6. Dataset and Training Setup

Set up the training dataset (Conceptual Captions 3M) and training configuration following the paper specifications.

In [53]:
class ConceptualCaptionsDataset(Dataset):
    """Dataset class for Conceptual Captions 3M"""
    
    def __init__(self, image_paths, captions, transform=None):
        self.image_paths = image_paths
        self.captions = captions
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        try:
            image = Image.open(self.image_paths[idx]).convert('RGB')
            if self.transform:
                image = self.transform(image)
        except:
            # Return a dummy image if loading fails
            image = torch.zeros(3, 224, 224)
        
        caption = self.captions[idx]
        return image, caption

# Training configuration following the paper
training_config = {
    'batch_size': 64,  # 64 per GPU (paper uses 2 RTX 6000 Ada)
    'learning_rate': 1e-5,
    'weight_decay': 0.1,
    'num_epochs': 10,
    'image_size': 224,
    'warmup_steps': 1000,
    'save_every': 1000,
}

# Data transforms
transform = transforms.Compose([
    transforms.Resize((training_config['image_size'], training_config['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Training configuration:")
for key, value in training_config.items():
    print(f"  {key}: {value}")

Training configuration:
  batch_size: 64
  learning_rate: 1e-05
  weight_decay: 0.1
  num_epochs: 10
  image_size: 224
  warmup_steps: 1000
  save_every: 1000


## 7. Training Loop Implementation

Implement the complete training loop with proper optimization and logging.

In [54]:
def train_grad_eclip_model(model, train_dataloader, config):
    """
    Train the Grad-ECLIP fine-tuning model
    """
    # Optimizer setup
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # Learning rate scheduler with warmup
    def lr_schedule(step):
        if step < config['warmup_steps']:
            return step / config['warmup_steps']
        else:
            return 1.0
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
    
    # Training loop
    model.train()
    global_step = 0
    
    for epoch in range(config['num_epochs']):
        epoch_global_loss = 0.0
        epoch_local_loss = 0.0
        epoch_total_loss = 0.0
        num_batches = 0
        
        print(f"\nEpoch {epoch + 1}/{config['num_epochs']}")
        print("-" * 50)
        
        for batch_idx, (images, captions) in enumerate(train_dataloader):
            try:
                images = images.to(device)
                
                # Forward pass
                global_loss, local_loss, total_loss = model(images, captions)
                
                # Backward pass
                optimizer.zero_grad()
                total_loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                scheduler.step()
                
                # Accumulate losses
                epoch_global_loss += global_loss.item()
                epoch_local_loss += local_loss.item()
                epoch_total_loss += total_loss.item()
                num_batches += 1
                global_step += 1
                
                # Logging
                if batch_idx % 100 == 0:
                    print(f"Batch {batch_idx:4d} | Global: {global_loss.item():.4f} | "
                          f"Local: {local_loss.item():.4f} | Total: {total_loss.item():.4f}")
                
                # Save checkpoint
                if global_step % config['save_every'] == 0:
                    checkpoint = {
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'global_step': global_step,
                        'epoch': epoch,
                        'config': config
                    }
                    torch.save(checkpoint, check_point_path/ f'grad_eclip_checkpoint_step_{global_step}.pt')
                    print(f"Checkpoint saved at step {global_step}")
                
            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue
        
        # Epoch summary
        if num_batches > 0:
            avg_global_loss = epoch_global_loss / num_batches
            avg_local_loss = epoch_local_loss / num_batches
            avg_total_loss = epoch_total_loss / num_batches
            
            print(f"\nEpoch {epoch + 1} Summary:")
            print(f"  Average Global Loss: {avg_global_loss:.4f}")
            print(f"  Average Local Loss: {avg_local_loss:.4f}")
            print(f"  Average Total Loss: {avg_total_loss:.4f}")
        
        # Save epoch checkpoint
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'global_step': global_step,
            'epoch': epoch + 1,
            'config': config,
            'epoch_losses': {
                'global': avg_global_loss if num_batches > 0 else 0,
                'local': avg_local_loss if num_batches > 0 else 0,
                'total': avg_total_loss if num_batches > 0 else 0
            }
        }
        torch.save(checkpoint, check_point_path / f"grad_eclip_epoch_{epoch + 1}.pt")
    
    print("\nTraining completed!")
    return model

# For demonstration with dummy data
print("Training setup complete. Ready to train with actual dataset.")
print("Note: Replace with actual Conceptual Captions 3M dataset for real training.")

Training setup complete. Ready to train with actual dataset.
Note: Replace with actual Conceptual Captions 3M dataset for real training.


## 9. MS COCO 2017 Fine-tuning Setup

Set up fine-tuning on MS COCO 2017 dataset following the paper's methodology. This uses the COCO train2017 split for training with image-caption pairs.

In [None]:
class GradECLIPFineTuner(nn.Module):
    """Working version with no inplace operations that break gradients"""
    
    def __init__(self, clip_model_name="ViT-B/16"):
        super().__init__()
        
        self.clip_model, self.preprocess = clip.load(clip_model_name, device=device)
        self.clip_model.float()
        
        self.visual_encoder = ModifiedViTEncoder(self.clip_model)
        self.text_encoder = self.clip_model.encode_text
        self.phrase_extractor = PhraseExtractor()
        
        # Loss functions
        self.global_loss_fn = GlobalContrastiveLoss()
        self.local_loss_fn = FixedLocalFocalLoss()
        
        if "ViT-B/32" in clip_model_name:
            self.spatial_size = (7, 7)
        elif "ViT-B/16" in clip_model_name:
            self.spatial_size = (14, 14)
        else:
            self.spatial_size = (7, 7)
    
    def forward(self, images, texts):
        batch_size = images.shape[0]
        
        try:
            # 1. Extract features
            global_image_features, dense_features = self.visual_encoder(images, return_dense=True)
            text_tokens = clip.tokenize(texts, truncate=True).to(device)
            global_text_features = self.text_encoder(text_tokens)
            
            # 2. Global loss
            global_loss = self.global_loss_fn(global_image_features, global_text_features)
            
            # 3. Extract phrases
            all_phrases = []
            max_phrases = 0
            for text in texts:
                phrases = self.phrase_extractor.extract_phrases(text, max_phrases=3)
                all_phrases.append(phrases)
                max_phrases = max(max_phrases, len(phrases))
        
            if max_phrases == 0:
                return global_loss, torch.tensor(0.0, device=device), global_loss
            
            # 4. Encode phrases - NO INPLACE OPERATIONS
            phrase_features_list = []
            valid_phrases_mask = torch.zeros(batch_size, max_phrases, device=device)
            
            for b, phrases in enumerate(all_phrases):
                batch_phrase_features = []
                
                for i in range(max_phrases):
                    if i < len(phrases):
                        phrase_tokens = clip.tokenize([phrases[i]], truncate=True).to(device)
                        phrase_feat = self.text_encoder(phrase_tokens).squeeze(0).float()
                        batch_phrase_features.append(phrase_feat)
                        valid_phrases_mask[b, i] = 1.0
                    else:
                        phrase_feat = torch.zeros_like(global_text_features[0], device=device)
                        batch_phrase_features.append(phrase_feat)
                
                phrase_features_list.append(torch.stack(batch_phrase_features))
            
            phrase_features = torch.stack(phrase_features_list).float() 
            
            B, HW, D_dense = dense_features.shape
            _, N, D_text = phrase_features.shape
            
            if D_dense != D_text:
                if not hasattr(self, 'proj'):
                    self.proj = nn.Linear(D_dense, D_text).to(device).float()
                dense_proj = self.proj(dense_features)
            else:
                dense_proj = dense_features
            
            region_features_list = []
            
            for n in range(N):
                phrase_n = phrase_features[:, n, :]  
                
                similarities = torch.sum(
                    F.normalize(dense_proj, dim=-1) * F.normalize(phrase_n.unsqueeze(1), dim=-1), 
                    dim=-1
                )
                attention = F.softmax(similarities, dim=-1) 
                
                # Weighted sum
                region_feat = torch.sum(dense_proj * attention.unsqueeze(-1), dim=1)
                region_features_list.append(region_feat)
            
            region_features = torch.stack(region_features_list, dim=1)  
            
            # Apply masking WITHOUT inplace operations
            # Create masked versions instead of modifying originals
            masked_region_features = region_features * valid_phrases_mask.unsqueeze(-1)
            masked_phrase_features = phrase_features * valid_phrases_mask.unsqueeze(-1)
            
            local_loss = self.local_loss_fn(masked_region_features, masked_phrase_features)
            total_loss = global_loss + 0.5 * local_loss
            
            return global_loss, local_loss, total_loss
            
        except Exception as e:
            print(f"Error in forward: {e}")
            return (torch.tensor(0.0, device=device, requires_grad=True), 
                    torch.tensor(0.0, device=device, requires_grad=True),
                    torch.tensor(0.0, device=device, requires_grad=True))

# Replace model
model = GradECLIPFineTuner("ViT-B/16").to(device).float()
print("✓ Model fixed - no more inplace operations!")

✓ Model fixed - no more inplace operations!


In [59]:
import json
import os
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader
import requests
from PIL import Image
import io
import torch
import torchvision.transforms as transforms

class MSCOCO2017Dataset(Dataset):
    """MS COCO 2017 dataset for fine-tuning"""
    
    def __init__(self, root_dir, annotation_file, transform=None, max_samples=None):
        """
        Args:
            root_dir: Path to COCO images directory (e.g., './Grad_ECLIP/data/coco/train2017')
            annotation_file: Path to annotations file (e.g., './Grad_ECLIP/data/coco/annotations/captions_train2017.json')
            transform: Image transformations
            max_samples: Maximum number of samples to use (for debugging)
        """
        self.root_dir = root_dir
        self.transform = transform
        
        # Load COCO annotations
        self.coco = COCO(annotation_file)
        
        # Get all image IDs that have captions
        self.img_ids = list(self.coco.imgs.keys())
        
        # Get all annotations (image-caption pairs)
        self.annotations = []
        for img_id in self.img_ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            
            img_info = self.coco.loadImgs(img_id)[0]
            img_path = os.path.join(self.root_dir, img_info['file_name'])
            
            # Add all captions for this image
            for ann in anns:
                self.annotations.append({
                    'image_path': img_path,
                    'caption': ann['caption'],
                    'image_id': img_id
                })
        
        # Limit samples if specified
        if max_samples and max_samples < len(self.annotations):
            self.annotations = self.annotations[:max_samples]
            
        print(f"Loaded {len(self.annotations)} image-caption pairs from MS COCO 2017")
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        
        try:
            # Load image
            if os.path.exists(ann['image_path']):
                image = Image.open(ann['image_path']).convert('RGB')
            else:
                # Create dummy image if file doesn't exist
                image = Image.new('RGB', (224, 224), color='black')
                
            if self.transform:
                image = self.transform(image)
                
        except Exception as e:
            print(f"Error loading image {ann['image_path']}: {e}")
            # Return dummy image
            image = torch.zeros(3, 224, 224)
        
        caption = ann['caption']
        return image, caption

def setup_coco_training(coco_root_dir, batch_size=32, max_samples=None):
    """
    Set up MS COCO 2017 training
    Args:
        coco_root_dir: Root directory containing COCO dataset
        batch_size: Training batch size
        max_samples: Maximum samples for training (None for full dataset)
    """
    
    # COCO dataset paths
    train_images_dir = os.path.join(coco_root_dir, 'train2017')
    train_annotations = os.path.join(coco_root_dir, 'annotations', 'captions_train2017.json')
    
    # Check if paths exist
    if not os.path.exists(train_images_dir):
        print(f"Warning: Training images directory not found: {train_images_dir}")
        print("Please download MS COCO 2017 dataset:")
        print("- Images: http://images.cocodataset.org/zips/train2017.zip")
        print("- Annotations: http://images.cocodataset.org/annotations/annotations_trainval2017.zip")
        return None
    
    if not os.path.exists(train_annotations):
        print(f"Warning: Annotations file not found: {train_annotations}")
        return None
    
    # Data transforms for COCO training
    coco_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create dataset
    train_dataset = MSCOCO2017Dataset(
        root_dir=train_images_dir,
        annotation_file=train_annotations,
        transform=coco_transform,
        max_samples=max_samples
    )
    
    # Create dataloader
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
    
    return train_loader

def train_on_coco(model, coco_root_dir, config=None, test_mode=False):
    """
    Train Grad-ECLIP model on MS COCO 2017
    Args:
        model: GradECLIPFineTuner model
        coco_root_dir: Path to COCO dataset root directory
        config: Training configuration
        test_mode: If True, use small sample for testing approach
    """
    
    if config is None:
        # Default configuration based on mode
        if test_mode:
            config = {
                'batch_size': 16,  # Smaller batch for testing
                'learning_rate': 1e-5,
                'weight_decay': 0.05,
                'num_epochs': 5,  # Few epochs for testing
                'warmup_steps': 500,  # Few warmup steps
                'save_every': 100,  # Save more frequently
                'max_samples': 10000,  # Small sample for testing
                'gradient_clip': 1.0
            }
        else:
            config = {
                'batch_size': 32,  # Full training batch size
                'learning_rate': 5e-6,  # Lower LR for fine-tuning
                'weight_decay': 0.05,
                'num_epochs': 5,  # Full training epochs
                'warmup_steps': 500,
                'save_every': 1000,
                'max_samples': None,  # Use full dataset
                'gradient_clip': 1.0
            }
    
    mode_str = "TEST MODE" if test_mode else "FULL TRAINING"
    print(f"Setting up MS COCO 2017 training - {mode_str}")
    print("Configuration:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    
    # Setup data loader
    train_loader = setup_coco_training(
        coco_root_dir, 
        batch_size=config['batch_size'],
        max_samples=config['max_samples']
    )
    
    if train_loader is None:
        print("Failed to setup COCO training data")
        return None
        
    # Setup optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    # Learning rate scheduler
    def lr_schedule(step):
        if step < config['warmup_steps']:
            return step / config['warmup_steps']
        else:
            return max(0.1, 1.0 - (step - config['warmup_steps']) / (len(train_loader) * config['num_epochs']))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
    
    # Training loopquick_test_training
    model.train()
    global_step = 0
    best_loss = float('inf')
    
    print(f"\nStarting {mode_str} on {len(train_loader)} batches...")
    print("=" * 60)
    

    
    for epoch in range(config['num_epochs']):
        epoch_losses = {'global': 0.0, 'local': 0.0, 'total': 0.0}
        num_batches = 0
        
        print(f"\nEpoch {epoch + 1}/{config['num_epochs']}")
        print("-" * 50)
        
        for batch_idx, (images, captions) in enumerate(train_loader):
            try:
                images = images.to(device)
                
                # Forward pass
                global_loss, local_loss, total_loss = model(images, captions)
                
                optimizer.zero_grad()
                total_loss.backward()
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip'])
                
                optimizer.step()
                scheduler.step()
                
                epoch_losses['global'] += global_loss.item()
                epoch_losses['local'] += local_loss.item()
                epoch_losses['total'] += total_loss.item()
                num_batches += 1
                global_step += 1
                
                log_freq = 10 if test_mode else 100
                if batch_idx % log_freq == 0:
                    lr = scheduler.get_last_lr()[0]
                    print(f"Batch {batch_idx:4d}/{len(train_loader)} | "
                          f"Global: {global_loss.item():.4f} | "
                          f"Local: {local_loss.item():.4f} | "
                          f"Total: {total_loss.item():.4f} | "
                          f"LR: {lr:.2e}")
                
                # Save checkpoint
                if global_step % config['save_every'] == 0:
                    avg_loss = epoch_losses['total'] / max(num_batches, 1)
                    
                    checkpoint = {
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'global_step': global_step,
                        'epoch': epoch,
                        'avg_loss': avg_loss,
                        'config': config,
                        'test_mode': test_mode
                    }
                    
                    suffix = "test" if test_mode else "full"
                    

                    checkpoint_path =  check_point_path/ f'grad_eclip_coco_{suffix}_step_{global_step}.pt'
                    torch.save(checkpoint, checkpoint_path)
                    print(f"Checkpoint saved: {checkpoint_path}")
                    
                    # Save best model
                    if avg_loss < best_loss:
                        best_loss = avg_loss
                        best_path = check_point_path/ f'grad_eclip_coco_{suffix}_best.pt'
                        torch.save(checkpoint, best_path)
                        print(f"Best model saved: {best_path}")
                
            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue
        
        # Epoch summary
        if num_batches > 0:
            for key in epoch_losses:
                epoch_losses[key] /= num_batches
            
            print(f"\nEpoch {epoch + 1} Summary:")
            print(f"  Average Global Loss: {epoch_losses['global']:.4f}")
            print(f"  Average Local Loss: {epoch_losses['local']:.4f}")
            print(f"  Average Total Loss: {epoch_losses['total']:.4f}")
        
        # Save epoch checkpoint
        suffix = "test" if test_mode else "full"
        epoch_checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'epoch': epoch + 1,
            'epoch_losses': epoch_losses,
            'config': config,
            'test_mode': test_mode
        }
        epoch_path = check_point_path / f'grad_eclip_coco_{suffix}_epoch_{epoch + 1}.pt'
        torch.save(epoch_checkpoint, epoch_path)
    
    print(f"\n{mode_str} completed successfully!")
    print(f"Best loss achieved: {best_loss:.4f}")
    
    if test_mode:
        print("\n" + "="*60)
        print("TEST MODE COMPLETED - Next steps:")
        print("1. Check the results and loss convergence")
        print("2. If satisfied, run full training:")
        print("   trained_model = train_on_coco(model, coco_root, test_mode=False)")
        print("="*60)
    
    return model

# Example usage functions
def quick_test_training(model, coco_root):
    """Quick test with small sample"""
    print("Starting QUICK TEST with small sample...")
    return train_on_coco(model, coco_root, test_mode=True)

def full_training(model, coco_root):
    """Full training with complete dataset"""
    print("Starting FULL TRAINING with complete dataset...")
    return train_on_coco(model, coco_root, test_mode=False)

# Example usage
print("MS COCO 2017 fine-tuning setup complete!")
print("\nTwo training modes available:")
print("\n1. QUICK TEST (recommended first):")
print("   - Small sample (1000 images)")
print("   - 2 epochs")
print("   - Quick feedback on approach")
print("   Usage: test_model = quick_test_training(model, coco_root)")
print("\n2. FULL TRAINING:")
print("   - Complete dataset")
print("   - 5 epochs")
print("   - Full fine-tuning")
print("   Usage: trained_model = full_training(model, coco_root)")
print("\nRecommended workflow:")
print("1. coco_root = '/path/to/coco/dataset'")
print("2. test_model = quick_test_training(model, coco_root)  # Test first")
print("3. trained_model = full_training(model, coco_root)    # Then full training")
print("\nNote: Download MS COCO 2017 dataset first:")

MS COCO 2017 fine-tuning setup complete!

Two training modes available:

1. QUICK TEST (recommended first):
   - Small sample (1000 images)
   - 2 epochs
   - Quick feedback on approach
   Usage: test_model = quick_test_training(model, coco_root)

2. FULL TRAINING:
   - Complete dataset
   - 5 epochs
   - Full fine-tuning
   Usage: trained_model = full_training(model, coco_root)

Recommended workflow:
1. coco_root = '/path/to/coco/dataset'
2. test_model = quick_test_training(model, coco_root)  # Test first
3. trained_model = full_training(model, coco_root)    # Then full training

Note: Download MS COCO 2017 dataset first:


In [60]:
model = quick_test_training(model,"../Grad_ECLIP/data/coco/")

Starting QUICK TEST with small sample...
Setting up MS COCO 2017 training - TEST MODE
Configuration:
  batch_size: 16
  learning_rate: 1e-05
  weight_decay: 0.05
  num_epochs: 5
  warmup_steps: 500
  save_every: 100
  max_samples: 10000
  gradient_clip: 1.0
loading annotations into memory...
Done (t=0.90s)
creating index...
index created!
Loaded 10000 image-caption pairs from MS COCO 2017

Starting TEST MODE on 625 batches...

Epoch 1/5
--------------------------------------------------
Batch    0/625 | Global: 2.7767 | Local: 0.4654 | Total: 3.0094 | LR: 2.00e-08
Batch   10/625 | Global: 2.7766 | Local: 0.4892 | Total: 3.0212 | LR: 2.20e-07
Batch   20/625 | Global: 2.7780 | Local: 0.4798 | Total: 3.0179 | LR: 4.20e-07
Batch   30/625 | Global: 2.7743 | Local: 0.4736 | Total: 3.0111 | LR: 6.20e-07
Batch   40/625 | Global: 2.7745 | Local: 0.4677 | Total: 3.0083 | LR: 8.20e-07
Batch   50/625 | Global: 2.7743 | Local: 0.4741 | Total: 3.0114 | LR: 1.02e-06
Batch   60/625 | Global: 2.7747 | 

KeyboardInterrupt: 