# Grounding DINO Internals: Understanding pred_logits

This notebook walks through how Grounding DINO processes images and text to produce predictions.

We'll explore:
1. Text tokenization and encoding
2. Visual feature extraction  
3. Contrastive similarity computation (pred_logits)
4. Token-to-class mapping


In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))
sys.path.insert(0, str(project_root / 'GroundingDINO'))

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt


## Step 0: Load the Model


In [None]:
from groundingdino.util.slconfig import SLConfig
from groundingdino.models import build_model
from groundingdino.util.utils import clean_state_dict

# Load config
config_path = project_root / 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
checkpoint_path = project_root / 'data/models/pretrained/groundingdino_swint_ogc.pth'

args = SLConfig.fromfile(str(config_path))
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = build_model(args)
checkpoint = torch.load(str(checkpoint_path), map_location='cpu')
model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
model = model.to(args.device)
model.eval()

print(f"Model loaded on {args.device}")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")


## Step 1: Understand the Text Pipeline

Grounding DINO uses BERT to encode text. Let's see how class names become token embeddings.


In [None]:
# Our example classes
class_names = ['dog', 'cat', 'car']

# Grounding DINO formats classes as: "class1 . class2 . class3."
caption = " . ".join(class_names) + "."
print(f"Caption: '{caption}'")


In [None]:
# Access the tokenizer
tokenizer = model.tokenizer

# Tokenize the caption
tokenized = tokenizer(
    caption,
    padding='max_length',
    max_length=256,
    return_tensors='pt'
)

input_ids = tokenized['input_ids'][0]
attention_mask = tokenized['attention_mask'][0]

print(f"Input IDs shape: {input_ids.shape}")
print(f"Attention mask shape: {attention_mask.shape}")
print(f"\nFirst 15 tokens:")

# Decode tokens to see what they are
for i in range(15):
    token_id = input_ids[i].item()
    token = tokenizer.decode([token_id])
    mask = attention_mask[i].item()
    print(f"  Position {i:2d}: ID={token_id:5d}, Token='{token:10s}', Mask={mask}")


In [None]:
# Use official utility to get token spans for each class
from groundingdino.util.vl_utils import build_captions_and_token_span, create_positive_map_from_span

caption_formatted, cat2tokenspan = build_captions_and_token_span(class_names, force_lowercase=False)

print(f"Formatted caption: '{caption_formatted}'")
print(f"\nCharacter spans per class:")
for class_name, spans in cat2tokenspan.items():
    print(f"  '{class_name}': characters {spans}")

# Create positive_map: which tokens belong to which class
tokenized_for_map = tokenizer(caption_formatted, padding='longest', return_tensors='pt')
token_span_per_class = [cat2tokenspan.get(name, []) for name in class_names]

positive_map = create_positive_map_from_span(
    tokenized_for_map,
    token_span_per_class,
    max_text_len=256
)  # Shape: [num_classes, max_text_len]

print(f"\nPositive map shape: {positive_map.shape}")
print(f"\nToken positions for each class:")
for i, class_name in enumerate(class_names):
    token_positions = (positive_map[i] > 0).nonzero(as_tuple=True)[0].tolist()
    print(f"  Class {i} '{class_name}': tokens at positions {token_positions}")


## Step 2: Prepare an Image and Run Forward Pass


In [None]:
import groundingdino.datasets.transforms as T

# Create a sample image (or load one if available)
sample_images = list((project_root / 'data').rglob('*.jpg'))[:1]
if not sample_images:
    sample_images = list((project_root / 'data').rglob('*.png'))[:1]

