# Deliverable 4
----



# 1. Import and Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
HOME = os.getcwd()
print(f'HOME is {HOME}')
import sys
import torch
!rm -rf /content/GroundingDINO
!rm -rf /content/requirements.txt
!rm -rf /content/groundingdino_swint_ogc.pth
!rm -rf /content/weights

HOME is /content


In [None]:

%cd {HOME}
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd './GroundingDINO'
!pip install -r requirements.txt
print(f'HOME is {HOME}')

/content
Cloning into 'GroundingDINO'...
remote: Enumerating objects: 463, done.[K
remote: Counting objects: 100% (240/240), done.[K
remote: Compressing objects: 100% (105/105), done.[K
remote: Total 463 (delta 175), reused 135 (delta 135), pack-reused 223 (from 1)[K
Receiving objects: 100% (463/463), 12.87 MiB | 35.72 MiB/s, done.
Resolving deltas: 100% (241/241), done.
/content/GroundingDINO
HOME is /content


In [None]:
cuda_file = f"/{HOME}/GroundingDINO/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu"

# Replace deprecated value.type() usage
!sed -i '' 's/value\.type()/value.scalar_type()/g' "$cuda_file"
!sed -i '' 's/::detail::scalar_type(the_type)/the_type/g' "$cuda_file"
!sed -i '' 's/value\.scalar_type()\.is_cuda()/value.is_cuda()/g' "$cuda_file"

print("Fully patched CUDA file for modern PyTorch compatibility")


# === PATCH setup.py to return [] instead of None in get_extensions() ===
setup_file = f"/{HOME}/GroundingDINO/setup.py"

with open(setup_file, "r") as f:
    lines = f.readlines()

patched_lines = []
inside_else_block = False

for line in lines:
    # Detect the "else:" block under the CUDA check
    if line.strip() == "else:" and "CUDA" in "".join(patched_lines[-2:]):
        inside_else_block = True
        patched_lines.append(line)
        continue

    if inside_else_block:
        if "return None" in line:
            line = line.replace("return None", "return []  # patched for Colab")
            inside_else_block = False  # Only patch once
    patched_lines.append(line)

with open(setup_file, "w") as f:
    f.writelines(patched_lines)

print("Patched setup.py to return [] instead of None (avoids setup crash)")



sed: can't read s/value\.type()/value.scalar_type()/g: No such file or directory
sed: can't read s/::detail::scalar_type(the_type)/the_type/g: No such file or directory
sed: can't read s/value\.scalar_type()\.is_cuda()/value.is_cuda()/g: No such file or directory
Fully patched CUDA file for modern PyTorch compatibility
Patched setup.py to return [] instead of None (avoids setup crash)


In [None]:
# ─── MICRO-PATCH GroundingDINO/bertwarper.py ──────────────────────────
bert_file = f"/{HOME}/GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py"

import re, pathlib, textwrap

lines   = pathlib.Path(bert_file).read_text().splitlines(keepends=True)
output  = []
patched = False

pat = re.compile(r"^\s*cate_to_token_mask_list\s*=\s*\[\s*torch\.stack")

for i, line in enumerate(lines):

    # ── find the original one-liner list-comprehension ────────────────
    if not patched and pat.search(line):
        indent = re.match(r"^\s*", line).group(0)          # keep original indent

        patch = textwrap.indent(textwrap.dedent(f"""
            # ===== PATCHED: avoid empty TensorList crash =====
            max_seq_len = input_ids.shape[1]
            device      = input_ids.device
            safe_cate_to_token_mask_list = []
            for _mask_list in cate_to_token_mask_list:
                if len(_mask_list) == 0:
                    _mask_list = [torch.zeros((1, max_seq_len),
                                             dtype=torch.bool,
                                             device=device)]
                safe_cate_to_token_mask_list.append(
                    torch.stack(_mask_list, dim=0)
                )
            cate_to_token_mask_list = safe_cate_to_token_mask_list
            # ===== END PATCH ===========================================
        """), indent)

        output.append(patch)
        patched = True
        continue                      # ← skip the original risky line

    # skip the second line of the old list-comprehension
    if patched and "for cate_to_token_mask_list" in line:
        continue

    output.append(line)

# write back
pathlib.Path(bert_file).write_text("".join(output))
print("bertwarper.py patched – empty-TensorList crash eliminated")


bertwarper.py patched – empty-TensorList crash eliminated


In [None]:
!pip install -q -e . -v

Obtaining file:///content/GroundingDINO
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: groundingdino
  Running setup.py develop for groundingdino


In [None]:
%cd {HOME}
print(f'HOME IS {HOME}')
!wget -q -P weights https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import BertModel, BertTokenizer
from timm import create_model
from PIL import Image
import xml.etree.ElementTree as ET
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import math
import matplotlib.patches as patches
import random
import time
import copy
import json
import warnings
warnings.filterwarnings("ignore")

# Import GroundingDINO modules
%cd {HOME}/GroundingDINO

try:
    from groundingdino.models import build_model
    from groundingdino.util.slconfig import SLConfig
    from groundingdino.util.utils import clean_state_dict
    from groundingdino.util.inference import load_image, load_model, get_phrases_from_posmap
    print("GroundingDINO modules imported successfully!")
except ImportError as e:
    print(f"Error importing GroundingDINO modules: {e}")
    print("Please check your installation and try again.")
    import traceback
    traceback.print_exc()

%cd {HOME}

# Check if imports were successful
print("Libraries imported successfully!")

# 2. RSVG

In [None]:
class VisualEncoder(nn.Module):
    def __init__(self):
        super(VisualEncoder, self).__init__()
        self.backbone = create_model('resnet50', pretrained=True, features_only=True)
        self.conv6_1 = nn.Conv2d(2048, 128, kernel_size=1, stride=1)
        self.conv6_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv7_1 = nn.Conv2d(256, 128, kernel_size=1, stride=1)
        self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv8_1 = nn.Conv2d(256, 128, kernel_size=1, stride=1)
        self.conv8_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        features = self.backbone(x)
        f1 = features[-1]
        x = self.relu(self.conv6_1(f1))
        f2 = self.relu(self.conv6_2(x))
        x = self.relu(self.conv7_1(f2))
        f3 = self.relu(self.conv7_2(x))
        x = self.relu(self.conv8_1(f3))
        f4 = self.relu(self.conv8_2(x))
        return [f1, f2, f3, f4]

class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert = BertModel.from_pretrained('bert-base-uncased')

    def forward(self, texts):
        if isinstance(texts, list) and isinstance(texts[0], str):
            tokens = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=40)
        else:
            # Handle different text formats
            if isinstance(texts, list) and isinstance(texts[0], dict) and "caption" in texts[0]:
                text_strs = [text["caption"] for text in texts]
            else:
                text_strs = [str(text) for text in texts]
            tokens = self.tokenizer(text_strs, return_tensors='pt', padding=True, truncation=True, max_length=40)

        device = next(self.bert.parameters()).device
        tokens = {k: v.to(device) for k, v in tokens.items()}
        outputs = self.bert(**tokens)
        word_embeddings = outputs.last_hidden_state
        sentence_embedding = outputs.pooler_output.unsqueeze(1)
        return [word_embeddings, sentence_embedding]

class MLCM(nn.Module):
    def __init__(self):
        super(MLCM, self).__init__()
        self.linear_v1 = nn.Linear(2048, 256)
        self.linear_v2 = nn.Linear(256, 256)
        self.linear_v3 = nn.Linear(256, 256)
        self.linear_v4 = nn.Linear(256, 256)
        self.linear_t_word = nn.Linear(768, 256)
        self.linear_t_sent = nn.Linear(768, 256)
        self.L = 6
        self.N = 6
        self.cross_attn_layers = nn.ModuleList([
            nn.MultiheadAttention(256, 8, dropout=0.1) for _ in range(self.L)
        ])
        self.cross_attn_norms = nn.ModuleList([
            nn.LayerNorm(256) for _ in range(self.L)
        ])
        self.self_attn_layers = nn.ModuleList([
            nn.MultiheadAttention(256, 8, dropout=0.1) for _ in range(self.N)
        ])
        self.self_attn_norms = nn.ModuleList([
            nn.LayerNorm(256) for _ in range(self.N)
        ])
        self.ffn_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(256, 2048),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(2048, 256)
            ) for _ in range(self.L + self.N)
        ])
        self.ffn_norms = nn.ModuleList([
            nn.LayerNorm(256) for _ in range(self.L + self.N)
        ])

    def forward(self, visual_features, text_features):
        f1, f2, f3, f4 = visual_features
        B, C, H, W = f1.shape
        f1_flat = f1.reshape(B, C, H*W).permute(0, 2, 1)
        f1_proj = self.linear_v1(f1_flat)
        B, C, H, W = f2.shape
        f2_flat = f2.reshape(B, C, H*W).permute(0, 2, 1)
        f2_proj = self.linear_v2(f2_flat)
        B, C, H, W = f3.shape
        f3_flat = f3.reshape(B, C, H*W).permute(0, 2, 1)
        f3_proj = self.linear_v3(f3_flat)
        B, C, H, W = f4.shape
        f4_flat = f4.reshape(B, C, H*W).permute(0, 2, 1)
        f4_proj = self.linear_v4(f4_flat)
        visual_proj = torch.cat([f1_proj, f2_proj, f3_proj, f4_proj], dim=1)
        word_embeddings, sentence_embedding = text_features
        word_proj = self.linear_t_word(word_embeddings)
        sent_proj = self.linear_t_sent(sentence_embedding)
        text_proj = torch.cat([word_proj, sent_proj], dim=1)
        fvt = torch.cat([visual_proj, text_proj], dim=1)
        x = f1_proj
        for i in range(self.L):
            x_trans = x.permute(1, 0, 2)
            fvt_trans = fvt.permute(1, 0, 2)
            attn_output, _ = self.cross_attn_layers[i](
                query=x_trans,
                key=fvt_trans,
                value=fvt_trans
            )
            attn_output = attn_output.permute(1, 0, 2)
            x = x + attn_output
            x = self.cross_attn_norms[i](x)
            ffn_output = self.ffn_layers[i](x)
            x = x + ffn_output
            x = self.ffn_norms[i](x)
        for i in range(self.N):
            x_trans = x.permute(1, 0, 2)
            attn_output, _ = self.self_attn_layers[i](
                query=x_trans,
                key=x_trans,
                value=x_trans
            )
            attn_output = attn_output.permute(1, 0, 2)
            x = x + attn_output
            x = self.self_attn_norms[i](x)
            ffn_output = self.ffn_layers[i + self.L](x)
            x = x + ffn_output
            x = self.ffn_norms[i + self.L](x)
        return x

