# Visgator
[![pdm-managed](https://img.shields.io/badge/pdm-managed-blueviolet)](https://pdm.fming.dev)

Referring expressions visual grounding wih scene graphs and object detection.
This repo includes early-stage work from [visgator-base](https://github.com/halixness/visgator-base).

### Authors
Diego Calanzone, Francesco Gentile. <br>
University of Trento <br>
Deep Learning, Capstone Project, Spring 2023.

### A note on code
The codebase for this project is articulated. We have developed a Python module with an object-oriented approach to practically train and test our model on multiple machines and to build a template code for Deep Learning projects in the future. This report will contain comprehensive textual descriptions with annotated code snippets for all the source files were code has been written (class templates excluded). <br>
Moreover, all the executable cells in sequence will allow to test all the experimental steps from code generation to model evaluation.

# Introduction
Object detection with referring expression is a sub-case of a major computer vision challenge: associating semantics with identified portions of the image. Recent approaches in deep computer vision are based on learning without annotations to leverage a massive amount of internet-scraped data. OpenAI CLIP ([Radford et al.](https://arxiv.org/pdf/2103.00020.pdf)) implements a visual and a language backbone to encode images and captions, respectively, and match their representation vectors in a common embedding space; by providing a score for a caption to match an image, this model can be used for open-set classification and object detection. This contrasts with a less recent branch in object detection: closed-set detectors. YOLO ([Redmon et al.](https://arxiv.org/abs/1506.02640)) is a widely-used object detector that classifies portions of images with a finite set of labels. <br> 
Grounding DINO ([Liu et al.](https://arxiv.org/abs/2303.05499)) proposes to extend object detection from closed to open sets by fusing language and visual modalities in multiple stages: after a feature enhancing layer, visual features, considered as query tokens, are selected with respect to language feature tokens; consequently, language (K,V) and vision features(Q) attend each other through cross-attention; eventually, entities in the language modality are matched with the output visual queries with a contrastive objective. <br>
Similarly, M-Plug 2 ([Xu et al.](https://arxiv.org/pdf/2302.00402.pdf)) proposes an architecture consisting in shared self/cross attention layers (universal layers), a cross-attention fusion model and a shared weights decoder block. In both of these methods, the proposed model is massively pre-trained on a set of datasets and object detection with referring expression is one of a multitude of target tasks. <br>
<br>
Given the sparsity of such pre-trained models with different capabilities, the goal of this project is to follow training efficiency by combining pre-trained visual and language backbones: a minimal portion of the proposed architecture is trained from scratch to "glue" the intermediate outputs of different specialized models. In particular:
1. We want to leverage the grounded text and image representations from CLIP, which has already been trained on massive datasets.
2. We introduce some "bias" by delegating the tasks of identifying entities in images and captions to existing pre-trained models (GDino, OwLViT, YOLO, SceneGraph Parser ([Schuster et al.](https://nlp.stanford.edu/software/scenegraph-parser.shtml))). This saves us the burden of training a massive network from scratch to learn these skills, while it allows us to test and leverage existing work from literature.
3. We introduce two attention-based modules trained from scratch, in order: a transformer decoder to ground visual entity pairs with referring expressions; a transformer encoder to compare language-informed entity pairs and extract the correct one as referred from annotations.

# Baseline
### Architecture
We choose as baseline model "YOLOCLIP": a simple architecture consisting in a detection module, [YOLOv8 (Ultralytics)](https://docs.ultralytics.com); a language-image pairing module, [OpenAI CLIP (HuggingFace)](https://huggingface.co/docs/transformers/model_doc/clip). For each `(image, caption)` pair: 
1. Given an image, the bounding boxes are generated with YOLO ([Redmon et al.](https://arxiv.org/abs/1506.02640)) for all the detected entities.
2. Each portion of image denoted by each bounding box is preprocessed with CLIP (resizing, center crop, RGB conversion, normalization) and encoded with the visual backbone (`ViT-B/32`) into a sequence of patch token embeddings.
3. The image caption is tokenized with [CLIP's BytePair Encoding](https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py) and encoded with CLIP's language backbone into a sequence of token embeddings.
4. Both the image and language token embeddings are mapped into a visual and textual vector embedding, respectively, with same dimensionality $d = 512$. For each pair of `(caption, bbox_img_embedding)` the matching score is computed as follows (section 2.4, [Radford et al.](https://arxiv.org/pdf/2103.00020.pdf)):
<br><br>
```python
    # image_encoder - ResNet or Vision Transformer
    # text_encoder - CBOW or Text Transformer
    # I[n, h, w, c] - minibatch of aligned images
    # T[n, l] - minibatch of aligned texts
    # W_i[d_i, d_e] - learned proj of image to embed
    # W_t[d_t, d_e] - learned proj of text to embed
    # t - learned temperature parameter
    # extract feature representations of each modality
    I_f = image_encoder(I) #[n, d_i]
    T_f = text_encoder(T) #[n, d_t]
    # joint multimodal embedding [n, d_e]
    I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
    T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
    # scaled pairwise cosine similarities [n, n]
    logits = np.dot(I_e, T_e.T) * np.exp(t)
    # symmetric loss function
    labels = np.arange(n)
    loss_i = cross_entropy_loss(logits, labels, axis=0)
    loss_t = cross_entropy_loss(logits, labels, axis=1)
    loss = (loss_i + loss_t)/2
```
5. From the pair `(caption, bbox_img_embedding)` with the highest score, the bounding box is considered as the prediction for the given referring expression.

### Implementation
We implement the baseline model in `src/visgator/models/baseline/_model.py`:
<br><br>
```python
class Model(_Model[BBoxes]):
    def __init__(self, config: Config) -> None:
        super().__init__()

        self._postprocessor = PostProcessor()

        self._yolo = YOLO(config.yolo.weights())
        self._clip_processor = CLIPProcessor.from_pretrained(config.clip.weights())
        self._clip = CLIPModel.from_pretrained(config.clip.weights())

        self._toPIL = T.ToPILImage()
    (...)

    def forward(self, batch: Batch) -> BBoxes:
        boxes = []

        # Preprocessing for YOLO
        images = [self._toPIL(sample.image) for sample in batch.samples]
        yolo_results = self._yolo.predict(images, conf=0.5, verbose=False)

        sample: BatchSample
        for sample, result in zip(batch, yolo_results):
            proposals = []

            # Base case: no entities detected: suppose the whole image as context
            if len(result.boxes) == 0:
                # create a dummy bbox
                tmp = BBox((0, 0, 0, 0), sample.image.shape[1:], BBoxFormat.XYXY, True)
                boxes.append(tmp.to(self._clip.device))
                continue

            # Each bounding box is a potion of img to encode with CLIP
            for bbox in result.boxes:
                xmin, ymin, xmax, ymax = bbox.xyxy.int()[0]
                clipped_image = sample.image[:, ymin:ymax, xmin:xmax]
                proposals.append(clipped_image)

            inputs = self._clip_processor(
                text=sample.caption.sentence,
                images=proposals,
                return_tensors="pt",
            ).to(self._clip.device)

            # Compute the CLIP scores and take the bbox with the highest score
            output = self._clip(**inputs)
            best = output.logits_per_image.argmax(0).item()
            boxes.append(
                BBox(
                    result.boxes[best].xyxy[0],
                    sample.image.shape[1:],
                    BBoxFormat.XYXY,
                    False,
                )
            )

        return BBoxes.from_bboxes(boxes)

```


# Proposed solution
### Architecture
![visgator architecture](https://github.com/FrancescoGentile/visgator/blob/main/docs/src/architecture.png?raw=true)

We call the proposed architecture **ERP-A: Entity Relationship Pairs Attention**. <br>
Referring expressions are informative to identify objects in two ways: by adding details about a specific object (e.g. the girl with the green jacket); by describing a relationship between multiple objects (e.g. the glass on the table). To achieve this we leverage multiple existing techniques: the Stanford SceneGraph parser to extract pairwise entity relationships from captions; OpenAI's CLIP to encode the image and the caption with significant representations from contrastive pre-training; OwLViT/Grounding DINO only to detect the entity nouns from the SceneGraph in the image. <br>
Formally, for each `(image, caption)` sample:
1. We identify all the Entity-Relationship-Pairs (ERP) from the caption ([Preprocessing](#preprocessing))
2. Given multiple portions of the image as detected candidates for each entity, we instantiate an ERP for all the possible pairs of visual entities, limited with a confidence threshold ([Instantiation](#instantiation)).
3. Enrich each entity-pair (a sequence of visual tokens) with the encoded referring expression through cross-attention ([Decoding](#decoding)).
4. Attend to all the ERPs (each is a sequence of text-informed visual tokens) with a learnable regression token and project this to predict the proposed visual region of interest ([ERP-Attention](#erpattention)).


## Preprocessing
### Scene graph extraction
Multiple techniques have been tested to extract entities and relationships from text:
- <ins>GPT models</ins>: we have designed a prompt to have GPT 3.5/GPT 4 extract (ent1, rel, ent2) tuples.
```
    Consider the sentence: {sentence}.
    What are the named entities?
    What are the relations between the named entities?
    Answer only with tuples "[x, action name, y]"  and without passive forms.
    Please be coherent with the name of the actions that occur multiple times.
    Answer by filling a JSON, follow the example:
    Sentence: "the girl is looking at the table full of drinks"
    Answer:
    {{
    "entities": ["the girl", "the table", "drinks"],
    "relations": [[0, "is on", 1], [1, "full of", 2]]
    }}
```
We attempted to identify the most difficult language expressions and to provide them as few-shot examples. The model generated valid outputs for the simplest captions:
```
Sentence: "the girl is looking at the table full of drinks"
Answer: {
    "entities": ["the girl", "the table full of drinks"],
    "relations": [[0, "is looking at", 1]]
}
```
While it showed inconsistent behavior with incomplete/grammarly incorrect/more abstract descriptions:
```
Sentence: "the leftmost dog"
Answer: { 
    "entities": ["the leftmost dog"], 
    "relations": [] 
}
```
Moreover, the inference overhead is significant: a single query to the GPT APIs takes around $0.5 \pm 1s$, this sums up to the overall model inference time of around $0.8 \pm 1s$, resulting in circa $1.3s$. For for around 62.000 training examples, this would result in 22 hours per epoch, which exceeds our capabilities. We have tested multiple backends: [YOU](you.com), [OpenAI](chat.openai.com), [Poe](poe.com); no significant differences have been found.

- <ins>Pre-trained Quantized LLMs</ins>: we have tested the same approach with large language models loaded locally on our machines. In order to use models of 7-13 billion of parameters on single gpus (8/16GB of VRAM), we have used quantized LLMs with weights cast to `float8` (8 bit quantized), which yields acceptable perplexity scores. We have tested a variety of models: [WizardLM-Vicuna 13B GGML](https://huggingface.co/TheBloke/wizard-vicuna-13B-GGML), [Falcon-Instruct 7B GPTQ](https://huggingface.co/TheBloke/falcon-7b-instruct-GPTQ). No significant gains, apart from inference, have been observed wrt. GPT API models; conversely, these models generated much more inconsistent outputs, possibly because of the limited capacity of these models.  

- <ins>Stanford SceneGraph Parser</ins>: it consists in a rule-based or a classifier-based model to extract entities from a sentence. The latter is shown to perform better ([Schuster et al.](https://nlp.stanford.edu/pubs/schuster-krishna-chang-feifei-manning-vl15.pdf)). This model yields acceptable graph representations of the relationships described in RefCOCOg's sentences: only `281/62176` captions resulted in zero entities identified, these samples have been neglected and the implications on accuracy are reported subsequently.

### Implementation
We use a modified version of the Stanford SceneGraph Parser [built with sPaCy](https://github.com/vacancy/SceneGraphParser).
The dataset is initially pre-processed to process captions into [SceneGraphs](https://en.wikipedia.org/wiki/Scene_graph). With `visgator.datasets.refcocog._generator.Generator.generate()`, each image annotation is parsed through a `visgator.utils.graph.SpacySceneGraphParser` ([reference](https://github.com/vacancy/SceneGraphParser)) or a Large Language Model ([reference](https://huggingface.co/tiiuae/falcon-7b-instruct)): entities and relationships between them are identified and stored in a graph object. The generator encodes each dataset sample as a tuple of `(image_path, caption, graph)`.

The caption generation process can initiated with:
```python -m visgator --phase generate --config config/local.yaml```

Which is the equivalent of the executable code:

In [9]:
from ruamel.yaml import YAML
from pathlib import Path
import os
import json

cfg_path = os.path.join("..", "config", "local.yaml")
config_path = Path(cfg_path)
extention = config_path.suffix
match extention:
    case ".json":
        with open(config_path, "r") as f:
            cfg = json.load(f)
    case ".yaml":
        yaml = YAML(typ="safe")
        cfg = yaml.load(config_path)
    case _:
        raise ValueError(f"Unknown config file extention: {extention}.")

In [None]:
from visgator.datasets import Config as DatasetConfig
from visgator.datasets import Generator

dataset_config = DatasetConfig.from_dict(cfg["dataset"])
generator = Generator.new(dataset_config)
generator.generate()

Which corresponds to the generation function:
<br><br>
```python
def generate(self) -> None:
  split_samples = get_preprocessed_samples(
      self._config,
      [Split.TRAIN, Split.VALIDATION, Split.TEST],
  )

  parser = SceneGraphParser.new(self._config.generation.parser)
  if self._config.generation.num_workers is not None:
      num_workers = self._config.generation.num_workers
  else:
      num_workers = multiprocessing.cpu_count() // 2

  output: dict[str, list] = {}  # type: ignore
  for split, samples in split_samples.items():
      split_output: list[dict] = []  # type: ignore
      output[str(split)] = split_output

      graphs = process_map(
          parser.parse,
          (sample.caption.sentence for sample in samples),
          max_workers=num_workers,
          chunksize=self._config.generation.chunksize,
          desc=f"Generating {split} split",
          total=len(samples),
      )

      output[str(split)] = [
          {
              "image": sample.path.name,
              "caption": Caption(sample.caption.sentence, graph).to_dict(),
              "bbox": sample.bbox,
          }
          for sample, graph in zip(samples, graphs)
      ]

  output_path = (
      self._config.path / f"annotations/info_{self._config.split_provider}.json"
  )
  with open(output_path, "w") as f:
      json.dump(output, f, indent=2)
```
<br><br>
The generated annotations file for RefCOCOg in UMD format:

```
{
  "test": [
    {
      "image": "COCO_train2014_000000380440.jpg",
      "caption": {
        "sentence": "the man in yellow coat",
        "graph": {
          "entities": [{"span": "the man", "head": "man"}, ...],
          "relations": [{
              "subject": 0,
              "predicate": "in",
              "object": 1
            }],
        }
      },
      "bbox": [374.31,65.06,136.04,201.94]
    },
    ...
  ]
}
```

## Training logic

Training starts with the instatiation of a `visgator.engines.Trainer` object, followed by `Trainer.run()`. <br>
Under the hood, for `num_epochs` the method `Trainer._train_epoch(epoch)` is invoked:
<br><br>
```python
def _train_epoch(self, epoch: int) -> None:
    self._logger.info(f"Training epoch {epoch + 1} started.")

    start = timer()

    self._model.train()
    self._postprocessor.train()
    self._criterion.train()

    self._tl_tracker.increment()
    self._tm_tracker.increment()
    self._optimizer.zero_grad()

    counter = tqdm(
        desc="Training",
        total=self._get_steps_per_epoch() * self._params.train_batch_size,
    )

    with counter as progress_bar:
        batch: Batch
        bboxes: BBoxes
        device_type = "cuda" if self._device.is_cuda else "cpu"

        for idx, (batch, bboxes) in enumerate(self._train_loader):
            # since the dataloader batch size is equal to true batch size
            # // gradient accumulation steps, the last samples in the dataloader
            # may not be enough to fill the true batch size, so we break the loop
            # for example, if the true batch size is 8, the gradient accumulation
            # steps is 4 and the dataset size is 50, the last 2 samples will be
            # ignored
            if progress_bar.total == progress_bar.n:
                break

            batch = batch.to(self._device.to_torch())
            bboxes = bboxes.to(self._device.to_torch())

            # Mixed precision training support
            with autocast(device_type, enabled=self._params.mixed_precision):
                outputs = self._model(batch)
                tmp_losses = self._criterion(outputs, bboxes)
                losses = self._tl_tracker(tmp_losses)

                # Gradient accumulation allows for larger batch sizes with memory limitations!
                loss = losses.total / self._params.gradient_accumulation_steps

            # GradScaling to prevent gradient underflow
            self._scaler.scale(loss).backward()

            if (idx + 1) % self._params.gradient_accumulation_steps == 0:
                if self._params.max_grad_norm is not None:
                    self._scaler.unscale_(self._optimizer)
                    torch.nn.utils.clip_grad_norm_(  # type: ignore
                        self._model.parameters(), self._params.max_grad_norm
                    )

                self._scaler.step(self._optimizer)
                self._scaler.update()
                self._optimizer.zero_grad()
                self._lr_scheduler.step_after_batch()
                

            with torch.no_grad():
                pred_bboxes = self._postprocessor(outputs)
                self._tm_tracker.update(
                    pred_bboxes.to_xyxy().normalize().tensor,
                    bboxes.to_xyxy().normalize().tensor,
                )

            progress_bar.update(len(batch))
            
            del batch
            del bboxes
            del outputs
            del loss

            # Memory cleanup to prevent leaking
            if self._device.is_cuda:
                torch.cuda.empty_cache()
            gc.collect()
            
    # By default the LR is rescaled after each epoch (1e-3 => 1e-4 => ...)
    self._lr_scheduler.step_after_epoch()

    end = timer()
    elapsed = end - start

    self._logger.info(f"Training epoch {epoch + 1} finished.")
    self._log_statistics(epoch, elapsed, train=True)
```


### Hyperparameters

<table>
    <tr>
        <td>Training batch size</td>
        <td>32</td>
    </tr>
    <tr>
        <td>Evaluation batch size</td>
        <td>4</td>
    </tr>
    <tr>
        <td>Gradient accumulation step</td>
        <td>8</td>
    </tr>
    <tr>
        <td>Start learning rate</td>
        <td>1e-3</td>
    </tr>
    <tr>
        <td>LR Scheduler</td>
        <td><a href="https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html">OneCycleLR</a></td>
    </tr>
    <tr>
        <td>LR scheduler interval</td>
        <td>1 epoch</td>
    </tr>
    <tr>
        <td>Max LR</td>
        <td>1e-3</td>
    </tr>
    <tr>
        <td>Optimizer</td>
        <td>AdamW</td>
    </tr>
    <tr>
        <td>Box confidence threshold</td>
        <td>0.1</td>
    </tr>
    <tr>
        <td>Text confidence threshold</td>
        <td>0.25</td>
    </tr>
    <tr>
        <td>Max detected boxes</td>
        <td>50</td>
    </tr>
    <tr>
        <td>Hidden dim</td>
        <td>256</td>
    </tr>
</table>

## Instantiation

### Batch formatting

Samples from the RefCOCOg are organized in a sequence of `Batch` objects, containing `BatchSample` samples.<br> 
A `BatchSample` object consists in an `image:tensor[B,W,HW,]` and a `caption:Caption`.<br>
A `Caption` object encapsulates the original text description and the parsed `SceneGraph`. 
<br><br>
```python
    @serde.serde(type_check=serde.Strict)
    @dataclass(frozen=True)
    class Caption:
        """A caption with an optional scene graph."""
        sentence: str
        graph: Optional[SceneGraph] = serde.field(
            default=None,
            serializer=SceneGraph.to_dict,
            deserializer=SceneGraph.from_dict,
        )
        (...)


    @dataclass(frozen=True)
    class BatchSample:
        """A batch sample with an image and a caption."""
        image: UInt8[Tensor, "3 H W"]
        caption: Caption
        (...)


    @dataclass(frozen=True)
    class Batch:
        """A batch of samples."""
        samples: tuple[BatchSample, ...]
        (...)
```

### Architecture
Our model, renamed in the configuration as `erpa` is invoked in `visgator.engines.trainer.Trainer.train_epoch()` and it is composed by 4 blocks:

### 1. Encoders
We use the [OpenCLIP](https://github.com/mlfoundations/open_clip) implementation for CLIP's visual and textual backbone. The source has been copied in `src/visgator/models/erpa/_encoders.py`:

```python

def build_encoders(config: EncodersConfig) -> tuple[VisionEncoder, TextEncoder]:
    """ Returns the vision and text backbone from OpenCLIP"""

    model, _, preprocess = open_clip.create_model_and_transforms(
        config.model,
        pretrained=config.pretrained,
    )

    tokenizer = open_clip.get_tokenizer("ViT-B-32")

    if type(preprocess) is not T.Compose:
        raise NotImplementedError

    # Image normalization as preprocessing pipeline for the vision encoder
    mean: Optional[tuple[float, float, float]] = None
    std: Optional[tuple[float, float, float]] = None

    for transform in preprocess.transforms:
        if type(transform) is T.Normalize:
            mean = transform.mean
            std = transform.std
            break

    vision = VisionEncoder(model.visual, config.hidden_dim, mean, std)
    text = TextEncoder(model, tokenizer, config.hidden_dim)

    return vision, text


class VisionEncoder(nn.Module):
    def __init__(self, encoder: OpenClipVisionTransformer) -> None:
    (...)

class _VisionTransformer(nn.Module):
    def __init__(self, transformer: OpenClipTransformer) -> None:
    (...)

class _ResidualAttentionBlock(nn.Module):
    def __init__(self, block: OpenClipResidualAttentionBlock) -> None:
    (...)

class TextEncoder(nn.Module):
    def __init__(self, model: CLIP, tokenizer: Tokenizer, output_dim: int) -> None:
    (...)

```

### 2. Detector
Below two variants have been tested to extract visual entities: OwLViT and GroundingDINO. <br>
It is crucial to mention two assumptions on handling borderline cases:
- **No boxes detected for an entity**: in this case, for each entity a set of $k$ region of interest as large as the whole image are considered as visual context.
- **Overall many objects detected**: detectors can identify many entities given flexible confidence scores (e.g. $0.25$). In instantiating Entity Relationship Pairs, the computational cost grows quadratically and results in memory overflow. To find a tradeoff between tolerance in detection and memory consumption, we have set an upper bound of detected entities of $50$, supporting up to $2500$ pairwise object relationships in the image. Consequently, from the set of detected boxes the top $50$ by confidence are considered.

### 2a. Detector (OwLViT)

We test a lightweight, transformer-based, open-set object detector called OwL-ViT ([Minderer et al.](https://arxiv.org/pdf/2205.06230.pdf)) that given a sent of entities, returns multiple object bounding boxes with scores. <br> We apply simple prompt engineering that resulted in improving the detection capabilities:
- Providing to OwLViT an entity labeled as `"[noun]"` to look for leads to detecting at least one entity in $89\pm3\%$ of the RefCOCOg samples.
-  Expanding the label to `"a photo of [noun]"` boosts the portion of detected entities up to $96\pm2\%$ of the RefCOCOg samples.

```python
class OwlViTDetector(nn.Module):
    def __init__(self, config: DetectorConfig) -> None:
        super().__init__()

        assert config.owlvit is not None
        self._dummy = nn.Parameter(torch.empty(0))
        self._box_threshold = config.box_threshold
        self._max_detections = config.max_detections
        self._preprocessor = OwlViTProcessor.from_pretrained(config.owlvit)
        self._detector = OwlViTForObjectDetection.from_pretrained(config.owlvit)
        self._freeze()

    def _freeze(self) -> None:
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, batch: Batch) -> list[DetectionResults]:
        # partialy taken from: https://huggingface.co/docs/transformers/model_doc/owlvit
        # Preprocessing
        images = [T.to_pil_image(sample.image) for sample in batch.samples]
        captions = [sample.caption for sample in batch.samples]

        B = len(captions)

        # Extracting graph entities
        entities: list[list[str]] = [None] * B  # type: ignore
        for i, caption in enumerate(captions):
            graph = caption.graph
            assert graph is not None
            entities[i] = [
                f"a photo of {entity.head.lower().strip()}" for entity in graph.entities
            ]

        # Object detection (open-vocabulary)
        inputs = self._preprocessor(
            text=entities, images=images, return_tensors="pt"
        ).to(self._dummy.device)
        detector_results = self._detector(**inputs)

        # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
        target_sizes = torch.tensor(
            [image.size for image in images], device=self._dummy.device
        )
        # Convert outputs (bounding boxes and class logits) to COCO API
        results = self._preprocessor.post_process_object_detection(
            outputs=detector_results, target_sizes=target_sizes
        )

        # For each result
        detections: list[DetectionResults] = [None] * B  # type: ignore

        for sample_idx in range(B):
            boxes, scores, labels = (
                results[sample_idx]["boxes"],
                results[sample_idx]["scores"],
                results[sample_idx]["labels"],
            )

            matched_indices = []
            matched_boxes = []
            height, width = images[sample_idx].size

            entities_found = [False] * len(entities[sample_idx])

            idx = scores >= self._box_threshold
            boxes = boxes[idx]
            scores = scores[idx]
            labels = labels[idx]

            # If detections are too many => select tok K first
            if len(boxes) > self._max_detections:
                _, idx = torch.topk(scores, self._max_detections)
                boxes = boxes[idx]
                scores = scores[idx]
                labels = labels[idx]

            # Check identified identities by score
            for box, label in zip(boxes, labels):
                matched_boxes.append(box)
                matched_indices.append(label)
                entities_found[label] = True

            # If no boxes are found for an entity: suppose the whole image as ROI
            for entity_idx, found in enumerate(entities_found):
                if not found:
                    matched_indices.append(entity_idx)
                    matched_boxes.append(
                        torch.tensor([0.0, 0.0, width - 1, height - 1]).to(
                            self._dummy.device
                        )
                    )

            boxes = BBoxes(
                boxes=torch.stack(matched_boxes),
                images_size=images[sample_idx].size,
                format=BBoxFormat.XYXY,
                normalized=False,
            )

            detections[sample_idx] = DetectionResults(
                entities=torch.tensor(matched_indices, device=self._dummy.device),
                boxes=boxes,
            )
        return detections
```

### 2b. Detector (GroundingDINO)
Alternatively, we test a detector based on GroundingDINO ([Liu et al.](https://arxiv.org/pdf/2303.05499.pdf)) that given an image and a caption, it returns a set of bounding boxes with associated labels from the caption.
<br><br>
```python
class GroundigDINODetector(nn.Module):
    def __init__(self, config: DetectorConfig) -> None:
        super().__init__()

        assert config.gdino is not None
        self._dummy = nn.Parameter(torch.empty(0))
        self._mean = (0.485, 0.456, 0.406)
        self._std = (0.229, 0.224, 0.225)
        self._gdino: GroundingDINO = load_model(
            str(config.gdino.config), str(config.gdino.weights)
        )
        self._box_threshold = config.box_threshold
        self._text_threshold = config.text_threshold
        self._max_detections = config.max_detections
        self._freeze()

    def _freeze(self) -> None:
        for param in self.parameters():
            param.requires_grad = False

    def forward(
        self, images: Nested4DTensor, captions: list[Caption]
    ) -> list[DetectionResults]:
        
        # Image preprocessing
        img_tensor = F.normalize(images.tensor, self._mean, self._std)
        img_tensor.masked_fill_(images.mask.unsqueeze(1).expand(-1, 3, -1, -1), 0.0)
        images = Nested4DTensor(img_tensor, images.sizes, images.mask)
        B = len(captions)

        entities: list[dict[str, list[int]]] = [{} for _ in range(B)]
        sentences: list[str] = [""] * B

        # Preprocess the extracted entities as sequences "noun . noun . noun" as in the paper
        for i, caption in enumerate(captions):
            graph = caption.graph
            assert graph is not None

            for entity_idx, entity in enumerate(graph.entities):
                head = entity.head.lower().strip()
                entities[i].setdefault(head, []).append(entity_idx)

            sentences[i] = " . ".join(entities[i].keys()) + " ."
        del i

        # Associate category names sequences with image portions
        gdino_images = NestedTensor(images.tensor, images.mask)
        output = self._gdino(gdino_images, captions=sentences)

        pred_logits = output["pred_logits"].sigmoid()
        pred_boxes = output["pred_boxes"]

        # Apply a defined confidence threshold
        masks = pred_logits.max(dim=2)[0] > self._box_threshold

        detections: list[DetectionResults] = [None] * B  # type: ignore

        for sample_idx in range(B):
            mask = masks[sample_idx]
            detected_boxes = pred_boxes[sample_idx, mask]
            logits = pred_logits[sample_idx, mask] # gDINO logits are phrases

            # Extracting the topk most confident boxes
            # (avoid combinatorial explosion)
            if len(logits) > self._max_detections:
                logits, indices = torch.topk(logits, self._max_detections)
                detected_boxes = detected_boxes[indices]

            tokenized = self._gdino.tokenizer(sentences[sample_idx])

            sep_idx = [
                i
                for i in range(len(tokenized["input_ids"]))
                if tokenized["input_ids"][i] in [101, 102, 1012]
            ]

            phrases: list[str] = []

            # processing the gDINO phrases associated with bounding boxes
            for logit in logits:
                max_idx = logit.argmax()
                insert_idx = bisect.bisect_left(sep_idx, max_idx)
                right_idx = sep_idx[insert_idx]
                left_idx = sep_idx[insert_idx - 1]
                phrases.append(
                    self._get_phrases_from_posmap(
                        logit > self._text_threshold,
                        tokenized,
                        self._gdino.tokenizer,
                        left_idx,
                        right_idx,
                    ).replace(".", "")
                )

            indexes = []
            boxes = []
            entities_found: list[bool] = [False] * len(
                captions[sample_idx].graph.entities  # type: ignore
            )

            # Phrases are matched with the entities from the SceneGraph
            # each entity in the scenegraph is then marked as found
            for det_idx, det_name in enumerate(phrases):
                if det_name in entities[sample_idx]:
                    for entity_idx in entities[sample_idx][det_name]:
                        indexes.append(entity_idx)
                        boxes.append(detected_boxes[det_idx])
                        entities_found[entity_idx] = True
                else:
                    for entity_name, entity_idxs in entities[sample_idx].items():
                        if det_name in entity_name:
                            for entity_idx in entity_idxs:
                                indexes.append(entity_idx)
                                boxes.append(detected_boxes[det_idx])
                                entities_found[entity_idx] = True

            # Alternatively, the entire image is supposed to be a region of interest (generalization)
            for entity_idx, found in enumerate(entities_found):
                if not found:
                    indexes.append(entity_idx)
                    height, width = images.sizes[sample_idx]
                    box = torch.tensor(
                        [0.0, 0.0, width - 1, height - 1], device=self._dummy.device
                    )
                    box = box.unsqueeze(0)
                    box = ops.from_xyxy_to_cxcywh(box)
                    box = ops.normalize(
                        box, torch.tensor([[width, height]], device=box.device)
                    )
                    box = box.squeeze(0)
                    boxes.append(box)

            detections[sample_idx] = DetectionResults(
                entities=torch.tensor(indexes, device=self._dummy.device),
                boxes=BBoxes(
                    boxes=torch.stack(boxes),
                    images_size=images.sizes[sample_idx],
                    format=BBoxFormat.CXCYWH,
                    normalized=True,
                ),
            )
        return detections
```

### 3. ERP Decoder

This module consists in a transformer decoder:
- For each `(entity1, relationship, entity2)` tuple in a sample's  SceneGraph, the sample image is masked with the union of the gaussian heatmaps centered on the extracted entities `(entity1, entity2)`. This allows the visual encoder to process the original image, while focusing on the target visual entities.
- For each `(entity1, relationship, entity2)` tuple, the masked image is encoded with CLIP's visual backbone `ViT-B/32`, which we found to best capture long-range spatial relantionship. We consider the output sequence of path embeddings of the Vision Transformer, while the final CLIP projection is discarded (also for the language backbone).
- The visual embeddings sequence of each masked image attends the CLIP token embeddings of the respective `(relationship)` with a stack of cross-attention layers. Recalling the SceneGraph parsing paragraph: a `(relationship` consists in a substring of the original image caption, extracted with a semantic parser. 
<br><br>
The procedure described above is implemented with the aid of a `Graph` data structure:
- A `Graph` represents the relationships between entities in each scene, that is a sample, described from the image and the caption. 
- The `nodes` consist in all the entities found by the Detector in the image.
- The `edges` are derived from the SceneGraph parsing of the caption. Only the edges between detected entities are considered.
<br><br>
```python
@dataclass(frozen=True)
class Graph:
    nodes: Float[Tensor, "N D"]
    edges: Float[Tensor, "M D"]
    edge_index: Int64[Tensor, "2 M"]

    @classmethod
    def new(
        cls,
        caption: Caption,
        embeddings: CaptionEmbeddings,
        detections: DetectionResults,
    ) -> Self:
        
        # A caption is a scenegraph
        graph = caption.graph
        assert graph is not None

        # The detected entities are nodes
        nodes = embeddings.entities[detections.entities]

        edge_index_list = []
        edge_rel_index_list = []

        for idx, detection in enumerate(detections.entities):
            entity: int = detection.item()

            # Build edges from the scenegraph's connections
            for connection in graph.connections(entity):
                if connection.end < entity:
                    continue

                device = detections.entities[0].device
                tmp = (
                    (detections.entities == connection.end)
                    .nonzero(as_tuple=True)[0]
                    .to(device)
                )
                indexes = torch.cat(
                    [
                        torch.tensor([idx])[None].expand(1, len(tmp)).to(device),
                        tmp[None],
                    ],
                    dim=0,
                )

                edge_index_list.append(indexes)
                edge_rel_index_list.extend([connection.relation] * len(tmp))

        # No relationships found in the caption
        if len(edge_index_list) == 0:
            edge_index = torch.empty(
                (2, 0),
                dtype=torch.long,
                device=detections.entities.device,
            )
        else:
            edge_index = torch.cat(edge_index_list, dim=1)  # (2, M)

        edge_rel_index = torch.tensor(
            edge_rel_index_list,
            dtype=torch.long,
            device=embeddings.relations.device,
        )  # (M,)
        edges = embeddings.relations[edge_rel_index]

        return cls(nodes, edges, edge_index)
```

<br><br>
A batch of `Graph` objects is organized in a wrapper data structure, namely a `NestedGraph`.
<br><br>
```python
class NestedGraph:
    def __init__(
        self,
        nodes: Float[Tensor, "B N D"],
        edges: Float[Tensor, "B E D"],
        edge_index: Int64[Tensor, "2 BE"],
        sizes: list[tuple[int, int]],
    ) -> None:
        self._nodes = nodes
        self._edges = edges
        self._edge_index = edge_index
        self._sizes = sizes
    (...)
```
<br><br>

In order to process a batch of Graph objects, padding is applied in the number of nodes and edges (in order to concatenate the nodes,edges embeddings):
<br><br>
```python
def pad_sequences(
    detections: list[DetectionResults],
    graphs: list[Graph],
) -> tuple[BBoxes, NestedGraph]:
    if len(detections) != len(graphs):
        raise ValueError(
            f"The number of detections ({len(detections)}) must be equal "
            f"to the number of graphs ({len(graphs)})"
        )

    batch = len(detections)
    sizes = [(graph.nodes.shape[0], graph.edges.shape[0]) for graph in graphs]
    max_nodes = max([nodes for nodes, _ in sizes]) + 1
    max_edges = max([edges for _, edges in sizes])

    # We add one to the number of max_nodes to prevent the situation in which
    # the graph with max_nodes is not the graph with max_edges.
    # If we did not do this, when padding the edge_index we would add an edge
    # between two non existent padding nodes (graph.nodes.shape[0] == max_nodes)
    # causing an index out of bounds error.

    padded_boxes = detections[0].boxes.tensor.new_ones(batch * max_nodes, 4)
    images_size = detections[0].boxes.images_size.new_ones(batch * max_nodes, 2)

    nodes = graphs[0].nodes.new_zeros(batch, max_nodes, graphs[0].nodes.shape[1])
    edges = graphs[0].edges.new_zeros(batch, max_edges, graphs[0].edges.shape[1])
    edge_index = graphs[0].edge_index.new_empty(2, batch * max_edges)

    for i, (detection, graph) in enumerate(zip(detections, graphs)):
        nodes[i, : graph.nodes.shape[0]].copy_(graph.nodes)
        edges[i, : graph.edges.shape[0]].copy_(graph.edges)

        # Pad edge index
        start = i * max_edges
        middle = start + graph.edge_index.shape[1]
        end = start + max_edges

        edge_index[:, start:middle].copy_(graph.edge_index + i * max_nodes)
        edge_index[0, middle:end] = graph.nodes.shape[0] + i * max_nodes
        edge_index[1, middle:end] = graph.nodes.shape[0] + i * max_nodes

        # Pad boxes
        start = i * max_nodes
        boxes = detection.boxes.to_cxcywh().normalize()
        end = start + len(boxes)

        padded_boxes[start:end].copy_(boxes.tensor)
        images_size[start:end].copy_(boxes.images_size)

    boxes = BBoxes(padded_boxes, images_size, BBoxFormat.CXCYWH, True)

    return boxes, NestedGraph(nodes, edges, edge_index, sizes)
```

Finally, the Decoder module is implemented as follows:
<br><br>
```python
class Decoder(nn.Module):
    def __init__(self, config: DecoderConfig) -> None:
        super().__init__()

        self._num_heads = config.num_heads
        self._hidden_dim = config.hidden_dim
        self._same_entity_edge = nn.Parameter(torch.randn(1, config.hidden_dim))

        # self._patch_encondings = PatchSpatialEncodings(config.hidden_dim)
        self._gaussian_heatmaps = GaussianHeatmaps()

        self._layers = nn.ModuleList(
            [DecoderLayer(config) for _ in range(config.num_layers)]
        )

    def forward(
        self,
        images: Nested4DTensor,
        graph: NestedGraph,
        boxes: BBoxes,
    ) -> NestedGraph:
        H, W = images.shape[2:]

        # (entity1, entity2), edges
        edge_index = graph.edge_index(False)  # (2 BE)

        # Select bboxes that have a connection (entity pairs)
        boxes1 = boxes[edge_index[0]]  # (BE 4)
        boxes2 = boxes[edge_index[1]]  # (BE 4)

        # Compute the union of the boxes and the respective gaussian heatmaps
        union_boxes = boxes1.union(boxes2)  # (BE 4)

        heatmaps = self._gaussian_heatmaps(boxes, (H, W))  # (BN HW)
        union_heatmaps = self._gaussian_heatmaps(union_boxes, (H, W))  # (BE, HW)
        heatmaps1 = heatmaps[edge_index[0]]  # (BE HW)
        heatmaps2 = heatmaps[edge_index[1]]  # (BE HW)

        edge_heatmaps = torch.maximum(
            torch.maximum(heatmaps1, heatmaps2),
            union_heatmaps,
        )  # (BE HW)

        heatmaps = torch.log(heatmaps + 1e-8)  # (BN HW)
        edge_heatmaps = torch.log(edge_heatmaps + 1e-8)  # (BE HW)

        node_heatmaps = heatmaps.view(len(graph), -1, H * W)  # (B N HW)
        edge_heatmaps = edge_heatmaps.view(len(graph), -1, H * W)  # (B E HW)
        heatmaps = torch.cat((node_heatmaps, edge_heatmaps), dim=1)  # (B (N+E) HW)

        flattened_images = images.flatten()  # (B HW D)
        masks = flattened_images.mask.unsqueeze(1).expand(-1, heatmaps.shape[1], -1)
        masks = heatmaps.masked_fill_(masks, -torch.inf)  # (B (N+E) HW)
        masks = masks.repeat(self._num_heads, 1, 1)  # (Bh (N+E) HW)

        nodes = graph.nodes(True)  # (B N D)
        edges = graph.edges(True)  # (B E D)
        x = torch.cat((nodes, edges), dim=1)  # (B (N+E) D)

        # One decoded layer 
        for block in self._layers:
            x = block(
                x,
                flattened_images.tensor,
                masks,
            )

        nodes = x[:, : nodes.shape[1]]  # (B N D)
        edges = x[:, nodes.shape[1] :]  # (B E D)

        return graph.new_like(nodes, edges)

    def __call__(
        self, images: Nested4DTensor, graph: NestedGraph, boxes: BBoxes
    ) -> NestedGraph:
        return super().__call__(images, graph, boxes)  # type: ignore


class DecoderLayer(nn.Module):
    def __init__(self, config: DecoderConfig) -> None:
        super().__init__()

        # attention with images
        self._norm1 = nn.LayerNorm(config.hidden_dim)
        self._attn = nn.MultiheadAttention(
            embed_dim=config.hidden_dim,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True,
        )
        self._layerscale1 = LayerScale(config.hidden_dim, config.epsilon_layer_scale)

        # feedforward
        self._norm2 = nn.LayerNorm(config.hidden_dim)
        self._ffn = nn.Sequential(
            nn.Linear(config.hidden_dim, 4 * config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(4 * config.hidden_dim, config.hidden_dim),
        )
        self._layerscale2 = LayerScale(config.hidden_dim, config.epsilon_layer_scale)

    def forward(
        self,
        x: Float[Tensor, "B (N+E) D"],
        images: Float[Tensor, "B HW D"],
        mask: Float[Tensor, "Bh (N+E) HW"],
    ) -> Float[Tensor, "B (N+E) D"]:
        # image attention
        x1 = self._norm1(x)  # (B (N+E) D)
        x1, _ = self._attn(
            x1,
            images,
            images,
            attn_mask=mask,
            need_weights=False,
        )  # (B (N+E) D)
        x1 = x + self._layerscale1(x1)  # (B (N+E) D)

        # feedforward
        x2 = self._norm2(x1)  # (B (N+E) D)
        x2 = self._ffn(x2)  # (B (N+E) D)
        x2 = x1 + self._layerscale2(x2)  # (B (N+E) D)

        return x2  # type: ignore


# taken from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
class LayerScale(nn.Module):
    def __init__(
        self,
        dim: int,
        init_value: float = 0.1,
        inplace: bool = False,
    ) -> None:
        super().__init__()

        self._inplace = inplace
        self._scale = nn.Parameter(torch.ones(dim) * init_value)

    def forward(self, x: Float[Tensor, "B ..."]) -> Float[Tensor, "B ..."]:
        return x.mul_(self._scale) if self._inplace else x * self._scale
```

<br><br>
Where each `DecoderLayer` in the `Decoder` consists in a vanilla Transformer's decoder block. For each sample: 
- As `Keys, Values`, the CLIP-encoded masked image patches are considered. 
- As `Query`, the concatenated nodes and edges of the sample scene graph are considered. Nodes and edges consist respectively in the CLIP encoded token sequences for the detected entity names and the parsed relationships.

### 4. Regression Head

The last module is a Transformer's Encoder:
- For each sample graph, positional (cosine) encoding is applied to the embeddings (nodes+edges).
- The language-informed (graph) visual tokens attend to each other in each scene. A learnable token is added to the input sequence to capture the scene semantics.
- The regression token is eventually projected to bounding box logits to predict the final ROI. 
- Self-attention between graph embeddings is computed batch-wise.
<br><br>
```python
class RegressionHead(nn.Module):
    def __init__(self, config: HeadConfig) -> None:
        super().__init__()

        self._dim = config.hidden_dim
        self._token = nn.Parameter(torch.randn(1, 1, config.hidden_dim))

        self._layers = nn.ModuleList(
            [ResidualAttentionLayer(config) for _ in range(config.num_layers)]
        )

        self._regression_head = nn.Sequential(
            nn.Linear(config.hidden_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.hidden_dim, 4),
            nn.Sigmoid(),
        )

    def _positional_encoding(self, mask: Bool[Tensor, "B L"]) -> Float[Tensor, "B L D"]:
        """ Sentence-wise positional embedding (cosine) """
        not_mask = ~mask  # (B, L)
        embed = not_mask.cumsum(dim=1, dtype=torch.float32)  # (B, L)

        dim_t = torch.arange(self._dim, dtype=torch.float32, device=mask.device)
        dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self._dim)

        pos = embed[:, :, None] / dim_t  # (B, L, D)
        pos = torch.stack((pos[..., 0::2].sin(), pos[..., 1::2].cos()), dim=3)
        pos = pos.view(pos.shape[0], pos.shape[1], -1)  # (B, L, D)

        return pos

    def forward(self, graph: NestedGraph, images_size: list[tuple[int, int]]) -> BBoxes:

        B = len(graph)
        nodes = graph.nodes(True)  # (B, N, D)
        edges = graph.edges(True)  # (B, E, D)
        # Each graph is a concatenation of nodes and edges
        tokens = torch.cat([nodes, edges], dim=1)  # (B, N+E, D)
        N, E = nodes.shape[1], edges.shape[1]

        # Pad masking
        mask = nodes.new_ones((B, 1 + N + E), dtype=torch.bool)  # (B, 1+N+E)
        for idx, (num_nodes, num_edges) in enumerate(graph.sizes):
            mask[idx, 0] = False
            mask[idx, 1 : 1 + num_nodes] = False
            mask[idx, 1 + N : 1 + N + num_edges] = False

        # Apply positional encoding (entity-wise)
        tokens = tokens + self._positional_encoding(mask[:, 1:])  # (B, N+E, D)

        # Concatenate all node-edges
        # concatenate also the learnable [REG] token
        x = torch.cat([self._token.expand(B, -1, -1), tokens], dim=1)  # (B, 1+N+E, D)

        # Masked self attention
        for layer in self._layers:
            x = layer(x, mask)

        # Regression on the learnable token
        token = self._regression_head(x[:, 0])  # (B, 4)
        boxes = BBoxes(token, images_size, BBoxFormat.CXCYWH, True)  # (B, 4)

        return boxes

    def __call__(
        self, graph: NestedGraph, images_size: list[tuple[int, int]]
    ) -> BBoxes:
        return super().__call__(graph, images_size)  # type: ignore


class ResidualAttentionLayer(nn.Module):
    def __init__(self, config: HeadConfig) -> None:
        super().__init__()

        self._norm1 = nn.LayerNorm(config.hidden_dim)
        self._attn = nn.MultiheadAttention(
            embed_dim=config.hidden_dim,
            num_heads=config.num_heads,
            dropout=config.dropout,
            batch_first=True,
        )
        self._dropout1 = nn.Dropout(config.dropout)

        self._norm2 = nn.LayerNorm(config.hidden_dim)
        self._ffn = nn.Sequential(
            nn.Linear(config.hidden_dim, 4 * config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(4 * config.hidden_dim, config.hidden_dim),
        )
        self._dropout2 = nn.Dropout(config.dropout)

    def forward(
        self, x: Float[Tensor, "B L D"], mask: Bool[Tensor, "B L"]
    ) -> Float[Tensor, "B L D"]:
        x1 = self._norm1(x)
        x1 = self._attn(x1, x1, x1, key_padding_mask=mask, need_weights=False)[0]
        x1 = self._dropout1(x1)
        x1 = x + x1

        x2 = self._norm2(x1)
        x2 = self._ffn(x2)
        x2 = self._dropout2(x2)
        x2 = x1 + x2

        return x2  # type: ignore
```

### Architecture overview

The described modules are combined according to the following pipeline:
- A batch of `(image,caption)` is firstly processed with CLIP to extract the sequences of embeddings.
- The processed batch is fed to the Detector to extract the `DetectionResults`.
- The resulting batch of graphs, with the encoded entity pairs, is fed to the `Decoder` to fuse the modalities.
- Eventually, the batch of graphs is fed to the `RegressionHead` such that all entity pairs within each graph attend to each other.
<br><br>
The final model is implemented in `src/visgator/models/erpa/_model.py`:

```python
class Model(_Model[BBoxes]):
    def __init__(self, config: Config) -> None:
        super().__init__()

        self._transform = Compose(
            [Resize(800, max_size=1333, p=1.0)],
            p=1.0,
        )

        self._criterion = Criterion(config.criterion)
        self._postprocessor = PostProcessor()

        # Encoders
        self._vision, self._text = build_encoders(config.encoders)
        self._decoder = Decoder(config.decoder)

        # Detector
        self._gdino = None
        self._owlvit = None
        if config.detector.gdino is not None:
            self._gdino = GroundigDINODetector(config.detector)
        elif config.detector.owlvit is not None:
            self._owlvit = OwlViTDetector(config.detector)

        # Regression Head
        self._head = RegressionHead(config.head)
    (...)

    def forward(self, batch: Batch) -> BBoxes:

        # Preprocess the images for the detector
        images = Nested4DTensor.from_tensors(
            [self._transform(sample.image) for sample in batch.samples]
        )
        img_tensor = images.tensor / 255.0
        images = Nested4DTensor(img_tensor, images.sizes, images.mask)

        # Detection results extraction
        if self._gdino is not None:
            detections = self._gdino(
                images, [sample.caption for sample in batch.samples]
            )
        elif self._owlvit is not None:
            detections = self._owlvit(batch)
        else:
            raise RuntimeError("No detector is initialized.")

        # CLIP encoded img+text
        img_embeddings = self._vision(images)
        text_embeddings = self._text(batch)

        # Constructing the batch graphs with entity embeddings
        graphs = [
            Graph.new(batch.samples[idx].caption, text_embeddings[idx], detections[idx])
            for idx in range(len(batch))
        ]

        boxes, graph = pad_sequences(detections, graphs)
        
        # ERP-Cross Attention (img+text)
        graph = self._decoder(img_embeddings, graph, boxes)

        # ERP-Self Attention and bbox regression
        boxes = self._head(graph, img_embeddings.sizes)
        
        if not boxes.normalized:
            raise RuntimeError("Boxes must be normalized.")

        boxes = BBoxes(
            boxes.tensor,
            [tuple(sample.image.shape[1:]) for sample in batch],  # type: ignore
            boxes.format,
            True,
        )

        return boxes
```

## Experiments

### Challenges
- When the object detector does not find the entities, the whole image is taken as bounding box for each entity.
- OwLViT and YOLO are limited, GroundingDINO is expensive.
- The dataset is noisy: grammatical errors, phrases with incomplete syntax. This leads to low quality entities detected from text and thus degraded comprehension of the textual grounding. The model may overfit the training set in order to yield high accuracy, with lack of generalization capabilities.
- Memory leaking: it required debug and meticoulus tracking of the computational graph to make sure no residuals are present. Garbage collection is also called after every backpropagation.

<hr>

### [below there is random and drafts]

### Preliminaries: downloading data & code

In [None]:
# Downloading the dataset
# !pip -q install gdown
# !gdown 1hxk0f62WtczYGp_zMBE_SQuqfsUvSNld

# Extracting the dataset locally
# !apt-get install unrar
# !unrar x -vb refcocog.rar

In [None]:
# Alternatively, mounting a Drive folder with data
from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/"Deep Learning 2023"
# !unrar x -vb refcocog.rar

Mounted at /content/drive
/content/drive/MyDrive/Deep Learning 2023


In [None]:
# Repo
!rm -rf visgator
!git clone https://github.com/FrancescoGentile/visgator
%cd visgator
!git checkout detector
!git pull
!pip3 install -q -e .

In [None]:
!pip3 install -q -e .

  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
  Building editable for visgator (pyproject.toml) ... [?25l[?25hdone


In [None]:
import json
from pathlib import Path
from ruamel.yaml import YAML

# Loading the model configuration file from repo
config_path = Path("config/local.yaml")
extention = config_path.suffix

match extention:
    case ".json":
        with open(config_path, "r") as f:
            cfg = json.load(f)
    case ".yaml":
        yaml = YAML(typ="safe")
        cfg = yaml.load(config_path)
    case _:
        raise ValueError(f"Unknown config file extention: {extention}.")

### 2. Instantiation

As subsequently shown, the training starts with the instatiation of a `visgator.engines.Trainer` object, followed by `Trainer.run()`. Under the hood, for `num_epochs` the method `Trainer._train_epoch(epoch)` is invoked:

```python
def _train_epoch(self, epoch: int) -> None:
        self._logger.info(f"Training epoch {epoch + 1} started.")

        start = timer()

        self._model.train()
        self._postprocessor.train()
        self._criterion.train()

        self._tl_tracker.increment()
        self._tm_tracker.increment()
        self._optimizer.zero_grad()

        counter = tqdm(
            desc="Training",
            total=self._get_steps_per_epoch() * self._params.train_batch_size,
        )

        with counter as progress_bar:
            batch: Batch
            bboxes: BBoxes
            device_type = "cuda" if self._device.is_cuda else "cpu"
            for idx, (batch, bboxes) in enumerate(self._train_loader):
                # since the dataloader batch size is equal to true batch size
                # // gradient accumulation steps, the last samples in the dataloader
                # may not be enough to fill the true batch size, so we break the loop
                # for example, if the true batch size is 8, the gradient accumulation
                # steps is 4 and the dataset size is 50, the last 2 samples will be
                # ignored
                if progress_bar.total == progress_bar.n:
                    break

                batch = batch.to(self._device.to_torch())
                bboxes = bboxes.to(self._device.to_torch())

                with autocast(device_type, enabled=self._params.mixed_precision):
                    outputs = self._model(batch)
                    tmp_losses = self._criterion(outputs, bboxes)
                    losses = self._tl_tracker(tmp_losses)
                    loss = losses.total / self._params.gradient_accumulation_steps

                self._scaler.scale(loss).backward()

                if (idx + 1) % self._params.gradient_accumulation_steps == 0:
                    if self._params.max_grad_norm is not None:
                        self._scaler.unscale_(self._optimizer)
                        torch.nn.utils.clip_grad_norm_(  # type: ignore
                            self._model.parameters(), self._params.max_grad_norm
                        )

                    self._scaler.step(self._optimizer)
                    self._scaler.update()
                    self._optimizer.zero_grad()
                    self._lr_scheduler.step_after_batch()

                with torch.no_grad():
                    pred_bboxes = self._postprocessor(outputs)
                    self._tm_tracker.update(
                        pred_bboxes.to_xyxy().normalize().tensor,
                        bboxes.to_xyxy().normalize().tensor,
                    )

                progress_bar.update(len(batch))
                
                del batch
                del bboxes
                del outputs
                del loss

                # Empty cache at each sample
                if self._device.is_cuda:
                    torch.cuda.empty_cache()

        self._lr_scheduler.step_after_epoch()

        end = timer()
        elapsed = end - start

        self._logger.info(f"Training epoch {epoch + 1} finished.")
        self._log_statistics(epoch, elapsed, train=True)
```

The ERP-Attention model is invoked in `visgator.engines.trainer.Trainer.train_epoch()` and it is implemented in `visgator.models.erpa.Model.forward()` as follows:

```python
def forward(self, batch: Batch) -> BBoxes:
        images = Nested4DTensor.from_tensors(
            [self._transform(sample.image) for sample in batch.samples]
        )
        img_tensor = images.tensor / 255.0
        images = Nested4DTensor(img_tensor, images.sizes, images.mask)
        detections = self._detector((batch, images), (self._model, self._tokenizer))

        # CLIP encoded img+text
        img_embeddings = self._vision(images)
        text_embeddings = self._text(batch)

        # Constructing the batch graphs with entity embeddings
        graphs = [
            Graph.new(batch.samples[idx].caption, text_embeddings[idx], detections[idx])
            for idx in range(len(batch))
        ]

        boxes, graph = pad_sequences(detections, graphs)
        graph = self._decoder(img_embeddings, graph, boxes)

        # ERP-Cross Attention (img+text)
        boxes = self._head(graph, img_embeddings.sizes)
        if not boxes.normalized:
            raise RuntimeError("Boxes must be normalized.")

        boxes = BBoxes(
            boxes.tensor,
            [tuple(sample.image.shape[1:]) for sample in batch],  # type: ignore
            boxes.format,
            True,
        )

        return boxes
```

Each `(image, caption)` pair is encoded with respectively the vision and the text backbone of CLIP (ViT-B/32). Note that the final projections leading to the common embedding space are discarded (`visgator.models.erpa._model.forward()`).

In parallel, each `(image, caption)` pair is also processed with a detector(e.g. [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO)/[YOLOv8](https://github.com/ultralytics/ultralytics)/[OwL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)), implemented with `visgator.models.erpa._detector.Detector`, to generate region proposals for each entity (`visgator.models.erpa._misc.Graph.new()`).

The per-entity region proposals, the text embeddings and the caption are embedded in a Nested SceneGraph, that is a set of graphs for the training batch. Moreover, the padded region proposals are organized in `visgator.utils.bbox.BBoxes` objects to apply handy geometrical operators.
Finally, a `NestedGraph` object is passed to the ERP-Decoder, along with the formatted bounding boxes and the image embeddings.


The invocation of `self._detector(...)` is implemented in `visgator.models.erpa.Detector.forward()`:

```python
def forward(self, data: tuple, model: tuple) -> list[DetectionResults]:
        # partialy taken from: https://huggingface.co/docs/transformers/model_doc/owlvit
        
        batch, nested_images = data
        clip, tokenizer = model

        # Preprocessing & YOLO
        images = [self._toPIL(sample.image) for sample in batch.samples]
        captions = [sample.caption for sample in batch.samples]
        
        B = len(captions)

        # Extracting graph entities
        entities: list[list[str]] = [None] * B  # type: ignore
        for i, caption in enumerate(captions):
            graph = caption.graph
            assert graph is not None
            entities[i] = [entity.head.lower().strip() for entity in graph.entities]

        # Object detection (open-vocabulary)
        with torch.no_grad():
            inputs = self._detector_processor(text=entities, images=images, return_tensors="pt").to(self.device)
            detector_results = self._detector_model(**inputs)

            # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
            target_sizes = torch.Tensor([image.size for image in images]).to(self.device)
            # Convert outputs (bounding boxes and class logits) to COCO API
            results = self._detector_processor.post_process_object_detection(outputs=detector_results, target_sizes=target_sizes)

        # For each result
        detections: list[DetectionResults] = [None] * B  # type: ignore

        
        for sample_idx in range(B):
            boxes, scores, labels = results[sample_idx]["boxes"], results[sample_idx]["scores"], results[sample_idx]["labels"]
            
            matched_indices = []
            matched_boxes = []
            height, width = images[sample_idx].size

            # Check identified identities by score
            for j, (box, score, label) in enumerate(zip(boxes, scores, labels)):
                box = [round(i, 2) for i in box.tolist()]
                if score >= self._detection_threshold:
                    matched_boxes.append(torch.tensor(box))
                    matched_indices.append(label)
                # not detected => suppose the entire image
                else:
                    matched_boxes.append(torch.tensor([0, 0, width-1, height-1]).to(self.device))
                    matched_indices.append(j) # entity index
               
            # if the detector hasn't identified an object => whole image as bounding box
            if len(boxes) == 0:
                for entity_idx, entity in enumerate(entities[sample_idx]):
                    matched_indices.append(entity_idx)
                    matched_boxes.append(torch.tensor([0, 0, width-1, height-1]).to(self.device))
            

            boxes = BBoxes(
                boxes=torch.stack(matched_boxes).to(self.device),
                images_size=images[sample_idx].size,
                format=BBoxFormat.XYXY,
                normalized=False,
            ).to_cxcywh().normalize()

            detections[sample_idx] = DetectionResults(
                entities=torch.tensor(matched_indices, device=self.device, dtype=torch.int),
                boxes=boxes,
            )

        del inputs
        del detector_results
        del target_sizes
        del results
        torch.cuda.empty_cache()

        return detections
```

### 3. Decoding


In each ERP, a gaussian heatmap is computed for each entity bounding box. The union of the heatmaps constitutes the mask for the input image (`visgator.models.erpa.Decoder.forward()`). A stack of attention layers processes the ERPs so that the visual tokens attend the text embeddings, eventually re-arranged in nodes and edges of the NestedGraph.

```python
def forward(
        self,
        images: Nested4DTensor,
        graph: NestedGraph,
        boxes: BBoxes,
    ) -> NestedGraph:
        H, W = images.shape[2:]

        # (entity1, entity2), edges
        edge_index = graph.edge_index(False)  # (2 BE)

        # Select bboxes that have a connection
        boxes1 = boxes[edge_index[0]]  # (BE 4)
        boxes2 = boxes[edge_index[1]]  # (BE 4)

        union_boxes = boxes1.union(boxes2)  # (BE 4)

        heatmaps = self._gaussian_heatmaps(boxes, (H, W))  # (BN HW)
        union_heatmaps = self._gaussian_heatmaps(union_boxes, (H, W))  # (BE, HW)
        heatmaps1 = heatmaps[edge_index[0]]  # (BE HW)
        heatmaps2 = heatmaps[edge_index[1]]  # (BE HW)

        edge_heatmaps = torch.maximum(
            torch.maximum(heatmaps1, heatmaps2),
            union_heatmaps,
        )  # (BE HW)

        heatmaps = torch.log(heatmaps + 1e-8)  # (BN HW)
        edge_heatmaps = torch.log(edge_heatmaps + 1e-8)  # (BE HW)

        node_heatmaps = heatmaps.view(len(graph), -1, H * W)  # (B N HW)
        edge_heatmaps = edge_heatmaps.view(len(graph), -1, H * W)  # (B E HW)
        heatmaps = torch.cat((node_heatmaps, edge_heatmaps), dim=1)  # (B (N+E) HW)

        flattened_images = images.flatten()  # (B HW D)
        masks = flattened_images.mask.unsqueeze(1).expand(-1, heatmaps.shape[1], -1)
        masks = heatmaps.masked_fill_(masks, -torch.inf)  # (B (N+E) HW)
        masks = masks.repeat(self._num_heads, 1, 1)  # (Bh (N+E) HW)

        # image_encodings = self._patch_encondings(images.mask)

        nodes = graph.nodes(True)  # (B N D)
        edges = graph.edges(True)  # (B E D)
        x = torch.cat((nodes, edges), dim=1)  # (B (N+E) D)

        for block in self._layers:
            x = block(
                x,
                flattened_images.tensor,
                masks,
            )

        nodes = x[:, : nodes.shape[1]]  # (B N D)
        edges = x[:, nodes.shape[1] :]  # (B E D)

        return graph.new_like(nodes, edges)
```

### Training ERP-A

In [None]:
from visgator.engines.trainer import Config as TrainerConfig
from visgator.engines.trainer import Trainer
from typing import Any

train_config = TrainerConfig.from_dict(cfg)
trainer: Trainer[Any] = Trainer(train_config)
trainer.run()