transform = T.Compose([
    T.RandomResize([800], max_size=1333),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

if sample_images:
    image_path = sample_images[0]
    print(f"Using image: {image_path}")
    image_source = np.array(Image.open(image_path).convert('RGB'))
else:
    print("No sample image found, creating a random 640x480 image")
    image_source = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)

image_pil = Image.fromarray(image_source)
image_tensor, _ = transform(image_pil, None)

print(f"Original image shape: {image_source.shape}")
print(f"Preprocessed tensor shape: {image_tensor.shape}")

plt.figure(figsize=(10, 8))
plt.imshow(image_source)
plt.title("Input Image")
plt.axis('off')
plt.show()


In [None]:
device = args.device

# Prepare batch
images = image_tensor.unsqueeze(0).to(device)  # [1, 3, H, W]
captions = [caption_formatted]

print(f"Input images shape: {images.shape}")
print(f"Caption: {captions}")

# Forward pass
with torch.no_grad():
    outputs = model(samples=images, captions=captions)

print(f"\nOutput keys: {outputs.keys()}")


## Step 3: Understand pred_logits (The Key Step!)

pred_logits contains the **dot product similarity** between:
- Visual queries (900 learned queries that attend to image regions)
- Text tokens (256 positions from BERT encoding)

This is NOT class probabilities yet!


In [None]:
# Inspect the outputs
pred_logits = outputs['pred_logits']  # [B, num_queries, num_tokens]
pred_boxes = outputs['pred_boxes']    # [B, num_queries, 4]

print(f"pred_logits shape: {pred_logits.shape}")
print(f"pred_boxes shape: {pred_boxes.shape}")
print(f"\nNumber of queries: {pred_logits.shape[1]}")
print(f"Number of text tokens: {pred_logits.shape[2]}")

# Look at raw logits statistics
logits = pred_logits[0]  # [num_queries, num_tokens]

print(f"\nRaw logits statistics:")
print(f"  Min: {logits.min().item():.2f}")
print(f"  Max: {logits.max().item():.2f}")
print(f"  Mean: {logits.mean().item():.2f}")

# Check for -inf (padding tokens)
num_inf = torch.isinf(logits).sum().item()
print(f"  Number of -inf values: {num_inf} (these are padding tokens)")


In [None]:
# Convert to probabilities with sigmoid
token_probs = logits.sigmoid()  # [num_queries, num_tokens]

print(f"Token probabilities after sigmoid:")
print(f"  Min: {token_probs.min().item():.6f}")
print(f"  Max: {token_probs.max().item():.6f}")
print(f"  Mean: {token_probs.mean().item():.6f}")

# Find the query with highest confidence
max_prob_per_query = token_probs.max(dim=1)[0]  # [num_queries]
top_query_idx = max_prob_per_query.argmax().item()
top_query_confidence = max_prob_per_query[top_query_idx].item()

print(f"\nQuery with highest confidence: {top_query_idx}")
print(f"Highest token probability: {top_query_confidence:.4f}")

# Look at this query's token probabilities for class tokens
top_query_probs = token_probs[top_query_idx]  # [num_tokens]

print(f"\nToken probabilities for top query (first 10 tokens):")
for i in range(10):
    prob = top_query_probs[i].item()
    token_id = tokenized_for_map['input_ids'][0, i].item()
    token = tokenizer.decode([token_id])
    print(f"  Position {i}: '{token:10s}' = {prob:.4f}")


## Step 4: Map Token Probabilities to Class Scores

Now we use positive_map to aggregate token probabilities into class scores.


In [None]:
positive_map = positive_map.to(device)

# Compute class scores for each query
num_queries = token_probs.shape[0]
num_classes = len(class_names)

class_probs = torch.zeros(num_queries, num_classes, device=device)

for c in range(num_classes):
    # Which tokens belong to this class?
    token_mask = positive_map[c] > 0
    
    if token_mask.sum() > 0:
        # Average the token probabilities for this class
        class_probs[:, c] = token_probs[:, token_mask].mean(dim=-1)

print(f"Class probabilities shape: {class_probs.shape}")
print(f"  [num_queries={num_queries}, num_classes={num_classes}]")

# For each query, get the best class and score
scores, labels = class_probs.max(dim=-1)  # [num_queries], [num_queries]

# Filter by confidence threshold
threshold = 0.3
keep = scores > threshold

print(f"\nQueries with confidence > {threshold}: {keep.sum().item()}")

# Show top 10 detections
top_indices = scores.argsort(descending=True)[:10]

print(f"\nTop 10 detections:")
print(f"{'Query':>6} {'Class':>10} {'Score':>8}")
print("-" * 30)
for idx in top_indices:
    query_idx = idx.item()
    label = labels[query_idx].item()
    score = scores[query_idx].item()
    class_name = class_names[label]
    print(f"{query_idx:6d} {class_name:>10} {score:8.4f}")


## Summary

The key insight is that **pred_logits is NOT class probabilities**.

The full pipeline:

```
1. Text: "dog . cat . car." → BERT → text_features [1, 256, 768]
                                    ↓ projection
                              proj_text [1, 256, 256]

2. Image → Swin → Transformer Decoder → query_features [1, 900, 256]
                                              ↓ projection  
                                        proj_visual [1, 900, 256]

3. pred_logits = proj_visual @ proj_text.T / temperature
   Shape: [1, 900, 256] (similarity between each query and each token)

4. token_probs = sigmoid(pred_logits)
   Shape: [1, 900, 256] (probabilities in [0, 1])

5. class_probs = aggregate(token_probs, positive_map)
   Shape: [1, 900, 3] (probability per class)

6. Final: filter by threshold, get (boxes, scores, labels)
```