class MultimodalFusionModule(nn.Module):
    def __init__(self, d_model=256):
        super(MultimodalFusionModule, self).__init__()
        self.learnable_token = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.learnable_token, std=0.02)
        self.visual_proj = nn.Linear(256, d_model)
        self.text_proj = nn.Linear(768, d_model)

        # Set up transformer
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            activation='relu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=6)

    def forward(self, visual_features, text_features):
        visual_tokens = self.visual_proj(visual_features)
        word_embeddings = text_features[0]
        text_tokens = self.text_proj(word_embeddings)
        batch_size = visual_tokens.size(0)
        learnable_token = self.learnable_token.expand(batch_size, -1, -1)
        joint_tokens = torch.cat([learnable_token, visual_tokens, text_tokens], dim=1)
        output = self.transformer(joint_tokens)
        learnable_output = output[:, 0, :]
        return learnable_output

class LocalizationModule(nn.Module):
    def __init__(self, d_model=256):
        super(LocalizationModule, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 4)
        )

    def forward(self, x):
        return self.mlp(x)

class RSVGModel(nn.Module):
    def __init__(self):
        super(RSVGModel, self).__init__()
        self.visual_encoder = VisualEncoder()
        self.text_encoder = TextEncoder()
        self.mlcm = MLCM()
        self.multimodal_fusion = MultimodalFusionModule()
        self.localization = LocalizationModule()

    def forward(self, images, texts):
        visual_features = self.visual_encoder(images)
        text_features = self.text_encoder(texts)
        refined_visual_features = self.mlcm(visual_features, text_features)
        fused_features = self.multimodal_fusion(refined_visual_features, text_features)
        box_coords = self.localization(fused_features)
        return box_coords

