<a href="https://colab.research.google.com/github/khanmhmdi/Moe-llm-edge-computing/blob/OBJECT-DETECTION/DETR/Evaluating_DETR_on_COCO_validation_2017.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we are going to evaluate `DetrForObjectDetection` on the COCO evaluation 2017 dataset.

Make sure to set Runtime to GPU.

## Set up environment

We install Transformers straight from Github, and [timm](https://github.com/rwightman/pytorch-image-models/), which is the library that is used for the convolutional backbone of DETR. DETR uses a ResNet-50 (or ResNet-101 for its larger variant) as backbone.

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git timm

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m69.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m50.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.3 MB/s[0m eta [36m0:0

## Create PyTorch dataset + dataloaders

Here we define a regular PyTorch dataset. Torchvision already provides a `CocoDetection` dataset, which we can use. We only add a feature extractor (namely `DetrFeatureExtractor`) to turn the data in COCO format in the format that DETR expects. It will resize the images and corresponding annotated bounding boxes, and normalize the images across the RGB channels using the ImageNet mean and standard deviation.

In [2]:
!wget http://images.cocodataset.org/zips/val2017.zip
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
import zipfile
with zipfile.ZipFile("/content/val2017.zip", 'r') as zip_ref:
    zip_ref.extractall("/content/a")
import zipfile
with zipfile.ZipFile("/content/annotations_trainval2017.zip", 'r') as zip_ref:
    zip_ref.extractall("/content/a")

--2025-03-27 18:14:22--  http://images.cocodataset.org/zips/val2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.146.123, 16.182.103.201, 54.231.199.121, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.146.123|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 815585330 (778M) [application/zip]
Saving to: ‘val2017.zip’


2025-03-27 18:14:37 (51.6 MB/s) - ‘val2017.zip’ saved [815585330/815585330]

--2025-03-27 18:14:37--  http://images.cocodataset.org/annotations/annotations_trainval2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 16.182.68.73, 3.5.25.244, 3.5.28.235, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|16.182.68.73|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 252907541 (241M) [application/zip]
Saving to: ‘annotations_trainval2017.zip’


2025-03-27 18:14:42 (50.4 MB/s) - ‘annotations_trainval2017.zip’ saved [252907541/252907541]



In [24]:
!wget http://images.cocodataset.org/zips/train2017.zip

--2025-03-27 18:33:16--  http://images.cocodataset.org/zips/train2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 3.5.30.217, 52.217.228.209, 3.5.30.50, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|3.5.30.217|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 19336861798 (18G) [application/zip]
Saving to: ‘train2017.zip’


2025-03-27 18:39:30 (49.3 MB/s) - ‘train2017.zip’ saved [19336861798/19336861798]



In [26]:
with zipfile.ZipFile("/content/detr/train2017.zip", 'r') as zip_ref:
    zip_ref.extractall("/content/t")

In [3]:
import torchvision
import os

class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, feature_extractor):
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self.feature_extractor = feature_extractor

    def __getitem__(self, idx):
        # read in PIL image and target in COCO format
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}

        # preprocess image and target (converting target to DETR format, resizing + normalization of both image and target)
        encoding = self.feature_extractor(images=img, annotations=target, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
        target = encoding["labels"][0] # remove batch dimension

        return pixel_values, target

When initializing the `test_dataset`, one should specify the path to the directory containing the images, and the directory containing the annotation JSON file. You can download COCO detection validation 2017 from the [official website](https://cocodataset.org/#download). I've stored them in my personal Google Drive.

Note that we call it a test dataset here, wherehas COCO calls it a "validation" set. It consists of 5,000 annotated images.

In [9]:
from transformers import DetrFeatureExtractor

feature_extractor = DetrFeatureExtractor()

test_dataset = CocoDetection(img_folder='/content/a/val2017',
                              ann_file='/content/a/annotations/instances_val2017.json',
                              feature_extractor=feature_extractor)

loading annotations into memory...
Done (t=1.41s)
creating index...
index created!


In [30]:
print(len(test_dataset))

118287


Let's visualize a random image with its annotations:

In [31]:
import numpy as np
import os
from PIL import Image, ImageDraw

# based on https://github.com/woctezuma/finetune-detr/blob/master/finetune_detr.ipynb
image_ids = test_dataset.coco.getImgIds()
# let's pick a random image
image_id = image_ids[np.random.randint(0, len(image_ids))]
print('Image n°{}'.format(image_id))
image = test_dataset.coco.loadImgs(image_id)[0]
image = Image.open(os.path.join('/content/a/val2017', image['file_name']))

annotations = test_dataset.coco.imgToAnns[image_id]
draw = ImageDraw.Draw(image, "RGBA")

cats = test_dataset.coco.cats
id2label = {k: v['name'] for k,v in cats.items()}

for annotation in annotations:
  box = annotation['bbox']
  class_idx = annotation['category_id']
  x,y,w,h = tuple(box)
  draw.rectangle((x,y,x+w,y+h), outline='red', width=1)
  draw.text((x, y), id2label[class_idx], fill='white')

image

Image n°554752


FileNotFoundError: [Errno 2] No such file or directory: '/content/a/val2017/000000554752.jpg'

Next, let's create a corresponding `test_dataloader`. As `DetrFeatureExtractor` resizes every image with a min_size of 800 and a max_size of 1333, images can have different sizes. Hence, we define a custom `collate_fn`, which will batch images and corresponding labels together, by padding images in a batch up to the largest one, and also create a `pixel_mask` that indicates which pixels are real/which are padding.

In [32]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  pixel_values = [item[0] for item in batch]
  encoding = feature_extractor.pad(pixel_values, return_tensors="pt")
  labels = [item[1] for item in batch]
  batch = {}
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = labels
  return batch

test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=2)

Let's verify the first batch:

In [33]:
batch = next(iter(test_dataloader))
batch.keys()

dict_keys(['pixel_values', 'pixel_mask', 'labels'])

In [34]:
batch['pixel_values'].shape

torch.Size([2, 3, 800, 1201])

In [35]:
batch['pixel_mask'].shape

torch.Size([2, 800, 1201])

## Run evaluation

Finally, we can run the evaluation. The original implementation of DETR (by Facebook AI) contains a nice `CocoEvaluator` object, which allows us to easily compute the metrics. Let's use it, by git cloning the original repository.

In [15]:
! git clone https://github.com/facebookresearch/detr.git
%cd detr

Cloning into 'detr'...
remote: Enumerating objects: 265, done.[K
remote: Total 265 (delta 0), reused 0 (delta 0), pack-reused 265 (from 1)[K
Receiving objects: 100% (265/265), 21.19 MiB | 17.33 MiB/s, done.
Resolving deltas: 100% (120/120), done.
/content/detr


We initialize the `CocoEvaluator` by providing the ground truths to it.

In [2]:
%cd /content/detr/
# from coco_utils import get_coco_api_from_dataset
from datasets.coco_eval import CocoEvaluator # Import CocoEvaluator
from datasets import get_coco_api_from_dataset

base_ds = get_coco_api_from_dataset(test_dataset) # this is actually just calling the coco attribute
iou_types = ['bbox']
coco_evaluator = CocoEvaluator(base_ds, iou_types) # initialize evaluator with ground truths

/content/detr


NameError: name 'test_dataset' is not defined

Now, let's load the trained model from HuggingFace's model hub, and run the evaluation, batch by batch. Each time a batch is processed, the metrics are added to the `CocoEvaluator`.

In [37]:
from transformers import DetrForObjectDetection

model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Below, we move the model to the GPU. Note that we don't need to put the model in evaluation mode (`model.eval()`) as this is done by default by the .`from_pretrained()` method above. We use the `tqdm` library to print a nice progress bar.

As this can take a while, I limitted the evaluation below to 50 batches. You can of course remove that statement to run the evaluation on the entire validation set.

In [38]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from transformers import DetrFeatureExtractor, DetrForObjectDetection
from datasets.coco_eval import CocoEvaluator
from datasets import get_coco_api_from_dataset
from tqdm.notebook import tqdm

# Custom COCO dataset that uses the DETR feature_extractor.
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, feature_extractor):
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self.feature_extractor = feature_extractor

    def __getitem__(self, idx):
        # Read in PIL image and target in COCO format
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}

        # Preprocess image and target: conversion to DETR format, resizing, normalization, etc.
        encoding = self.feature_extractor(images=img, annotations=target, return_tensors="pt")
        # Remove batch dimension added by feature_extractor
        pixel_values = encoding["pixel_values"].squeeze(0)
        # DETR returns 'labels' as a list inside the encoding when annotations are provided.
        target = encoding["labels"][0]  # remove batch dimension

        return pixel_values, target

# Initialize the feature extractor.
feature_extractor = DetrFeatureExtractor()

# Define paths for your COCO dataset.
img_folder = '/content/t/train2017'
ann_file = '/content/a/annotations/instances_train2017.json'

# Create the dataset instance.
dataset = CocoDetection(img_folder=img_folder,
                        ann_file=ann_file,
                        feature_extractor=feature_extractor)

# Define a collate function to combine images/targets into a batch.
def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = feature_extractor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    batch = {
        'pixel_values': encoding['pixel_values'],
        'pixel_mask': encoding['pixel_mask'],
        'labels': labels,
    }
    return batch

# Create DataLoader. (Here we use the same dataset for demonstration.
# In practice, use a separate training split.)
train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=2, shuffle=True)
eval_dataloader  = DataLoader(dataset, collate_fn=collate_fn, batch_size=2)

# Initialize the evaluation helper (COCO API).
base_ds = get_coco_api_from_dataset(dataset)
iou_types = ['bbox']
coco_evaluator = CocoEvaluator(base_ds, iou_types)

# Load the DETR model.
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

# Set the device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

################################################################################
# Fine-tuning Setup: Define L0 Regularization on FFN (fc1 and fc2) layers.
################################################################################

l0_lambda = 1e-4  # Hyperparameter for L0 regularization strength.

def approximate_l0(weight):
    """
    A smooth surrogate for the L0 norm.
    Applies a scaled sigmoid to the absolute weight values.
    Adjust scaling factor and offset as needed.
    """
    return torch.sum(torch.sigmoid(10 * (torch.abs(weight) - 0.001)))

def compute_l0_loss(model, l0_lambda):
    """
    Computes the total L₀ loss for each transformer layer by concatenating the flattened weights
    of both fc1 and fc2 layers and applying the L0 surrogate function on the combined tensor.

    The model architecture is assumed as:
      - Encoder layers: model.model.encoder.layers
      - Decoder layers: model.model.decoder.layers
    """
    l0_loss = 0.0

    # Process encoder layers.
    for layer in model.model.encoder.layers:
        # Concatenate flattened weights of fc1 and fc2.
        combined_weights = torch.cat([layer.fc1.weight.view(-1), layer.fc2.weight.view(-1)])
        l0_loss += l0_lambda * approximate_l0(combined_weights)

    # Process decoder layers.
    for layer in model.model.decoder.layers:
        combined_weights = torch.cat([layer.fc1.weight.view(-1), layer.fc2.weight.view(-1)])
        l0_loss += l0_lambda * approximate_l0(combined_weights)

    return l0_loss

################################################################################
# Optimization Setup & Training Loop with tqdm Progress Bar
################################################################################

optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 3

print("Starting fine-tuning with L0 regularization...")
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # Initialize a tqdm progress bar for the training dataloader.
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", leave=True)

    for batch in progress_bar:
        pixel_values = batch["pixel_values"].to(device)
        pixel_mask = batch["pixel_mask"].to(device)
        # Move each target dictionary to the device.
        labels = [{k: v.to(device) for k, v in target.items()} for target in batch["labels"]]

        optimizer.zero_grad()

        # Forward pass with detection loss.
        outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
        detection_loss = outputs.loss

        # Compute the combined L0 loss on fc1 and fc2 weights.
        l0_loss = compute_l0_loss(model, l0_lambda)

        total_loss = detection_loss + 0.0001 * l0_loss

        # Backpropagation.
        # total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

        # Update the progress bar with the current total loss.
        progress_bar.set_postfix(loss=total_loss.item())

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}: Average Loss: {avg_loss:.4f}")

    # (Optional) Evaluate after each epoch.
    model.eval()
    coco_evaluator.reset()
    with torch.no_grad():
        for eval_batch in tqdm(eval_dataloader, desc="Evaluation"):
            pixel_values = eval_batch["pixel_values"].to(device)
            pixel_mask = eval_batch["pixel_mask"].to(device)
            labels = [{k: v.to(device) for k, v in t.items()} for t in eval_batch["labels"]]

            outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
            # Assume that each label dict has an "orig_size" key for post_process.
            orig_target_sizes = torch.stack([t["orig_size"] for t in labels], dim=0)
            results = feature_extractor.post_process(outputs, orig_target_sizes)
            res = {target['image_id'].item(): output for target, output in zip(labels, results)}
            coco_evaluator.update(res)

    coco_evaluator.synchronize_between_processes()
    coco_evaluator.accumulate()
    coco_evaluator.summarize()

