In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
from PIL import Image
import torchvision.transforms as T

# Change this path
image_path = '/content/drive/MyDrive/peddet_2_1.jpg'

img = Image.open(image_path).convert('RGB')
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225])
])
img_tensor = transform(img).unsqueeze(0)  # shape: [1, 3, H, W]


In [None]:
import torch
import torch.nn.functional as F
import types

def make_patched_forward(layer_type):
    def patched_forward(self, query, key, value, **kwargs):
        embed_dim = self.embed_dim
        num_heads = self.num_heads
        head_dim = embed_dim // num_heads
        scaling = float(head_dim) ** -0.5

        # Linear projections
        q_proj_weight = self.in_proj_weight[:embed_dim]
        k_proj_weight = self.in_proj_weight[embed_dim:2*embed_dim]
        v_proj_weight = self.in_proj_weight[2*embed_dim:]

        q_proj_bias = self.in_proj_bias[:embed_dim]
        k_proj_bias = self.in_proj_bias[embed_dim:2*embed_dim]
        v_proj_bias = self.in_proj_bias[2*embed_dim:]

        q = F.linear(query, q_proj_weight, q_proj_bias) * scaling
        k = F.linear(key,   k_proj_weight, k_proj_bias)
        v = F.linear(value, v_proj_weight, v_proj_bias)

        # Reshape for heads
        seq_len, batch_size, _ = q.shape
        q_heads = q.view(seq_len, batch_size, num_heads, head_dim)
        k_heads = k.view(k.shape[0], batch_size, num_heads, head_dim)
        v_heads = v.view(v.shape[0], batch_size, num_heads, head_dim)

        print(f"\n {layer_type.upper()} ATTENTION")
        print(f"Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}")
        print(f"Per-head shape: Q {q_heads.shape}, K {k_heads.shape}, V {v_heads.shape}")

        # Print value slices
        print(" Q[0, 0, 0, :10]:", q_heads[0, 0, 0, :10])
        print(" K[0, 0, 0, :10]:", k_heads[0, 0, 0, :10])
        print(" V[0, 0, 0, :10]:", v_heads[0, 0, 0, :10])

        return torch.nn.functional.multi_head_attention_forward(
            query, key, value,
            embed_dim, num_heads,
            self.in_proj_weight, self.in_proj_bias,
            self.bias_k, self.bias_v,
            self.add_zero_attn, self.dropout,
            self.out_proj.weight, self.out_proj.bias,
            training=self.training,
            need_weights=kwargs.get("need_weights", False),
            attn_mask=kwargs.get("attn_mask", None),
            key_padding_mask=kwargs.get("key_padding_mask", None),
            use_separate_proj_weight=False
        )
    return patched_forward

def patch_all_attention_layers(model):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.MultiheadAttention):
            if "encoder" in name:
                layer_type = "encoder"
            elif "self_attn" in name:
                layer_type = "decoder_self"
            else:
                layer_type = "decoder_cross"
            print(f"Patching: {name} --> {layer_type}")
            module.forward = types.MethodType(make_patched_forward(layer_type), module)


In [None]:
# ========== STEP 4: Load DETR and Run Inference ==========
# Load pretrained DETR
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()

# Patch attention layers
patch_all_attention_layers(model)

# Run inference (prints Q/K/V shapes and per-head shapes)
with torch.no_grad():
    outputs = model(img_tensor)


Using cache found in /root/.cache/torch/hub/facebookresearch_detr_main


Patching: transformer.encoder.layers.0.self_attn --> encoder
Patching: transformer.encoder.layers.1.self_attn --> encoder
Patching: transformer.encoder.layers.2.self_attn --> encoder
Patching: transformer.encoder.layers.3.self_attn --> encoder
Patching: transformer.encoder.layers.4.self_attn --> encoder
Patching: transformer.encoder.layers.5.self_attn --> encoder
Patching: transformer.decoder.layers.0.self_attn --> decoder_self
Patching: transformer.decoder.layers.0.multihead_attn --> decoder_cross
Patching: transformer.decoder.layers.1.self_attn --> decoder_self
Patching: transformer.decoder.layers.1.multihead_attn --> decoder_cross
Patching: transformer.decoder.layers.2.self_attn --> decoder_self
Patching: transformer.decoder.layers.2.multihead_attn --> decoder_cross
Patching: transformer.decoder.layers.3.self_attn --> decoder_self
Patching: transformer.decoder.layers.3.multihead_attn --> decoder_cross
Patching: transformer.decoder.layers.4.self_attn --> decoder_self
Patching: transf