# 3. LORA Layers

In [None]:
class LoRALayer(nn.Module):
    """
    Low-Rank Adaptation layer for efficient fine-tuning of pre-trained models
    """
    def __init__(self, in_features, out_features, rank=4, alpha=16):
        super().__init__()
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank

        # LoRA weights
        self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, out_features))

        # Initialize weights
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        # Low-rank adaptation
        return (x @ self.lora_A) @ self.lora_B * self.scaling

def apply_lora_to_linear_layers(module, target_names=None, rank=4, alpha=16):
    """
    Apply LoRA to linear layers in a model safely without changing dictionary size during iteration
    """
    if target_names is None:
        # Default to common attention projection layers
        target_names = ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'query', 'key', 'value']

    lora_params = []

    # First collect all the modules we want to modify
    modules_to_modify = []
    for name, submodule in module.named_modules():
        if isinstance(submodule, nn.Linear) and any(target in name for target in target_names):
            modules_to_modify.append((name, submodule))

    # Then apply LoRA without modifying the dictionary during iteration
    for name, submodule in modules_to_modify:
        in_features, out_features = submodule.in_features, submodule.out_features

        # Create a LoRA layer
        lora_layer = LoRALayer(in_features, out_features, rank, alpha)

        # Store original forward
        original_forward = submodule.forward

        # Create a new forward method that applies base + LoRA
        def create_forward_hook(orig_forward, lora):
            def forward_hook(x):
                return orig_forward(x) + lora(x)
            return forward_hook

        # Set the new forward method
        submodule.forward = create_forward_hook(original_forward, lora_layer)

        # Store the lora_layer as a direct attribute of the parent module
        # Use a sanitized name to avoid issues with dots in attribute names
        lora_name = f"{name.replace('.', '_')}_lora"
        setattr(module, lora_name, lora_layer)

        # Add parameters to the list of trainable parameters
        lora_params.extend(list(lora_layer.parameters()))

        print(f"Applied LoRA to {name}")

    return lora_params

# 4. Helper Functions

In [None]:
# !pip install spacy
# !python -m spacy download en_core_web_sm

In [None]:
class RSVGLoss(nn.Module):
    def __init__(self, lambda_giou=1.0):
        super(RSVGLoss, self).__init__()
        self.smooth_l1 = nn.SmoothL1Loss()
        self.lambda_giou = lambda_giou

    def forward(self, pred_boxes, target_boxes):
        smooth_l1_loss = self.smooth_l1(pred_boxes, target_boxes)
        giou_loss = self.generalized_box_iou_loss(pred_boxes, target_boxes)
        total_loss = smooth_l1_loss + self.lambda_giou * giou_loss
        return total_loss

    def generalized_box_iou_loss(self, pred_boxes, target_boxes):
        pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
        target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1])
        left_top = torch.max(pred_boxes[:, :2], target_boxes[:, :2])
        right_bottom = torch.min(pred_boxes[:, 2:], target_boxes[:, 2:])
        wh = (right_bottom - left_top).clamp(min=0)
        intersection = wh[:, 0] * wh[:, 1]
        union = pred_area + target_area - intersection
        iou = intersection / (union + 1e-7)
        enclosing_left_top = torch.min(pred_boxes[:, :2], target_boxes[:, :2])
        enclosing_right_bottom = torch.max(pred_boxes[:, 2:], target_boxes[:, 2:])
        enclosing_wh = (enclosing_right_bottom - enclosing_left_top).clamp(min=0)
        enclosing_area = enclosing_wh[:, 0] * enclosing_wh[:, 1]
        giou = iou - (enclosing_area - union) / (enclosing_area + 1e-7)
        return 1 - giou.mean()