print("Training complete.")


loading annotations into memory...
Done (t=22.06s)
creating index...
index created!


Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Starting fine-tuning with L0 regularization...


Epoch 1:   0%|          | 0/59144 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [12]:
import os
import copy # Needed for deep copying modules if necessary, though direct replacement might work
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from transformers import DetrFeatureExtractor, DetrForObjectDetection
from datasets.coco_eval import CocoEvaluator
from datasets import get_coco_api_from_dataset
from tqdm.notebook import tqdm
import numpy as np # For calculating number of parameters

# --- Dataset Loading and Preparation (Copied from your code) ---
# Custom COCO dataset that uses the DETR feature_extractor.
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, feature_extractor):
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self.feature_extractor = feature_extractor

    def __getitem__(self, idx):
        # Read in PIL image and target in COCO format
        try:
            img, target = super(CocoDetection, self).__getitem__(idx)
        except (OSError, FileNotFoundError) as e:
            print(f"Skipping image {self.ids[idx]} due to error: {e}")
            # Return None or handle appropriately; here we might skip by returning None
            # and filtering in collate_fn or DataLoader. For simplicity, let's raise for now
            # or return a dummy item if your collate_fn handles it.
            # A better approach is to filter self.ids during init or use a sampler.
            # Let's try returning None and filtering in collate_fn
            return None

        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}

        # Preprocess image and target: conversion to DETR format, resizing, normalization, etc.
        # Add error handling for feature extractor if needed
        try:
            encoding = self.feature_extractor(images=img, annotations=target, return_tensors="pt")
            # Remove batch dimension added by feature_extractor
            pixel_values = encoding["pixel_values"].squeeze(0)
            # DETR returns 'labels' as a list inside the encoding when annotations are provided.
            target = encoding["labels"][0]  # remove batch dimension
        except Exception as e:
            print(f"Skipping image {image_id} due to feature extraction error: {e}")
            return None # Filter this out later

        return pixel_values, target

# Initialize the feature extractor.
# Use recommended settings for evaluation if needed, though default is often fine.
feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")

# Define paths for your COCO dataset. Adjust if necessary.
# Ensure these paths are correct and accessible in your environment.
img_folder = '/content/a/val2017' # Example path
ann_file = '/content/a/annotations/instances_val2017.json' # Example path

# --- Safety check for paths ---
if not os.path.isdir(img_folder):
    print(f"Warning: Image folder not found at {img_folder}")
    # Consider raising an error or exiting if data is essential
if not os.path.isfile(ann_file):
    print(f"Warning: Annotation file not found at {ann_file}")
    # Consider raising an error or exiting

# Create the dataset instance. Add error handling if paths are invalid.
try:
    dataset = CocoDetection(img_folder=img_folder,
                            ann_file=ann_file,
                            feature_extractor=feature_extractor)
    if len(dataset) == 0:
         print("Warning: Dataset loaded but is empty. Check paths and annotation file.")
except Exception as e:
    print(f"Error creating dataset: {e}. Please check paths and file integrity.")
    # Exit or handle error appropriately
    exit()

# Define a collate function to combine images/targets into a batch.
# Filter out None items resulting from dataset errors
def collate_fn(batch):
    batch = [item for item in batch if item is not None] # Filter out problematic samples
    if not batch: # If batch becomes empty after filtering
        return None

    pixel_values = [item[0] for item in batch]
    encoding = feature_extractor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    # Add 'orig_size' to labels if not present, needed for evaluation post_process
    # This should ideally come from the dataset __getitem__ based on original image size
    # Adding dummy values here, replace with actual sizes if possible
    for i, label_dict in enumerate(labels):
         if 'orig_size' not in label_dict:
             # Assuming square images from feature extractor typical size, adjust if needed
             # This is a fallback, actual original size is better.
             img_height, img_width = pixel_values[i].shape[1:] # Get size after feature extraction TBC
             label_dict['orig_size'] = torch.tensor([img_height, img_width], device=label_dict['image_id'].device) # Ensure device consistency


    batch_dict = {
        'pixel_values': encoding['pixel_values'],
        'pixel_mask': encoding['pixel_mask'],
        'labels': labels,
    }
    return batch_dict

