In [25]:
import sys
from pathlib import Path
notebook_path = Path.cwd()
sys.path.insert(0, str(notebook_path))

from ml_engine.models.teacher.grounding_dino_lora import load_grounding_dino_with_lora
import sys
from pathlib import Path
from typing import Dict, Optional, List
import logging
import torch
from torch import nn
from ml_engine.training.peft_utils import (
    verify_freezing, save_lora_adapters, apply_lora, load_lora_model
)
from groundingdino.util.slconfig import SLConfig
from groundingdino.models import build_model
from groundingdino.util.utils import clean_state_dict

from core.constants import DEFAULT_DINO_LORA_CONFIG

logger = logging.getLogger(__name__)

In [None]:
base_checkpoint = "data/models/pretrained/groundingdino_swint_ogc.pth"

lora_config = {
    "enabled": True,
    "r": 16,
    "lora_alpha": 32,
    "target_modules": [
        # vision-focused fine-tuning
        "value_proj", "output_proj", "v_proj", "out_v_proj"
        # For text-focused fine-tuning
        "out_proj", "l_proj", "out_l_proj"
        # Focus cross-model fusion
        "v_proj", "l_proj", "values_v_proj", "values_l_proj"
    ],
    "lora_dropout": 0.1,
    "bias": "none",
    "task_type": "FEATURE_EXTRACTION"
}

bert_model_path = "data/models/pretrained/bert-base-uncased"


In [34]:
def load_base_model(checkpoint_path: str, bert_model_path: str):
    """
    Load pretrained Grounding DINO model.
    
    This method loads the official GroundingDINO model WITHOUT modification.
    Grounding DINO is an open-vocabulary model - it has NO fixed num_classes.
    
    Args:
        checkpoint_path: Path to pretrained checkpoint (.pth file)
    
    Returns:
        Loaded GroundingDINO model
        
    Raises:
        FileNotFoundError: If checkpoint or config file not found
        RuntimeError: If model building or loading fails
    """
    # Add GroundingDINO to path

    # Load config
    config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"

    args = SLConfig.fromfile(str(config_file))

    if bert_model_path:
        bert_path = Path(bert_model_path)
        if not bert_path.exists():
            raise FileNotFoundError(
                f"Local BERT model not found: {bert_path}\n"
            )
        # Set bert_base_uncased_path, NOT text_encoder_type!
        args.bert_base_uncased_path = str(bert_path.absolute())
    else:
        # Ensure bert_base_uncased_path is None for online mode
        if not hasattr(args, 'bert_base_uncased_path'):
            args.bert_base_uncased_path = None
    args.aux_loss = True
    
    try:
        model = build_model(args)
    except Exception as e:
        raise RuntimeError(f"Failed to build Grounding DINO model: {e}") from e

    # Load pretrained checkpoint
    checkpoint_path = Path(checkpoint_path)
    if not checkpoint_path.exists():
        raise FileNotFoundError(
            f"Checkpoint not found: {checkpoint_path}\n"
            f"Download pretrained weights from:\n"
            f"  https://github.com/IDEA-Research/GroundingDINO/releases"
        )

    try:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
    except Exception as e:
        raise RuntimeError(f"Failed to load checkpoint: {e}") from e

    # Extract state dict (handle different formats)
    if isinstance(checkpoint, dict):
        if 'model' in checkpoint:
            state_dict = checkpoint['model']
            logger.info("Checkpoint contains: epoch=%s", checkpoint.get('epoch', 'N/A'))
        else:
            state_dict = checkpoint
    else:
        state_dict = checkpoint

    # Clean and load state dict
    try:
        state_dict = clean_state_dict(state_dict)
        msg = model.load_state_dict(state_dict, strict=False)
    except Exception as e:
        raise RuntimeError(f"Failed to load state dict: {e}") from e

    # Report loading results
    if msg.missing_keys:
        logger.warning("Missing keys (%d): %s...", len(msg.missing_keys), msg.missing_keys[:5])

    if msg.unexpected_keys:
        logger.warning("Unexpected keys (%d): %s...", len(msg.unexpected_keys), msg.unexpected_keys[:5])

    if not msg.missing_keys and not msg.unexpected_keys:
        logger.info("All keys matched perfectly")

    # Verify model is in correct mode
    model.eval()  # Start in eval mode (will be set to train by trainer)

    return model

In [35]:
base_model = load_base_model(base_checkpoint, bert_model_path)

use local bert model path: /root/coding/platform/Grounded-Segment-Anything/data/models/pretrained/bert-base-uncased


Unexpected keys (2): ['label_enc.weight', 'bert.embeddings.position_ids']...


In [31]:
final_model = load_grounding_dino_with_lora(
    base_checkpoint=base_checkpoint,
    lora_config=lora_config,
    bert_model_path=bert_model_path
)



use local bert model path: /root/coding/platform/Grounded-Segment-Anything/data/models/pretrained/bert-base-uncased


Unexpected keys (2): ['label_enc.weight', 'bert.embeddings.position_ids']...
 Non-LoRA trainable parameters found: ['base_model.model.transformer.decoder.bbox_embed.0.layers.0.weight', 'base_model.model.transformer.decoder.bbox_embed.0.layers.0.bias', 'base_model.model.transformer.decoder.bbox_embed.0.layers.1.weight', 'base_model.model.transformer.decoder.bbox_embed.0.layers.1.bias', 'base_model.model.transformer.decoder.bbox_embed.0.layers.2.weight']...


In [32]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params:,} || "
        f"all params: {all_param:,} || "
        f"trainable%: {100 * trainable_params / all_param:.2f}%"
    )

In [33]:
print_trainable_parameters(final_model)

trainable params: 820,740 || all params: 173,527,810 || trainable%: 0.47%


In [12]:
for name, param in final_model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.shape}")

model.base_model.model.transformer.encoder.text_layers.0.self_attn.out_proj.lora_A.default.weight: torch.Size([16, 256])
model.base_model.model.transformer.encoder.text_layers.0.self_attn.out_proj.lora_B.default.weight: torch.Size([256, 16])
model.base_model.model.transformer.encoder.text_layers.1.self_attn.out_proj.lora_A.default.weight: torch.Size([16, 256])
model.base_model.model.transformer.encoder.text_layers.1.self_attn.out_proj.lora_B.default.weight: torch.Size([256, 16])
model.base_model.model.transformer.encoder.text_layers.2.self_attn.out_proj.lora_A.default.weight: torch.Size([16, 256])
model.base_model.model.transformer.encoder.text_layers.2.self_attn.out_proj.lora_B.default.weight: torch.Size([256, 16])
model.base_model.model.transformer.encoder.text_layers.3.self_attn.out_proj.lora_A.default.weight: torch.Size([16, 256])
model.base_model.model.transformer.encoder.text_layers.3.self_attn.out_proj.lora_B.default.weight: torch.Size([256, 16])
model.base_model.model.transform

In [29]:
# Find all modules with "proj" in their name
print("=== All projection layers ===")
for name, module in base_model.named_modules():
    if "proj" in name.lower():
        print(f"{name}: {module}")

=== All projection layers ===
transformer.encoder.layers.0.self_attn.value_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.0.self_attn.output_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.1.self_attn.value_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.1.self_attn.output_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.2.self_attn.value_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.2.self_attn.output_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.3.self_attn.value_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.3.self_attn.output_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.4.self_attn.value_proj: Linear(in_features=256, out_features=256, bias=True)
transformer.encoder.layers.4.s