In [None]:
class RemoteSensingDataset(Dataset):
    def __init__(self, img_dir, ann_dir, transform=None):
        self.img_dir = img_dir
        self.ann_dir = ann_dir
        self.imgs = [f for f in os.listdir(self.img_dir) if f.endswith(('.jpg', '.png'))]
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((640, 640)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        img_path = os.path.join(self.img_dir, img_name)
        xml_path = os.path.join(self.ann_dir, img_name.replace('.jpg', '.xml').replace('.png', '.xml'))

        try:
            image = Image.open(img_path).convert('RGB')
            w, h = image.size
            image = self.transform(image)

            tree = ET.parse(xml_path)
            root = tree.getroot()
            obj = root.find('object')
            if obj is None:
                return None

            grounding_caption_elem = root.find('grounding_caption')
            if grounding_caption_elem is not None and grounding_caption_elem.text is not None:
                query = grounding_caption_elem.text.strip()
                print(f'CAPTION NOT FOUND')
            else:
                description = obj.find('description')
                if description is None or description.text is None:
                    query = "object"
                else:
                    query = description.text.strip()

            bbox = obj.find('bndbox')
            if bbox is None:
                return None

            x1 = float(bbox.find('xmin').text)
            y1 = float(bbox.find('ymin').text)
            x2 = float(bbox.find('xmax').text)
            y2 = float(bbox.find('ymax').text)

            # Normalize coordinates
            box = torch.tensor([x1/w, y1/h, x2/w, y2/h], dtype=torch.float32)

            return image, query, box

        except Exception as e:
            print(f"Error loading sample {img_name}: {e}")
            return None

def compute_iou(pred_box, true_box):
    """
    Compute IoU between predicted and ground truth boxes
    Boxes are in format [x1, y1, x2, y2]
    """
    # Ensure tensors
    if not isinstance(pred_box, torch.Tensor):
        pred_box = torch.tensor(pred_box)
    if not isinstance(true_box, torch.Tensor):
        true_box = torch.tensor(true_box)

    if pred_box.nelement() == 0 or true_box.nelement() == 0:
        return 0.0
    # Calculate intersection
    x1_inter = torch.max(pred_box[0], true_box[0])
    y1_inter = torch.max(pred_box[1], true_box[1])
    x2_inter = torch.min(pred_box[2], true_box[2])
    y2_inter = torch.min(pred_box[3], true_box[3])

    width_inter = torch.max(torch.tensor(0.0), x2_inter - x1_inter)
    height_inter = torch.max(torch.tensor(0.0), y2_inter - y1_inter)
    area_inter = width_inter * height_inter

    # Calculate areas
    area_pred = (pred_box[2] - pred_box[0]) * (pred_box[3] - pred_box[1])
    area_true = (true_box[2] - true_box[0]) * (true_box[3] - true_box[1])

    # Calculate union
    area_union = area_pred + area_true - area_inter

    # Calculate IoU
    iou = area_inter / (area_union + 1e-7)

    return iou

def custom_collate_fn(batch):
    # Remove None samples
    batch = [x for x in batch if x is not None]
    if len(batch) == 0:
        return None
    images, queries, boxes = zip(*batch)
    images = torch.stack(images, dim=0)
    boxes = torch.stack(boxes, dim=0)
    return images, list(queries), boxes


# 5. RSVG_DINO

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast

from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.misc import clean_state_dict

class RSVG_DINO(nn.Module):
    def __init__(
        self,
        rsvg_model_path: str,
        config_path: str,
        groundingdino_weights_path: str
    ):
        super().__init__()

        # RSVG model
        print("Loading RSVG model...")
        self.rsvg = RSVGModel()
        ckpt = torch.load(rsvg_model_path, map_location='cpu')
        rsvg_state = ckpt.get("model_state_dict", ckpt)
        self.rsvg.load_state_dict(rsvg_state, strict=False)
        self.rsvg.eval()
        for p in self.rsvg.parameters():
            p.requires_grad_(False)
        print("RSVG model loaded successfully!")

        # GroundingDINO model
        print("Loading GroundingDINO model...")
        args = SLConfig.fromfile(config_path)
        self.groundingdino = build_model(args)
        dino_ckpt = torch.load(groundingdino_weights_path, map_location='cpu')
        dino_state = clean_state_dict(dino_ckpt.get("model", dino_ckpt))
        self.groundingdino.load_state_dict(dino_state, strict=False)
        self.groundingdino.eval()
        print("GroundingDINO model loaded successfully!")

        # Default fallback for tokens
        self.default_tokens_positive = [[1]]

        # Projection & fusion heads
        self.dino_proj = nn.Linear(256, 256)
        self.feature_fusion = nn.Sequential(
            nn.Linear(256 + 256, 256),
            nn.ReLU(inplace=True)
        )
        self.confidence_weighting = nn.Linear(256, 2)
        self.enhanced_localization = nn.Linear(256, 4)

        # Initialize weights to favor RSVG initially
        with torch.no_grad():
            if hasattr(self.confidence_weighting, 'bias'):
                self.confidence_weighting.bias.data[0] = 1.0  # Higher bias for RSVG
                self.confidence_weighting.bias.data[1] = 0.0  # Lower bias for DINO

        print("RSVG_DINO initialization complete!")

        # Debug mode
        self.debug = True

    def forward(self, images, texts):
        B = images.shape[0]
        device = images.device

        if self.debug:
            print("\n==== DEBUG: RSVG_DINO.forward ====")
            print(f"Batch size: {B}")
            print(f"Images shape: {images.shape}")
            print(f"Texts type: {type(texts)}, length: {len(texts)}")
            for i, t in enumerate(texts[:min(3, len(texts))]):  # Print first few
               print(f"  Text {i}: {t}")
            if len(texts) > 3:
                print(f"  ... and {len(texts) - 3} more")

           # Send models to device
        self.rsvg.to(device)
        self.groundingdino.to(device)

        # RSVG
        if self.debug:
            print("\n--- Processing RSVG pipeline ---")

        with torch.no_grad():
            # Get inputs for RSVG
            if isinstance(texts[0], dict) and "caption" in texts[0]:
                rsvg_texts = [t["caption"] for t in texts]
            else:
                rsvg_texts = texts

            if self.debug:
                print(f"RSVG texts: {rsvg_texts[:min(3, len(rsvg_texts))]}")

            # Process through RSVG pipeline
            vf = self.rsvg.visual_encoder(images)
            tf = self.rsvg.text_encoder(rsvg_texts)
            rv = self.rsvg.mlcm(vf, tf)
            rsvg_feats = self.rsvg.multimodal_fusion(rv, tf)
            rsvg_boxes = self.rsvg.localization(rsvg_feats)

            if self.debug:
                print(f"RSVG features shape: {rsvg_feats.shape}")
                print(f"RSVG boxes shape: {rsvg_boxes.shape}")
                print(f"RSVG sample boxes: {rsvg_boxes[0]}")

        # GROUNDING DINO
        if self.debug:
            print("\n--- Processing GroundingDINO pipeline ---")

        try:
            # Format queries with guaranteed valid token spans
            # from transformers import BertTokenizerFast
            tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
            dino_texts = []
            for cap in texts:
                # If input is dict (from custom collate), extract "caption"
                if isinstance(cap, dict) and "caption" in cap:
                    caption = cap["caption"]
                else:
                    caption = cap
                ids = tokenizer(caption, add_special_tokens=True).input_ids
                seq_len = len(ids)
                tokens_pos = list(range(1, seq_len - 1))
                if not tokens_pos:
                    tokens_pos = [1]
                dino_texts.append({
                    "caption": caption,
                    "tokens_positive": [tokens_pos],
                })

            if self.debug:
                print(f"Formatted {len(dino_texts)} texts for GroundingDINO")
                print(f"Sample formatted text: {dino_texts[0]}")

            # IMPORTANT: Add special attribute to help bertwarper
            self.groundingdino.specical_tokens = [q["tokens_positive"][0] for q in dino_texts]

            if self.debug:
                print("Set self.groundingdino.specical_tokens to:")
                for i, tok in enumerate(self.groundingdino.specical_tokens[:min(3, len(self.groundingdino.specical_tokens))]):
                    print(f"  Sample {i}: {tok}")

            # CRITICAL DEBUGGING POINT - Before actual GroundingDINO forward
            if self.debug:
                print("\n>>> About to call groundingdino forward <<<")
                print(f"groundingdino input image shape: {images.shape}")
                print(f"groundingdino input texts: {len(dino_texts)} items")
                try:
                    model_state = {
                        "has_tokenizer": hasattr(self.groundingdino, "tokenizer"),
                        "has_special_tokens": hasattr(self.groundingdino, "specical_tokens"),
                        "tokenizer_type": type(self.groundingdino.tokenizer).__name__ if hasattr(self.groundingdino, "tokenizer") else None,
                    }
                    print(f"Model state checks: {model_state}")
                except Exception as e:
                    print(f"Error during model inspection: {e}")

            # Process through GroundingDINO
            dino_out = self.groundingdino(images, dino_texts)

            if self.debug:
                print("\n>>> GroundingDINO forward succeeded! <<<")
                print(f"Output keys: {list(dino_out.keys())}")

            # Extract outputs
            logits = dino_out["pred_logits"]
            boxes = dino_out["pred_boxes"]

            if self.debug:
                print(f"Pred logits shape: {logits.shape}")
                print(f"Pred boxes shape: {boxes.shape}")

            # Get feature embeddings
            if "hidden_states" in dino_out:
                df = dino_out["hidden_states"][-1][:, 0, :]
                if self.debug:
                    print(f"Using hidden_states for features, shape: {df.shape}")
            elif "encoder_hidden_states" in dino_out:
                df = dino_out["encoder_hidden_states"][-1][:, 0, :]
                if self.debug:
                    print(f"Using encoder_hidden_states for features, shape: {df.shape}")
            else:
                df = torch.zeros(B, 256, device=device)
                if self.debug:
                    print("No hidden states found, using zeros")

            # Process boxes and scores
            box_list, score_list = [], []
            for b in range(B):
                scores = F.softmax(logits[b, :, 0], dim=0)
                if self.debug and b == 0:
                    print(f"Sample scores shape: {scores.shape}")
                    print(f"Sample scores values: {scores[:5]}")  # First 5 scores

                if scores.numel() > 0:
                    # Get best prediction
                    i = scores.argmax()
                    s = scores[i]
                    c = boxes[b, i]

                    # Convert from center-size to corners
                    x1 = c[0] - c[2] / 2
                    y1 = c[1] - c[3] / 2
                    x2 = c[0] + c[2] / 2
                    y2 = c[1] + c[3] / 2

                    box_list.append(torch.stack([x1, y1, x2, y2], dim=0))
                    score_list.append(s)

                    if self.debug and b == 0:
                        print(f"Best box for sample 0: {[x1.item(), y1.item(), x2.item(), y2.item()]}")
                        print(f"Confidence score: {s.item()}")
                else:
                    # Default box (full image) with zero confidence
                    box_list.append(torch.tensor([0, 0, 1, 1], device=device))
                    score_list.append(torch.tensor(0.0, device=device))

                    if self.debug and b == 0:
                        print("Using default box [0,0,1,1] with zero confidence")

            # Stack results
            dino_boxes = torch.stack(box_list)
            dino_scores = torch.stack(score_list)

        except Exception as e:
            if self.debug:
                print(f"\nERROR in GroundingDINO processing: {e}")
                import traceback
                traceback.print_exc()

            # Use default values on error
            df = torch.zeros(B, 256, device=device)
            dino_boxes = torch.zeros(B, 4, device=device)
            dino_scores = torch.zeros(B, device=device)

        # FEATURE FUSION & OUTPUT
        if self.debug:
            print("\n--- Processing Feature Fusion ---")

        df_proj = self.dino_proj(df)
        fused_in = torch.cat([rsvg_feats, df_proj], dim=1)
        fused_feats = self.feature_fusion(fused_in)
        weights = F.softmax(self.confidence_weighting(fused_feats), dim=1)
        enhanced_bxes = self.enhanced_localization(fused_feats)

        # Weighted average of boxes based on confidence
        weighted_bxes = weights[:, 0:1] * rsvg_boxes + weights[:, 1:2] * dino_boxes

        if self.debug:
            print(f"Fusion weights: RSVG={weights[0, 0].item():.4f}, DINO={weights[0, 1].item():.4f}")
            print(f"Final boxes sample: {weighted_bxes[0]}")
            print("==== END DEBUG: RSVG_DINO.forward ====\n")

        return {
            "boxes": weighted_bxes,
            "enhanced_boxes": enhanced_bxes,
            "rsvg_boxes": rsvg_boxes,
            "dino_boxes": dino_boxes,
            "dino_scores": dino_scores,
            "confidence_weights": weights,
            "fused_features": fused_feats
        }

# 6. Training Loop:

In [None]:
# import os
# import torch
# from tqdm import tqdm

# def train_rsvgdino_model(
#     model,
#     train_loader,
#     criterion,
#     optimizer,
#     device,
#     checkpoint_dir,
#     num_epochs
# ):
#     """
#     model         : your RSVG_DINO
#     train_loader  : DataLoader yielding (images, texts, targets)
#     criterion     : RSVGLoss(pred_boxes, target_boxes)
#     optimizer     : as you already built it
#     device        : torch.device
#     checkpoint_dir: path to save checkpoints
#     num_epochs    : int
#     """
#     os.makedirs(checkpoint_dir, exist_ok=True)
#     history = {"loss": []}

#     for epoch in range(1, num_epochs + 1):
#         model.train()
#         running_loss = 0.0

#         loop = tqdm(train_loader, desc=f"[Epoch {epoch}/{num_epochs}]", leave=False)
#         for batch_idx, (images, texts, targets) in enumerate(loop, start=1):
#             # ---- 1) move images to device ----
#             images = images.to(device)

#             # ---- 2) pull out target_boxes ----
#             if isinstance(targets, dict) and "boxes" in targets:
#                 # your collate returned a dict of batched tensors
#                 target_boxes = targets["boxes"].to(device)

#             elif torch.is_tensor(targets):
#                 # collate returned a single tensor of shape [B,4]
#                 target_boxes = targets.to(device)

#             elif isinstance(targets, (list, tuple)) and torch.is_tensor(targets[0]):
#                 # collate returned list of per-sample box tensors [4] -> stack into [B,4]
#                 target_boxes = torch.stack([t.to(device) for t in targets], dim=0)

#             else:
#                 raise ValueError(f"Unrecognized targets format: {type(targets)}")

#             # ---- 3) forward + loss ----
#             optimizer.zero_grad()
#             outputs = model(images, texts)
#             pred_boxes = outputs["boxes"]

#             loss = criterion(pred_boxes, target_boxes)

#             loss.backward()
#             optimizer.step()

#             # ---- 4) logging ----
#             running_loss += loss.item()
#             if batch_idx % 10 == 0:
#                 avg_batch_loss = running_loss / batch_idx
#                 print(f"[Epoch {epoch}] Batch {batch_idx}/{len(train_loader)} — avg loss: {avg_batch_loss:.4f}")
#             loop.set_postfix(loss=running_loss / batch_idx)

#         # end of epoch
#         avg_loss = running_loss / len(train_loader)
#         history["loss"].append(avg_loss)
#         print(f"Epoch {epoch}/{num_epochs} — avg loss: {avg_loss:.4f}")

#         # save every 5 epochs
#         if epoch % 5 == 0:
#             ckpt_name = f"rsvgdino_{epoch}.pth"
#             ckpt_path = os.path.join(checkpoint_dir, ckpt_name)
#             torch.save(model.state_dict(), ckpt_path)
#             print(f"Saved checkpoint: {ckpt_path}")

#     return model, history


In [None]:
import os
import torch
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from torch.optim import AdamW


# 1. Configuration
batch_size = 14
num_epochs = 20

# 2. paths to pretrained RSVG & DINO
rsvg_model_path            = f"/{HOME}/drive/MyDrive/Checkpoints/rsvg_checkpoint_epoch_150.pth"
config_path                = f"/{HOME}/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
groundingdino_weights_path = f"/{HOME}/weights/groundingdino_swint_ogc.pth"

# 3. where to save your new checkpoints
checkpoint_dir = f"{HOME}/drive/MyDrive/Checkpoints2"
os.makedirs(checkpoint_dir, exist_ok=True)

# 3. Data directories
train_img_dir = f"{HOME}/drive/MyDrive/Dataset/train_data/train_images"
train_ann_dir = f"{HOME}/drive/MyDrive/Dataset/train_data/train_annotations"

# 5. Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 6. Dataset & DataLoader
train_dataset = RemoteSensingDataset(
    img_dir=train_img_dir,
    ann_dir=train_ann_dir
)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=custom_collate_fn
)