# Create DataLoader. (Using same dataset for train/eval for demo - use separate splits!)
# Reduced batch_size for potential memory constraints after loading model
# Use num_workers > 0 for faster data loading if possible, but can cause issues in some environments
# Handle potential empty dataset
if len(dataset) > 0:
    train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=2, shuffle=True, num_workers=0) # Start with 0 workers for stability
    eval_dataloader  = DataLoader(dataset, collate_fn=collate_fn, batch_size=2, num_workers=0)
else:
    print("Dataset is empty, cannot create DataLoaders.")
    train_dataloader = None
    eval_dataloader = None
    # Exit or handle appropriately
    exit()

# Initialize the evaluation helper (COCO API).
try:
    base_ds = get_coco_api_from_dataset(dataset) # Make sure dataset is valid coco
    iou_types = ['bbox']
    coco_evaluator = CocoEvaluator(base_ds, iou_types)
except Exception as e:
    print(f"Error initializing CocoEvaluator: {e}. Ensure dataset is COCO format.")
    # Fallback or exit
    coco_evaluator = None
    exit()

# Load the DETR model.
print("Loading pre-trained DETR model...")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
print("Model loaded.")

# Set the device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

# --- Function to count parameters ---
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --- One-Shot Pruning Implementation ---

def calculate_ffn_neuron_importance(model):
    """
    Calculates the L1 norm of weights connected *from* each intermediate neuron
    in the FFN layers (fc1 -> fc2) of the Transformer encoder and decoder.

    Returns:
        List of tuples: [(importance_score, layer_type, layer_index, neuron_index, fc1_layer, fc2_layer), ...]
                       Sorted by importance_score ascending (least important first).
    """
    importances = []
    model_base = model.model # Access the underlying DetrModel

    print("Calculating neuron importance scores...")
    with torch.no_grad():
        # Encoder FFNs
        for i, layer in enumerate(model_base.encoder.layers):
            fc1 = layer.fc1
            fc2 = layer.fc2
            # Importance = L1 norm of rows in fc2.weight (connections FROM intermediate neuron)
            # Shape of fc2.weight: (output_dim, intermediate_dim) -> output_dim = model_dim, intermediate_dim = ffn_dim
            l1_norms = torch.linalg.norm(fc2.weight, ord=1, dim=0) # Norm along columns of fc2.weight = rows of fc2.weight.T
            for j in range(l1_norms.shape[0]): # Iterate through intermediate neurons
                importances.append({
                    'score': l1_norms[j].item(),
                    'layer_type': 'encoder',
                    'layer_index': i,
                    'neuron_index': j, # Index in the intermediate dimension
                    'fc1': fc1,
                    'fc2': fc2
                })

        # Decoder FFNs
        for i, layer in enumerate(model_base.decoder.layers):
            fc1 = layer.fc1
            fc2 = layer.fc2
            # Importance = L1 norm of rows in fc2.weight
            l1_norms = torch.linalg.norm(fc2.weight, ord=1, dim=0)
            for j in range(l1_norms.shape[0]): # Iterate through intermediate neurons
                importances.append({
                    'score': l1_norms[j].item(),
                    'layer_type': 'decoder',
                    'layer_index': i,
                    'neuron_index': j, # Index in the intermediate dimension
                    'fc1': fc1,
                    'fc2': fc2
                })

    # Sort by score (ascending)
    importances.sort(key=lambda x: x['score'])
    print(f"Calculated importance for {len(importances)} FFN neurons.")
    return importances

def prune_ffn_neurons(model, importances, pruning_ratio):
    """
    Prunes the FFN neurons based on the provided importance scores and ratio.
    This function PERMANENTLY removes neurons by reconstructing the layers.

    Args:
        model: The DETR model to prune.
        importances (list): Sorted list of neuron importances from calculate_ffn_neuron_importance.
        pruning_ratio (float): The fraction of neurons to prune (e.g., 0.3 for 30%).
    """
    if not 0 <= pruning_ratio < 1:
        raise ValueError("Pruning ratio must be between 0 (inclusive) and 1 (exclusive).")

    total_neurons = len(importances)
    num_to_prune = int(total_neurons * pruning_ratio)

    if num_to_prune == 0:
        print("Pruning ratio is too low, no neurons will be pruned.")
        return

    print(f"Targeting {num_to_prune} out of {total_neurons} FFN neurons for pruning ({pruning_ratio*100:.2f}%)...")

    # Get the details of neurons to prune
    neurons_to_prune = importances[:num_to_prune]

    # Group neurons to prune by the layer they belong to
    pruning_plan = {} # Key: (layer_type, layer_index), Value: list of neuron_indices to prune

    for neuron_info in neurons_to_prune:
        key = (neuron_info['layer_type'], neuron_info['layer_index'])
        if key not in pruning_plan:
            pruning_plan[key] = {
                'fc1': neuron_info['fc1'],
                'fc2': neuron_info['fc2'],
                'indices_to_prune': []
            }
        pruning_plan[key]['indices_to_prune'].append(neuron_info['neuron_index'])

    print(f"Pruning neurons across {len(pruning_plan)} FFN blocks.")

    # --- Perform the actual pruning by layer reconstruction ---
    model_base = model.model
    device = next(model.parameters()).device # Get device model is on

    with torch.no_grad():
        for (layer_type, layer_index), plan in tqdm(pruning_plan.items(), desc="Reconstructing Layers"):
            fc1_orig = plan['fc1']
            fc2_orig = plan['fc2']
            indices_to_prune_set = set(plan['indices_to_prune'])

            original_intermediate_dim = fc1_orig.out_features
            original_input_dim = fc1_orig.in_features # Should be model dim (d_model)
            original_output_dim = fc2_orig.out_features # Should be model dim (d_model)

            # Determine indices to KEEP
            indices_to_keep = sorted([i for i in range(original_intermediate_dim) if i not in indices_to_prune_set])
            new_intermediate_dim = len(indices_to_keep)

            if new_intermediate_dim == 0:
                print(f"Warning: Attempting to prune all neurons in {layer_type} layer {layer_index}. Skipping this layer.")
                continue # Avoid creating empty layers

            # Create new layers
            fc1_new = nn.Linear(original_input_dim, new_intermediate_dim, bias=fc1_orig.bias is not None).to(device)
            fc2_new = nn.Linear(new_intermediate_dim, original_output_dim, bias=fc2_orig.bias is not None).to(device)

            # Copy weights and biases for KEPT indices
            # fc1: Weight shape (intermediate_dim, input_dim), Bias shape (intermediate_dim)
            fc1_new.weight.data = fc1_orig.weight.data[indices_to_keep, :]
            if fc1_orig.bias is not None:
                fc1_new.bias.data = fc1_orig.bias.data[indices_to_keep]

            # fc2: Weight shape (output_dim, intermediate_dim), Bias shape (output_dim)
            fc2_new.weight.data = fc2_orig.weight.data[:, indices_to_keep]
            if fc2_orig.bias is not None:
                fc2_new.bias.data = fc2_orig.bias.data # fc2 bias is not affected by intermediate dim change

            # Replace original layers in the model structure
            if layer_type == 'encoder':
                model_base.encoder.layers[layer_index].fc1 = fc1_new
                model_base.encoder.layers[layer_index].fc2 = fc2_new
            elif layer_type == 'decoder':
                model_base.decoder.layers[layer_index].fc1 = fc1_new
                model_base.decoder.layers[layer_index].fc2 = fc2_new

    print("Pruning complete. Layers have been reconstructed.")


# --- Main Pruning Execution ---

# 1. Calculate Parameter Count Before Pruning
params_before = count_parameters(model)
print(f"Parameters before pruning: {params_before:,}")

# 2. Calculate Importance Scores
neuron_importances = calculate_ffn_neuron_importance(model)

# 3. Define Pruning Ratio and Prune
PRUNING_RATIO = 0.95 # Example: Prune 30% of the least important FFN neurons globally
prune_ffn_neurons(model, neuron_importances, PRUNING_RATIO)

# 4. Calculate Parameter Count After Pruning
params_after = count_parameters(model)
print(f"Parameters after pruning:  {params_after:,}")
reduction = (params_before - params_after) / params_before * 100
print(f"Parameter reduction: {reduction:.2f}%")

# --- Fine-tuning After Pruning (Crucial Step!) ---
# After pruning, the model's accuracy will likely drop significantly.
# Fine-tuning on the target dataset is necessary to recover performance.

print("\n--- Starting Fine-tuning of the Pruned Model ---")