print(f"Training dataset contains {len(train_dataset)} samples")

# 7. Model
model = RSVG_DINO(
    rsvg_model_path=rsvg_model_path,
    config_path=config_path,
    groundingdino_weights_path=groundingdino_weights_path
)
model = model.to(device)



# 8 Optimizer
backbone_params = []
text_params     = []
fusion_params   = []
lora_params     = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue

    # LoRA adapters
    if "lora" in name:
        lora_params.append(param)

    # Vision & DINO backbones
    elif name.startswith("rsvg.visual_encoder.backbone") \
      or name.startswith("groundingdino.backbone") \
      or name.startswith("groundingdino.input_proj"):
        backbone_params.append(param)

    # Text encoder (BERT)
    elif name.startswith("rsvg.text_encoder.bert") \
      or name.startswith("groundingdino.bert"):
        text_params.append(param)

    # Everything else new: fusion, heads, MLCM, projections
    else:
        fusion_params.append(param)

optimizer = AdamW([
    { "params": backbone_params, "lr": 5e-6 },   # backbone
    { "params": text_params,     "lr": 1e-5 },   # text
    { "params": fusion_params,   "lr": 2e-5 },   # fusion heads
    { "params": lora_params,     "lr": 8e-5 },   # LoRA
], weight_decay=1e-2)


# 9. Loss function
criterion = RSVGLoss(lambda_giou=1.0)



Using device: cuda
Training dataset contains 12340 samples
Loading RSVG model...
RSVG model loaded successfully!
Loading GroundingDINO model...
final text_encoder_type: bert-base-uncased
GroundingDINO model loaded successfully!
RSVG_DINO initialization complete!


In [None]:
# # Assume you have already loaded your model as `model`
# # and have a batch of images and queries (texts)

# # Example: get a batch from your DataLoader
# batch = next(iter(train_loader))  # or test_loader
# images, queries, boxes = batch  # boxes are optional for inference

# # Move images to device
# images = images.to(device)

# # Call the model directly
# outputs = model(images, queries)

# # outputs is a dict with keys: "boxes", "enhanced_boxes", etc.
# print(f' BOXES: {outputs["boxes"]}')
# print(f' ENHANCED BOXES: {outputs["enhanced_boxes"]}')

In [None]:
# # Kick off training
# trained_model, history = train_rsvgdino_model(
#     model=model,
#     train_loader=train_loader,
#     criterion=criterion,
#     optimizer=optimizer,
#     device=device,
#     checkpoint_dir=checkpoint_dir,
#     num_epochs=num_epochs
# )

# print("Training completed successfully!")

In [None]:
!pwd