# Define optimizer for the pruned model (it has different parameters now)
# Use a smaller learning rate for fine-tuning
optimizer = optim.Adam(model.parameters(), lr=1e-5) # Example: Lower LR
num_finetune_epochs = 5 # Example: Fine-tune for a few epochs

for epoch in range(num_finetune_epochs):
    model.train()
    running_loss = 0.0

    # Ensure train_dataloader is not None
    if train_dataloader is None:
        print("Train dataloader is not available. Skipping fine-tuning.")
        break

    progress_bar = tqdm(train_dataloader, desc=f"Fine-tune Epoch {epoch+1}", leave=True)

    batch_count = 0
    # for batch in progress_bar:
    #     # Handle potentially empty batches from collate_fn filtering
    #     if batch is None:
    #         print("Skipping empty batch during fine-tuning.")
    #         continue

    #     pixel_values = batch["pixel_values"].to(device)
    #     pixel_mask = batch["pixel_mask"].to(device)
    #     # Move each target dictionary to the device
    #     labels = [{k: v.to(device) for k, v in target.items()} for target in batch["labels"]]

    #     optimizer.zero_grad()

    #     # Forward pass with detection loss
    #     outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
    #     detection_loss = outputs.loss

    #     # Check if loss is valid
    #     if torch.isnan(detection_loss) or torch.isinf(detection_loss):
    #         print(f"Warning: NaN or Inf loss detected at batch {batch_count}. Skipping batch.")
    #         # Potentially skip optimizer step or investigate further
    #         # Consider gradient clipping here if gradients explode
    #         continue


    #     # Backpropagation
    #     detection_loss.backward()

    #     # Optional: Gradient Clipping (can help stabilize fine-tuning after pruning)
    #     # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    #     optimizer.step()

    #     running_loss += detection_loss.item()
    #     batch_count += 1

    #     # Update the progress bar with the current loss
    #     progress_bar.set_postfix(loss=detection_loss.item())

    # if batch_count > 0:
    #    avg_loss = running_loss / batch_count
    #    print(f"Fine-tune Epoch {epoch+1}: Average Loss: {avg_loss:.4f}")
    # else:
    #    print(f"Fine-tune Epoch {epoch+1}: No valid batches processed.")


    # --- Evaluation after each fine-tuning epoch (Optional but Recommended) ---
    if eval_dataloader is not None and coco_evaluator is not None:
        model.eval()
        from datasets.coco_eval import CocoEvaluator # Import CocoEvaluator
        from datasets import get_coco_api_from_dataset

        base_ds = get_coco_api_from_dataset(test_dataset) # this is actually just calling the coco attribute
        iou_types = ['bbox']
        coco_evaluator = CocoEvaluator(base_ds, iou_types) # initialize evaluator with ground truths
        eval_progress_bar = tqdm(eval_dataloader, desc="Evaluation", leave=False)
        t = 5
        with torch.no_grad():
            for eval_batch in eval_progress_bar:
                t = t - 1
                if t==0:
                  break
                if eval_batch is None:
                    print("Skipping empty batch during evaluation.")
                    continue

                pixel_values = eval_batch["pixel_values"].to(device)
                pixel_mask = eval_batch["pixel_mask"].to(device)
                # Ensure labels have 'orig_size' which post_process needs
                labels = [{k: v.to(device) for k, v in t.items()} for t in eval_batch["labels"]]
                # Make sure 'orig_size' exists and is on the correct device
                if not all('orig_size' in t for t in labels):
                     print("Warning: 'orig_size' missing in some evaluation labels. Skipping batch or using fallback.")
                     # Fallback example (use with caution):
                     # h, w = pixel_values.shape[-2:]
                     # dummy_size = torch.tensor([h, w], device=device)
                     # for t in labels: t['orig_size'] = t.get('orig_size', dummy_size)
                     continue # Safer to skip if orig_size is crucial and missing


                outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask)

                # Ensure 'orig_size' is on CPU for post_process if needed, or handle device mismatch
                # The post_process function expects sizes as a tensor on the CPU typically.
                try:
                   orig_target_sizes = torch.stack([t["orig_size"].to("cpu") for t in labels], dim=0)
                   results = feature_extractor.post_process_object_detection(outputs, threshold=0.1, target_sizes=orig_target_sizes) # Use updated function name and add threshold
                except Exception as e:
                   print(f"Error during post-processing: {e}. Skipping batch.")
                   continue

                # Prepare results for COCO evaluator
                # Ensure image_id is correctly extracted and used as key
                res = {}
                for i, target in enumerate(labels):
                   img_id = target['image_id'].item() # Get Python int
                   res[img_id] = results[i] # Results is a list of dicts

                if res: # Only update if results were generated
                   coco_evaluator.update(res)

        # Synchronize, accumulate and summarize results after iterating through all eval batches
        try:
           coco_evaluator.synchronize_between_processes() # Important in distributed settings, safe otherwise
           coco_evaluator.accumulate()
           coco_evaluator.summarize() # Prints COCO mAP scores
        except Exception as e:
           print(f"Error during COCO evaluation summary: {e}")

    else:
        print("Evaluation dataloader or evaluator not available. Skipping evaluation.")


print("Fine-tuning complete.")

# You can now save the pruned and fine-tuned model
# torch.save(model.state_dict(), "detr_resnet50_pruned_finetuned.pth")


loading annotations into memory...
Done (t=0.76s)
creating index...
index created!
Loading pre-trained DETR model...


Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model loaded.
Using device: cuda
Parameters before pruning: 41,302,368
Calculating neuron importance scores...
Calculated importance for 24576 FFN neurons.
Targeting 23347 out of 24576 FFN neurons for pruning (95.00%)...
Pruning neurons across 12 FFN blocks.


Reconstructing Layers:   0%|          | 0/12 [00:00<?, ?it/s]

Pruning complete. Layers have been reconstructed.
Parameters after pruning:  30,375,981
Parameter reduction: 26.45%

--- Starting Fine-tuning of the Pruned Model ---


Fine-tune Epoch 1:   0%|          | 0/2500 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/2500 [00:00<?, ?it/s]