/content


In [None]:
!pip install -q git+https://github.com/facebookresearch/segment-anything.git
!pip install -q opencv-python-headless tqdm
from segment_anything import sam_model_registry, SamPredictor
import numpy as np
import torch

def load_sam_predictor(sam_checkpoint_path, device="cuda"):
    """
    Loads SAM ViT-B model as a box-refinement predictor.
    """
    sam = sam_model_registry["vit_b"](checkpoint=sam_checkpoint_path).to(device).eval()
    return SamPredictor(sam)

def refine_boxes_with_sam(image_np, boxes_np, predictor):
    """
    Refines each predicted box using SAM to snap it to object boundaries.

    Args:
        image_np: (H, W, 3) RGB image as numpy array (uint8)
        boxes_np: (N, 4) array of predicted boxes (xyxy)
        predictor: SamPredictor instance

    Returns:
        (N, 4) array of tightened boxes
    """
    refined_boxes = []
    predictor.set_image(image_np)

    for box in boxes_np:
        box_tensor = torch.tensor(box, dtype=torch.float32).reshape(1, 4)
        masks, _, _ = predictor.predict(box=box_tensor.numpy(), multimask_output=False)
        mask = masks[0]

        ys, xs = np.where(mask)
        if len(xs) == 0 or len(ys) == 0:
            refined_boxes.append(box)
            continue

        x0, y0, x1, y1 = xs.min(), ys.min(), xs.max(), ys.max()
        refined_boxes.append([x0, y0, x1, y1])

    return np.array(refined_boxes)


  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
def evaluate_model_with_sam(model, dataloader, sam_checkpoint_path, device="cuda", iou_threshold=0.5):
    model.eval()
    predictor = load_sam_predictor(sam_checkpoint_path, device)

    correct, total_preds, total_gts = 0, 0, 0
    all_ious = []

    for batch in tqdm(dataloader, desc="Evaluating"):
        images, texts, gt_boxes = batch  # from your dataloader

        for i in range(len(images)):
            img_tensor = images[i].unsqueeze(0).to(device)  # (1, C, H, W)
            text = [texts[i]]  # wrap single text into list
            gt = gt_boxes[i].cpu().numpy()

            # Run forward pass
            with torch.no_grad():
                out = model(img_tensor, text)

                # Change pred_boxes to remove the squeeze() or handle the case where it contains only one element
                # Original Line: pred_boxes = out["boxes"].squeeze(0).cpu().numpy()
                pred_boxes = out["boxes"].cpu().numpy()
                if pred_boxes.ndim == 1:
                    pred_boxes = pred_boxes[np.newaxis, :]  # add a dimension if it's a 1D array

            # Prepare image for SAM
            image_np = images[i].permute(1, 2, 0).cpu().numpy()
            image_np = ((image_np * 255).clip(0, 255)).astype(np.uint8)

            # SAM refinement
            refined = refine_boxes_with_sam(image_np, pred_boxes, predictor)

            matched = set()
            for p in refined:
                for j, g in enumerate(gt):
                    iou = compute_iou(p, g)
                    if iou >= iou_threshold and j not in matched:
                        correct += 1
                        matched.add(j)
                        all_ious.append(iou)
                        break

            total_preds += len(refined)
            total_gts += len(gt)

    precision = correct / total_preds if total_preds > 0 else 0
    recall = correct / total_gts if total_gts > 0 else 0
    f1 = 2 * precision * recall / (precision + recall + 1e-6)
    mean_iou = np.mean(all_ious) if all_ious else 0

    print(f"📊 Precision: {precision:.3f} | Recall: {recall:.3f} | F1: {f1:.3f} | mIoU: {mean_iou:.3f}")

In [None]:
# Test dataset paths
test_img_dir = "/content/drive/MyDrive/Dataset/test_data/test_images"
test_ann_dir = "/content/drive/MyDrive/Dataset/test_data/test_annotations"

# Create test dataset
test_dataset = RemoteSensingDataset(
    img_dir=test_img_dir,
    ann_dir=test_ann_dir
)

# Create test DataLoader
from torch.utils.data import DataLoader

test_loader = DataLoader(
    test_dataset,
    batch_size=1,  # for SAM accuracy, 1 image at a time is best
    shuffle=False,
    num_workers=2,
    collate_fn=custom_collate_fn
)

print(f"Test dataset contains {len(test_dataset)} samples.")


Test dataset contains 3372 samples.


In [None]:
sam_ckpt = "/content/drive/MyDrive/Checkpoints2/sam_vit_b_01ec64.pth"

evaluate_model_with_sam(model, test_loader, sam_ckpt)

Evaluating:   0%|          | 0/3372 [00:00<?, ?it/s]


==== DEBUG: RSVG_DINO.forward ====
Batch size: 1
Images shape: torch.Size([1, 3, 640, 640])
Texts type: <class 'list'>, length: 1
  Text 0: The golf field at the bottom

--- Processing RSVG pipeline ---
RSVG texts: ['The golf field at the bottom']
RSVG features shape: torch.Size([1, 256])
RSVG boxes shape: torch.Size([1, 4])
RSVG sample boxes: tensor([0.2451, 0.4868, 0.7370, 0.9557], device='cuda:0')

--- Processing GroundingDINO pipeline ---
Formatted 1 texts for GroundingDINO
Sample formatted text: {'caption': 'The golf field at the bottom', 'tokens_positive': [[1, 2, 3, 4, 5, 6]]}
Set self.groundingdino.specical_tokens to:
  Sample 0: [1, 2, 3, 4, 5, 6]

>>> About to call groundingdino forward <<<
groundingdino input image shape: torch.Size([1, 3, 640, 640])
groundingdino input texts: 1 items
Model state checks: {'has_tokenizer': True, 'has_special_tokens': True, 'tokenizer_type': 'BertTokenizerFast'}

ERROR in GroundingDINO processing: stack expects a non-empty TensorList

--- Pro

Traceback (most recent call last):
  File "<ipython-input-104-d52a0d23cdf8>", line 160, in forward
    dino_out = self.groundingdino(images, dino_texts)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/GroundingDINO/groundingdino/models/GroundingDINO/groundingdino.py", line 255, in forward
    ) = generate_masks_with_special_tokens_and_transfer_map(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/GroundingDINO/groundingdino/models/GroundingDINO/bertwarper.py", line 264, in generate_masks_with_special_tokens_and_transfer_map
    cate_to_token_mask_list = [
    

IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number