Accumulating evaluation results...
DONE (t=0.05s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.159
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.284
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.148
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.019
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.165
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.426
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.095
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.248
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.263
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.022
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.297
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

Fine-tune Epoch 2:   0%|          | 0/2500 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/2500 [00:00<?, ?it/s]

Accumulating evaluation results...
DONE (t=0.05s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.159
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.284
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.148
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.019
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.165
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.426
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.095
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.248
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.263
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.022
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.297
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

Fine-tune Epoch 3:   0%|          | 0/2500 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/2500 [00:00<?, ?it/s]

Accumulating evaluation results...
DONE (t=0.05s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.159
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.284
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.148
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.019
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.165
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.426
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.095
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.248
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.263
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.022
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.297
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

Fine-tune Epoch 4:   0%|          | 0/2500 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/2500 [00:00<?, ?it/s]

Accumulating evaluation results...
DONE (t=0.06s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.159
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.284
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.148
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.019
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.165
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.426
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.095
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.248
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.263
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.022
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.297
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

Fine-tune Epoch 5:   0%|          | 0/2500 [00:00<?, ?it/s]

Evaluation:   0%|          | 0/2500 [00:00<?, ?it/s]

Accumulating evaluation results...
DONE (t=0.07s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.159
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.284
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.148
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.019
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.165
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.426
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.095
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.248
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.263
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.022
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.297
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

In [4]:
%cd /content/detr/

import os
import copy # Needed for deep copying modules if necessary, though direct replacement might work
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from transformers import DetrFeatureExtractor, DetrForObjectDetection
from datasets.coco_eval import CocoEvaluator
from datasets import get_coco_api_from_dataset
from tqdm.notebook import tqdm
import numpy as np # For calculating number of parameters
from collections import defaultdict


from transformers import DetrFeatureExtractor

feature_extractor = DetrFeatureExtractor()

test_dataset = CocoDetection(img_folder='/content/a/val2017',
                              ann_file='/content/a/annotations/instances_val2017.json',
                              feature_extractor=feature_extractor)
# --- Dataset Loading and Preparation (Copied from your code) ---
# Custom COCO dataset that uses the DETR feature_extractor.
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, feature_extractor):
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self.feature_extractor = feature_extractor

    def __getitem__(self, idx):
        # Read in PIL image and target in COCO format
        try:
            img, target = super(CocoDetection, self).__getitem__(idx)
        except (OSError, FileNotFoundError) as e:
            print(f"Skipping image {self.ids[idx]} due to error: {e}")
            # Return None or handle appropriately; here we might skip by returning None
            # and filtering in collate_fn or DataLoader. For simplicity, let's raise for now
            # or return a dummy item if your collate_fn handles it.
            # A better approach is to filter self.ids during init or use a sampler.
            # Let's try returning None and filtering in collate_fn
            return None

        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}

        # Preprocess image and target: conversion to DETR format, resizing, normalization, etc.
        # Add error handling for feature extractor if needed
        try:
            encoding = self.feature_extractor(images=img, annotations=target, return_tensors="pt")
            # Remove batch dimension added by feature_extractor
            pixel_values = encoding["pixel_values"].squeeze(0)
            # DETR returns 'labels' as a list inside the encoding when annotations are provided.
            target = encoding["labels"][0]  # remove batch dimension
        except Exception as e:
            print(f"Skipping image {image_id} due to feature extraction error: {e}")
            return None # Filter this out later

        return pixel_values, target

# Initialize the feature extractor.
# Use recommended settings for evaluation if needed, though default is often fine.
feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")

# Define paths for your COCO dataset. Adjust if necessary.
# Ensure these paths are correct and accessible in your environment.
img_folder = '/content/a/val2017' # Example path
ann_file = '/content/a/annotations/instances_val2017.json' # Example path

# --- Safety check for paths ---
if not os.path.isdir(img_folder):
    print(f"Warning: Image folder not found at {img_folder}")
    # Consider raising an error or exiting if data is essential
if not os.path.isfile(ann_file):
    print(f"Warning: Annotation file not found at {ann_file}")
    # Consider raising an error or exiting

# Create the dataset instance. Add error handling if paths are invalid.
try:
    dataset = CocoDetection(img_folder=img_folder,
                            ann_file=ann_file,
                            feature_extractor=feature_extractor)
    if len(dataset) == 0:
         print("Warning: Dataset loaded but is empty. Check paths and annotation file.")
except Exception as e:
    print(f"Error creating dataset: {e}. Please check paths and file integrity.")
    # Exit or handle error appropriately
    exit()

# Define a collate function to combine images/targets into a batch.
# Filter out None items resulting from dataset errors
def collate_fn(batch):
    batch = [item for item in batch if item is not None] # Filter out problematic samples
    if not batch: # If batch becomes empty after filtering
        return None

    pixel_values = [item[0] for item in batch]
    encoding = feature_extractor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    # Add 'orig_size' to labels if not present, needed for evaluation post_process
    # This should ideally come from the dataset __getitem__ based on original image size
    # Adding dummy values here, replace with actual sizes if possible
    for i, label_dict in enumerate(labels):
         if 'orig_size' not in label_dict:
             # Assuming square images from feature extractor typical size, adjust if needed
             # This is a fallback, actual original size is better.
             img_height, img_width = pixel_values[i].shape[1:] # Get size after feature extraction TBC
             label_dict['orig_size'] = torch.tensor([img_height, img_width], device=label_dict['image_id'].device) # Ensure device consistency


    batch_dict = {
        'pixel_values': encoding['pixel_values'],
        'pixel_mask': encoding['pixel_mask'],
        'labels': labels,
    }
    return batch_dict

# Create DataLoader. (Using same dataset for train/eval for demo - use separate splits!)
# Reduced batch_size for potential memory constraints after loading model
# Use num_workers > 0 for faster data loading if possible, but can cause issues in some environments
# Handle potential empty dataset
if len(dataset) > 0:
    train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=2, shuffle=True, num_workers=0) # Start with 0 workers for stability
    eval_dataloader  = DataLoader(dataset, collate_fn=collate_fn, batch_size=2, num_workers=0)
else:
    print("Dataset is empty, cannot create DataLoaders.")
    train_dataloader = None
    eval_dataloader = None
    # Exit or handle appropriately
    exit()

# Initialize the evaluation helper (COCO API).
try:
    base_ds = get_coco_api_from_dataset(dataset) # Make sure dataset is valid coco
    iou_types = ['bbox']
    coco_evaluator = CocoEvaluator(base_ds, iou_types)
except Exception as e:
    print(f"Error initializing CocoEvaluator: {e}. Ensure dataset is COCO format.")
    # Fallback or exit
    coco_evaluator = None
    exit()

# Load the DETR model.
print("Loading pre-trained DETR model...")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
print("Model loaded.")

# Set the device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

# --- Function to count parameters ---
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --- One-Shot Pruning Implementation ---

def get_ffn_intermediate_activations(model, calibration_loader, device):
    """
    Runs data through the model and collects intermediate activations from FFN layers using hooks.
    Moves collected activations to CPU to save GPU memory.

    Returns:
        dict: Key: (layer_type, layer_index), Value: List of activation tensors [on CPU] for intermediate neurons
    """
    model.eval()
    activations = defaultdict(list)
    hooks = []

    def get_activation_hook(name):
        def hook(model, input, output):
            # Output is the result AFTER fc1 (and potentially activation)
            # Shape: (batch_size, seq_len, intermediate_dim)
            activations[name].append(output.detach().cpu()) # Move to CPU immediately
        return hook

    # Register hooks
    model_base = model.model
    # Determine where the intermediate activation occurs. Usually after fc1.
    # Let's assume hooking fc1's output is sufficient. If an explicit activation module exists (e.g., nn.ReLU), hook that instead.
    try:
        for i, layer in enumerate(model_base.encoder.layers):
            # Hook the output of the first linear layer in the FFN
            hook_handle = layer.fc1.register_forward_hook(get_activation_hook(('encoder', i)))
            hooks.append(hook_handle)
        print(f"Registered {len(hooks)} hooks for encoder FFN layers.")

        decoder_hooks_start_index = len(hooks)
        for i, layer in enumerate(model_base.decoder.layers):
            # Hook the output of the first linear layer in the FFN
            hook_handle = layer.fc1.register_forward_hook(get_activation_hook(('decoder', i)))
            hooks.append(hook_handle)
        print(f"Registered {len(hooks) - decoder_hooks_start_index} hooks for decoder FFN layers.")

    except AttributeError as e:
        print(f"Error registering hooks: {e}. Check layer naming (fc1) in the model structure.")
        # Clean up any hooks that were successfully registered before the error
        for handle in hooks:
            handle.remove()
        return None # Indicate failure

    print(f"Running calibration data ({len(calibration_loader)} batches) to collect activations...")
    with torch.no_grad():
        for batch in tqdm(calibration_loader, desc="Calibration"):
             if batch is None: continue
             # Ensure all required inputs are moved to the device
             pixel_values = batch["pixel_values"].to(device)
             pixel_mask = batch["pixel_mask"].to(device)
             # No labels needed for activation collection
             _ = model(pixel_values=pixel_values, pixel_mask=pixel_mask)

    # Remove hooks AFTER the loop finishes
    for handle in hooks:
        handle.remove()

    print("Activation collection complete.")
    if not activations:
        print("Warning: No activations were collected. Check hook registration and model structure.")
    return activations

def get_ffn_intermediate_activations(model, calibration_loader, device):
    """
    Runs data through the model and collects intermediate activations from FFN layers using hooks.
    Moves collected activations to CPU to save GPU memory.

    Returns:
        dict: Key: (layer_type, layer_index), Value: List of activation tensors [on CPU] for intermediate neurons
    """
    model.eval()
    activations = defaultdict(list)
    hooks = []

    def get_activation_hook(name):
        def hook(model, input, output):
            # Output is the result AFTER fc1 (and potentially activation)
            # Shape: (batch_size, seq_len, intermediate_dim)
            activations[name].append(output.detach().cpu()) # Move to CPU immediately
        return hook

    # Register hooks
    model_base = model.model
    # Determine where the intermediate activation occurs. Usually after fc1.
    # Let's assume hooking fc1's output is sufficient. If an explicit activation module exists (e.g., nn.ReLU), hook that instead.
    try:
        for i, layer in enumerate(model_base.encoder.layers):
            # Hook the output of the first linear layer in the FFN
            hook_handle = layer.fc1.register_forward_hook(get_activation_hook(('encoder', i)))
            hooks.append(hook_handle)
        print(f"Registered {len(hooks)} hooks for encoder FFN layers.")

        decoder_hooks_start_index = len(hooks)
        for i, layer in enumerate(model_base.decoder.layers):
            # Hook the output of the first linear layer in the FFN
            hook_handle = layer.fc1.register_forward_hook(get_activation_hook(('decoder', i)))
            hooks.append(hook_handle)
        print(f"Registered {len(hooks) - decoder_hooks_start_index} hooks for decoder FFN layers.")

    except AttributeError as e:
        print(f"Error registering hooks: {e}. Check layer naming (fc1) in the model structure.")
        # Clean up any hooks that were successfully registered before the error
        for handle in hooks:
            handle.remove()
        return None # Indicate failure

    print(f"Running calibration data ({len(calibration_loader)} batches) to collect activations...")
    t = 10
    with torch.no_grad():
        for batch in tqdm(calibration_loader, desc="Calibration"):
             t = t - 1
             if t==0:
              break
             if batch is None: continue
             # Ensure all required inputs are moved to the device
             pixel_values = batch["pixel_values"].to(device)
             pixel_mask = batch["pixel_mask"].to(device)
             # No labels needed for activation collection
             _ = model(pixel_values=pixel_values, pixel_mask=pixel_mask)

    # Remove hooks AFTER the loop finishes
    for handle in hooks:
        handle.remove()

    print("Activation collection complete.")
    if not activations:
        print("Warning: No activations were collected. Check hook registration and model structure.")
    return activations


def calculate_ffn_neuron_importance_activation(model, calibration_loader, device):
    """
    Calculates importance based on the average L1 magnitude of activations
    of intermediate neurons in FFN layers, handling varying sequence lengths.

    Returns:
        List of dicts: [{'score': score, 'layer_type': ..., 'layer_index': ..., 'neuron_index': ..., 'fc1': ..., 'fc2': ...}, ...]
                       Sorted by importance_score ascending (least important first).
    """
    # 1. Collect activations
    if calibration_loader is None:
        print("Error: Calibration dataloader is required for activation-based importance.")
        return []
    collected_activations = get_ffn_intermediate_activations(model, calibration_loader, device)
    if collected_activations is None or not collected_activations:
         print("Error: Failed to collect activations.")
         return []

    # 2. Calculate scores by aggregating stats batch-by-batch per layer
    importances = []
    model_base = model.model

    print("Calculating importance scores from activations...")
    # Process collected activations layer by layer
    # collected_activations format: {(layer_type, layer_index): [batch1_acts_cpu, batch2_acts_cpu, ...]}
    for (layer_type, layer_index), activation_list in tqdm(collected_activations.items(), desc="Processing Layers"):
        if not activation_list:
            print(f"Warning: No activations found for layer {layer_type} {layer_index}. Skipping.")
            continue

        # Initialize accumulators for this layer
        total_sum_abs_acts_per_neuron = None
        total_samples = 0
        intermediate_dim = -1

        # Iterate through activations from each batch for this specific layer
        for batch_acts_cpu in activation_list:
            # batch_acts_cpu shape is (batch_size, seq_len, intermediate_dim)
            if batch_acts_cpu.dim() != 3:
                print(f"Warning: Unexpected activation dimension {batch_acts_cpu.dim()} in batch for {layer_type} {layer_index}. Skipping this batch.")
                continue

            current_intermediate_dim = batch_acts_cpu.shape[-1]

            # Initialize sum tensor on first valid batch for this layer
            if total_sum_abs_acts_per_neuron is None:
                intermediate_dim = current_intermediate_dim
                # Initialize on CPU, matches the batch_acts_cpu device
                total_sum_abs_acts_per_neuron = torch.zeros(intermediate_dim, device=batch_acts_cpu.device)
            # Ensure consistency across batches for the same layer
            elif intermediate_dim != current_intermediate_dim:
                 print(f"Warning: Inconsistent intermediate dimension found within layer {layer_type} {layer_index}. Expected {intermediate_dim}, got {current_intermediate_dim}. Skipping this batch.")
                 continue

            # Calculate sum of absolute values per neuron for this batch
            # Sum over batch and sequence length dimensions (0 and 1)
            sum_abs_batch = torch.sum(torch.abs(batch_acts_cpu), dim=(0, 1)) # Shape: (intermediate_dim,)

            # Accumulate sums (already on CPU)
            total_sum_abs_acts_per_neuron += sum_abs_batch

            # Accumulate count of samples (batch_size * seq_len)
            total_samples += batch_acts_cpu.shape[0] * batch_acts_cpu.shape[1]

        # Check if we collected any valid data for this layer after iterating through batches
        if total_samples == 0 or total_sum_abs_acts_per_neuron is None:
            print(f"Warning: No valid activation samples aggregated for layer {layer_type} {layer_index}. Skipping.")
            continue

        # Calculate average magnitude per neuron for this layer
        avg_mag_per_neuron = total_sum_abs_acts_per_neuron / total_samples

        # Retrieve corresponding layers (fc1, fc2)
        try:
            if layer_type == 'encoder':
                layer = model_base.encoder.layers[layer_index]
            elif layer_type == 'decoder':
                layer = model_base.decoder.layers[layer_index]
            else:
                print(f"Warning: Unknown layer_type '{layer_type}'. Skipping.")
                continue
            fc1 = layer.fc1
            fc2 = layer.fc2
        except (AttributeError, IndexError) as e:
             print(f"Warning: Could not retrieve fc1/fc2 for {layer_type} layer {layer_index}. Error: {e}. Skipping.")
             continue


        # Dimension verification (crucial!)
        if fc1.out_features != intermediate_dim or fc2.in_features != intermediate_dim:
             print(f"CRITICAL Warning: Dimension mismatch after aggregation for {layer_type} layer {layer_index}.")
             print(f"  Aggregated Activation dim: {intermediate_dim}")
             print(f"  fc1 out_features: {fc1.out_features}")
             print(f"  fc2 in_features: {fc2.in_features}")
             print(f"  This suggests a potential issue in hook placement or model understanding. Skipping layer.")
             continue

        # Store importance scores for each neuron in this layer
        for j in range(intermediate_dim):
            importances.append({
                'score': avg_mag_per_neuron[j].item(), # Lower score = less important
                'layer_type': layer_type,
                'layer_index': layer_index,
                'neuron_index': j,
                'fc1': fc1, # Store reference to the actual layer
                'fc2': fc2  # Store reference to the actual layer
            })

    # Sort all collected importances by score (ascending)
    importances.sort(key=lambda x: x['score'])
    print(f"Calculated activation-based importance for {len(importances)} FFN neurons across all layers.")

    if not importances:
        print("Warning: No importance scores were calculated. Pruning cannot proceed.")

    return importances

def prune_ffn_neurons(model, importances, pruning_ratio):
    """
    Prunes the FFN neurons based on the provided importance scores and ratio.
    This function PERMANENTLY removes neurons by reconstructing the layers.

    Args:
        model: The DETR model to prune.
        importances (list): Sorted list of neuron importances from calculate_ffn_neuron_importance.
        pruning_ratio (float): The fraction of neurons to prune (e.g., 0.3 for 30%).
    """
    if not 0 <= pruning_ratio < 1:
        raise ValueError("Pruning ratio must be between 0 (inclusive) and 1 (exclusive).")

    total_neurons = len(importances)
    num_to_prune = int(total_neurons * pruning_ratio)

    if num_to_prune == 0:
        print("Pruning ratio is too low, no neurons will be pruned.")
        return

    print(f"Targeting {num_to_prune} out of {total_neurons} FFN neurons for pruning ({pruning_ratio*100:.2f}%)...")

    # Get the details of neurons to prune
    neurons_to_prune = importances[:num_to_prune]

    # Group neurons to prune by the layer they belong to
    pruning_plan = {} # Key: (layer_type, layer_index), Value: list of neuron_indices to prune

    for neuron_info in neurons_to_prune:
        key = (neuron_info['layer_type'], neuron_info['layer_index'])
        if key not in pruning_plan:
            pruning_plan[key] = {
                'fc1': neuron_info['fc1'],
                'fc2': neuron_info['fc2'],
                'indices_to_prune': []
            }
        pruning_plan[key]['indices_to_prune'].append(neuron_info['neuron_index'])

    print(f"Pruning neurons across {len(pruning_plan)} FFN blocks.")

    # --- Perform the actual pruning by layer reconstruction ---
    model_base = model.model
    device = next(model.parameters()).device # Get device model is on

    with torch.no_grad():
        for (layer_type, layer_index), plan in tqdm(pruning_plan.items(), desc="Reconstructing Layers"):
            fc1_orig = plan['fc1']
            fc2_orig = plan['fc2']
            indices_to_prune_set = set(plan['indices_to_prune'])

            original_intermediate_dim = fc1_orig.out_features
            original_input_dim = fc1_orig.in_features # Should be model dim (d_model)
            original_output_dim = fc2_orig.out_features # Should be model dim (d_model)

            # Determine indices to KEEP
            indices_to_keep = sorted([i for i in range(original_intermediate_dim) if i not in indices_to_prune_set])
            new_intermediate_dim = len(indices_to_keep)

            if new_intermediate_dim == 0:
                print(f"Warning: Attempting to prune all neurons in {layer_type} layer {layer_index}. Skipping this layer.")
                continue # Avoid creating empty layers

            # Create new layers
            fc1_new = nn.Linear(original_input_dim, new_intermediate_dim, bias=fc1_orig.bias is not None).to(device)
            fc2_new = nn.Linear(new_intermediate_dim, original_output_dim, bias=fc2_orig.bias is not None).to(device)

            # Copy weights and biases for KEPT indices
            # fc1: Weight shape (intermediate_dim, input_dim), Bias shape (intermediate_dim)
            fc1_new.weight.data = fc1_orig.weight.data[indices_to_keep, :]
            if fc1_orig.bias is not None:
                fc1_new.bias.data = fc1_orig.bias.data[indices_to_keep]

            # fc2: Weight shape (output_dim, intermediate_dim), Bias shape (output_dim)
            fc2_new.weight.data = fc2_orig.weight.data[:, indices_to_keep]
            if fc2_orig.bias is not None:
                fc2_new.bias.data = fc2_orig.bias.data # fc2 bias is not affected by intermediate dim change

            # Replace original layers in the model structure
            if layer_type == 'encoder':
                model_base.encoder.layers[layer_index].fc1 = fc1_new
                model_base.encoder.layers[layer_index].fc2 = fc2_new
            elif layer_type == 'decoder':
                model_base.decoder.layers[layer_index].fc1 = fc1_new
                model_base.decoder.layers[layer_index].fc2 = fc2_new

    print("Pruning complete. Layers have been reconstructed.")


# --- Main Pruning Execution ---

# 1. Calculate Parameter Count Before Pruning
params_before = count_parameters(model)
print(f"Parameters before pruning: {params_before:,}")

calibration_loader = eval_dataloader # Or create a smaller subset if eval is too large
# 2. Calculate Importance Scores
# neuron_importances = calculate_ffn_neuron_importance(model)
neuron_importances = calculate_ffn_neuron_importance_activation(model, calibration_loader, device)

# 3. Define Pruning Ratio and Prune
PRUNING_RATIO = 0.95 # Example: Prune 30% of the least important FFN neurons globally
prune_ffn_neurons(model, neuron_importances, PRUNING_RATIO)

# 4. Calculate Parameter Count After Pruning
params_after = count_parameters(model)
print(f"Parameters after pruning:  {params_after:,}")
reduction = (params_before - params_after) / params_before * 100
print(f"Parameter reduction: {reduction:.2f}%")

# --- Fine-tuning After Pruning (Crucial Step!) ---
# After pruning, the model's accuracy will likely drop significantly.
# Fine-tuning on the target dataset is necessary to recover performance.

print("\n--- Starting Fine-tuning of the Pruned Model ---")

# Define optimizer for the pruned model (it has different parameters now)
# Use a smaller learning rate for fine-tuning
optimizer = optim.Adam(model.parameters(), lr=1e-5) # Example: Lower LR
num_finetune_epochs = 5 # Example: Fine-tune for a few epochs

for epoch in range(num_finetune_epochs):
    model.train()
    running_loss = 0.0

    # Ensure train_dataloader is not None
    if train_dataloader is None:
        print("Train dataloader is not available. Skipping fine-tuning.")
        break

    progress_bar = tqdm(train_dataloader, desc=f"Fine-tune Epoch {epoch+1}", leave=True)

    batch_count = 0
    for batch in progress_bar:
        # Handle potentially empty batches from collate_fn filtering
        if batch is None:
            print("Skipping empty batch during fine-tuning.")
            continue

        pixel_values = batch["pixel_values"].to(device)
        pixel_mask = batch["pixel_mask"].to(device)
        # Move each target dictionary to the device
        labels = [{k: v.to(device) for k, v in target.items()} for target in batch["labels"]]

        optimizer.zero_grad()

        # Forward pass with detection loss
        outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
        detection_loss = outputs.loss

        # Check if loss is valid
        if torch.isnan(detection_loss) or torch.isinf(detection_loss):
            print(f"Warning: NaN or Inf loss detected at batch {batch_count}. Skipping batch.")
            # Potentially skip optimizer step or investigate further
            # Consider gradient clipping here if gradients explode
            continue


        # Backpropagation
        detection_loss.backward()

        # Optional: Gradient Clipping (can help stabilize fine-tuning after pruning)
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += detection_loss.item()
        batch_count += 1

        # Update the progress bar with the current loss
        progress_bar.set_postfix(loss=detection_loss.item())

    if batch_count > 0:
       avg_loss = running_loss / batch_count
       print(f"Fine-tune Epoch {epoch+1}: Average Loss: {avg_loss:.4f}")
    else:
       print(f"Fine-tune Epoch {epoch+1}: No valid batches processed.")


    # --- Evaluation after each fine-tuning epoch (Optional but Recommended) ---
    if eval_dataloader is not None and coco_evaluator is not None:
        model.eval()
        from datasets.coco_eval import CocoEvaluator # Import CocoEvaluator
        from datasets import get_coco_api_from_dataset

        base_ds = get_coco_api_from_dataset(test_dataset) # this is actually just calling the coco attribute
        iou_types = ['bbox']
        coco_evaluator = CocoEvaluator(base_ds, iou_types) # initialize evaluator with ground truths
        eval_progress_bar = tqdm(eval_dataloader, desc="Evaluation", leave=False)
        t = 5
        with torch.no_grad():
            for eval_batch in eval_progress_bar:
                t = t - 1
                if t==0:
                  break
                if eval_batch is None:
                    print("Skipping empty batch during evaluation.")
                    continue

                pixel_values = eval_batch["pixel_values"].to(device)
                pixel_mask = eval_batch["pixel_mask"].to(device)
                # Ensure labels have 'orig_size' which post_process needs
                labels = [{k: v.to(device) for k, v in t.items()} for t in eval_batch["labels"]]
                # Make sure 'orig_size' exists and is on the correct device
                if not all('orig_size' in t for t in labels):
                     print("Warning: 'orig_size' missing in some evaluation labels. Skipping batch or using fallback.")
                     # Fallback example (use with caution):
                     # h, w = pixel_values.shape[-2:]
                     # dummy_size = torch.tensor([h, w], device=device)
                     # for t in labels: t['orig_size'] = t.get('orig_size', dummy_size)
                     continue # Safer to skip if orig_size is crucial and missing


                outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask)

                # Ensure 'orig_size' is on CPU for post_process if needed, or handle device mismatch
                # The post_process function expects sizes as a tensor on the CPU typically.
                try:
                   orig_target_sizes = torch.stack([t["orig_size"].to("cpu") for t in labels], dim=0)
                   results = feature_extractor.post_process_object_detection(outputs, threshold=0.1, target_sizes=orig_target_sizes) # Use updated function name and add threshold
                except Exception as e:
                   print(f"Error during post-processing: {e}. Skipping batch.")
                   continue

                # Prepare results for COCO evaluator
                # Ensure image_id is correctly extracted and used as key
                res = {}
                for i, target in enumerate(labels):
                   img_id = target['image_id'].item() # Get Python int
                   res[img_id] = results[i] # Results is a list of dicts

                if res: # Only update if results were generated
                   coco_evaluator.update(res)

        # Synchronize, accumulate and summarize results after iterating through all eval batches
        try:
           coco_evaluator.synchronize_between_processes() # Important in distributed settings, safe otherwise
           coco_evaluator.accumulate()
           coco_evaluator.summarize() # Prints COCO mAP scores
        except Exception as e:
           print(f"Error during COCO evaluation summary: {e}")

    else:
        print("Evaluation dataloader or evaluator not available. Skipping evaluation.")


print("Fine-tuning complete.")

# You can now save the pruned and fine-tuned model
# torch.save(model.state_dict(), "detr_resnet50_pruned_finetuned.pth")


/content/detr
loading annotations into memory...
Done (t=1.88s)
creating index...
index created!
loading annotations into memory...
Done (t=0.80s)
creating index...
index created!
Loading pre-trained DETR model...


Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model loaded.
Using device: cuda
Parameters before pruning: 41,302,368
Registered 6 hooks for encoder FFN layers.
Registered 6 hooks for decoder FFN layers.
Running calibration data (2500 batches) to collect activations...


Calibration:   0%|          | 0/2500 [00:00<?, ?it/s]

Activation collection complete.
Calculating importance scores from activations...


Processing Layers:   0%|          | 0/12 [00:00<?, ?it/s]

Calculated activation-based importance for 24576 FFN neurons across all layers.
Targeting 23347 out of 24576 FFN neurons for pruning (95.00%)...
Pruning neurons across 12 FFN blocks.


Reconstructing Layers:   0%|          | 0/12 [00:00<?, ?it/s]

Pruning complete. Layers have been reconstructed.
Parameters after pruning:  29,325,357
Parameter reduction: 29.00%

--- Starting Fine-tuning of the Pruned Model ---


Fine-tune Epoch 1:   0%|          | 0/2500 [00:00<?, ?it/s]

Fine-tune Epoch 1: Average Loss: 2.0663


Evaluation:   0%|          | 0/2500 [00:00<?, ?it/s]

Accumulating evaluation results...
DONE (t=0.06s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.218
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.374
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.195
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.044
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.157
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.513
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.166
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.296
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.298
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.061
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.244
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

Fine-tune Epoch 2:   0%|          | 0/2500 [00:00<?, ?it/s]

Fine-tune Epoch 2: Average Loss: 1.9150


Evaluation:   0%|          | 0/2500 [00:00<?, ?it/s]

Accumulating evaluation results...
DONE (t=0.10s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.232
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.421
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.182
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.081
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.223
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.583
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.188
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.299
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.310
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.081
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.294
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

Fine-tune Epoch 3:   0%|          | 0/2500 [00:00<?, ?it/s]

Fine-tune Epoch 3: Average Loss: 1.8318


Evaluation:   0%|          | 0/2500 [00:00<?, ?it/s]

Accumulating evaluation results...
DONE (t=0.06s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.317
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.532
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.269
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.138
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.351
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.637
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.253
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.349
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.363
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.159
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.358
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

Fine-tune Epoch 4:   0%|          | 0/2500 [00:00<?, ?it/s]

KeyboardInterrupt: 

Finally, we can accumulate the metrics, and print a summary. This is what I got printed:


```
Accumulating evaluation results...
DONE (t=0.44s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.459
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.640
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.485
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.253
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.530
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.729
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.359
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.566
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.585
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.325
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.618
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.845
 ```



In [None]:
coco_evaluator.synchronize_between_processes()
coco_evaluator.accumulate()
coco_evaluator.summarize()

Accumulating evaluation results...
DONE (t=0.45s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.477
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.670
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.500
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.263
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.517
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.712
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.364
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.578
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.604
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.340
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.625
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= la

In [13]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from transformers import DetrFeatureExtractor, DetrForObjectDetection
from datasets.coco_eval import CocoEvaluator
from datasets import get_coco_api_from_dataset
from tqdm.notebook import tqdm

# Custom COCO dataset that uses the DETR feature_extractor.
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, img_folder, ann_file, feature_extractor):
        super(CocoDetection, self).__init__(img_folder, ann_file)
        self.feature_extractor = feature_extractor

    def __getitem__(self, idx):
        # Read in PIL image and target in COCO format
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}

        # Preprocess image and target: conversion to DETR format, resizing, normalization, etc.
        encoding = self.feature_extractor(images=img, annotations=target, return_tensors="pt")
        # Remove batch dimension added by feature_extractor
        pixel_values = encoding["pixel_values"].squeeze(0)
        # DETR returns 'labels' as a list inside the encoding when annotations are provided.
        target = encoding["labels"][0]  # remove batch dimension

        return pixel_values, target

# Initialize the feature extractor.
feature_extractor = DetrFeatureExtractor()

# Define paths for your COCO dataset
img_folder = '/content/a/val2017'
ann_file = '/content/a/annotations/instances_val2017.json'

# Create the dataset instance.
dataset = CocoDetection(img_folder=img_folder,
                        ann_file=ann_file,
                        feature_extractor=feature_extractor)

# Define a collate function to combine images/targets into a batch.
def collate_fn(batch):
    pixel_values = [item[0] for item in batch]
    encoding = feature_extractor.pad(pixel_values, return_tensors="pt")
    labels = [item[1] for item in batch]
    batch = {
        'pixel_values': encoding['pixel_values'],
        'pixel_mask': encoding['pixel_mask'],
        'labels': labels,
    }
    return batch

# Create DataLoader. (Here we use the same dataset for demonstration.
# In practice, use a separate training split.)
train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=2, shuffle=True)
eval_dataloader  = DataLoader(dataset, collate_fn=collate_fn, batch_size=2)

# Initialize the evaluation helper (COCO API)
base_ds = get_coco_api_from_dataset(dataset)
iou_types = ['bbox']
coco_evaluator = CocoEvaluator(base_ds, iou_types)

# Load the DETR model; note that the transformer layers are inside model.model.encoder/decoder.
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

# Define device and transfer the model.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

################################################################################
# Fine-tuning Setup: Define L0 Regularization on FFN (fc1 and fc2) layers.
################################################################################

# Set the regularization strength.
l0_lambda = 1e-4  # Tune this hyperparameter as needed.

def approximate_l0(weight):
    """
    A smooth-sparsity surrogate for the L0 norm.
    Here we use a sigmoid on a scaled version of the absolute weight values.
    This is only a heuristic. For a more principled L0 loss, consider the hard concrete formulation.
    """
    # Adjust the scaling factor and offset as desired.
    return torch.sum(torch.sigmoid(10 * (torch.abs(weight) - 0.001)))

def compute_l0_loss(model, l0_lambda):
    """
    Computes the total L0 loss for the FFN layers (fc1 & fc2)
    of both the encoder and the decoder.

    In the provided DETR architecture:
      - Encoder layers are found in model.model.encoder.layers.
      - Decoder layers are found in model.model.decoder.layers.
      - Each transformer layer has two feed-forward submodules: fc1 and fc2.
    """
    l0_loss = 0.0
    # Encoder FFN layers.
    for layer in model.model.encoder.layers:
        for fc_layer in [layer.fc1, layer.fc2]:
            l0_loss += l0_lambda * approximate_l0(fc_layer.weight)
    # Decoder FFN layers.
    for layer in model.model.decoder.layers:
        for fc_layer in [layer.fc1, layer.fc2]:
            l0_loss += l0_lambda * approximate_l0(fc_layer.weight)
    return l0_loss

################################################################################
# Optimization Setup & Training Loop
################################################################################

# Use Adam optimizer; you may consider grouping parameters if needed.
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Number of training epochs (adjust as needed)
num_epochs = 3

print("Starting fine-tuning with L0 regularization...")
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    # Loop over training batches.
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        pixel_values = batch["pixel_values"].to(device)
        pixel_mask = batch["pixel_mask"].to(device)
        # Convert each target dictionary to device.
        labels = [{k: v.to(device) for k, v in target.items()} for target in batch["labels"]]

        # Forward pass:
        # When providing labels the model returns a dict that includes a combined detection loss.
        outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
        detection_loss = outputs.loss  # This is the sum (or weighted sum) of detection losses.

        # Compute the L0 loss for FFN layers.
        l0_loss = compute_l0_loss(model, l0_lambda)

        # Total loss is the sum of detection and sparsity (L0) losses.
        total_loss = detection_loss + l0_loss

        # Backpropagation and optimization step.
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}: Average Loss: {avg_loss:.4f}")

    # (Optional) Evaluate on the validation dataset after each epoch.
    model.eval()
    coco_evaluator.reset()
    with torch.no_grad():
        for idx, eval_batch in enumerate(tqdm(eval_dataloader, desc="Evaluation")):
            # You might choose to limit the number of batches during evaluation.
            pixel_values = eval_batch["pixel_values"].to(device)
            pixel_mask = eval_batch["pixel_mask"].to(device)
            labels = [{k: v.to(device) for k, v in t.items()} for t in eval_batch["labels"]]

            outputs = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
            # DETR post-processing converts outputs to COCO-format predictions.
            orig_target_sizes = torch.stack([t["orig_size"] for t in labels], dim=0)
            results = feature_extractor.post_process(outputs, orig_target_sizes)
            # Create a results dictionary mapping image_id to outputs.
            res = {target['image_id'].item(): output for target, output in zip(labels, results)}
            coco_evaluator.update(res)

    # (Optional) You can compute and print evaluation metrics here after epoch completion.
    coco_evaluator.synchronize_between_processes()
    coco_evaluator.accumulate()
    coco_evaluator.summarize()

print("Training complete.")


loading annotations into memory...
Done (t=0.45s)
creating index...
index created!


Some weights of the model checkpoint at facebook/detr-resnet-50 were not used when initializing DetrForObjectDetection: ['model.backbone.conv_encoder.model.layer1.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer2.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer3.0.downsample.1.num_batches_tracked', 'model.backbone.conv_encoder.model.layer4.0.downsample.1.num_batches_tracked']
- This IS expected if you are initializing DetrForObjectDetection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DetrForObjectDetection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Starting fine-tuning with L0 regularization...


Epoch 1:   0%|          | 0/2500 [00:00<?, ?it/s]

KeyboardInterrupt: 