## Referring Expression Comprehension as Scene Graph Grounding

### Authors

Diego Calanzone, Francesco Gentile <br>
University of Trento <br>
Deep Learning course project, Spring 2023

### Notes on the code

For this project we have decided to develop a complete framework that can be easily extended to many computer vision tasks by simply defining the necessary metrics and inputs and targets. This made extremely easy to test different models without reuqiring any modifications to parts of the code not related to the model.

In this notebook, we do not report all the infrastructure since it is not the relevant part of the project. We only report those code snippets that are associated to our proposed model. The complete framework can be found at https://github.com/FrancescoGentile/visgator in the `deepsight` branch.

The structure of the project is highly inspired on that of [detrex](https://github.com/IDEA-Research/detrex) and is:

```
datasets/ # here you can find all the datasets (at the moment only RefCOCOg)
deepsight/ # this contain the framework source code
  data/
    structs/ # dataclasses used to model the input and output of different tasks
    datasets/ # interface to be implemented bt all datasets
    transformations/ # data aumentation transforms
  engines/ # contains the trainer and tester
  modeling/
    layers/ # reusable layers
    detectors/ # wrappers around available object detectors
    parsers/ # modules to extract scene graphs from sentences
    pipeline/ # contains the interfaces for the creation of new models
  measures/ # contains losses and metrics
  optimizers/
  lr_schedulers/
  utils/
projects/ # contains the implementation of the proposed model and the baseline
```

### Preliminaries

In [None]:
!git clone https://github.com/FrancescoGentile/visgator
%cd visgator
!git checkout deepsight

Cloning into 'visgator'...
remote: Enumerating objects: 1264, done.[K
remote: Counting objects: 100% (434/434), done.[K
remote: Compressing objects: 100% (277/277), done.[K
remote: Total 1264 (delta 153), reused 377 (delta 128), pack-reused 830[K
Receiving objects: 100% (1264/1264), 1.90 MiB | 24.31 MiB/s, done.
Resolving deltas: 100% (655/655), done.
/content/visgator
Branch 'deepsight' set up to track remote branch 'deepsight' from 'origin'.
Switched to a new branch 'deepsight'


In [None]:
# Download and extract the RefCOCOg dataset
!pip install -q gdown
!gdown 1hxk0f62WtczYGp_zMBE_SQuqfsUvSNld
!apt-get install unrar
!unrar x refcocog.rar

In [None]:
!gdown 1HA0GbP0MXMm8UOC-ties9DtFoI_NUt9l

In [None]:
!rm refcocog.rar
!mkdir data
!mv refcocog data
!mv scene_graphs.json data/refcocog/annotations

In [None]:
!pip install -q torch==2.0.1 torchvision==0.15.2
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu118.html
!pip install -q jaxtyping torchmetrics pyserde numpy rustworkx transformers scikit-learn ruamel-yaml wandb albumentations openai ultralytics
!pip install -q SceneGraphParser diaparser
!python -m spacy download en_core_web_sm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m92.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m728.8/728.8 kB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m74.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m92.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.2/112.2 kB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m86.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.6/73.6 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m627.5/627.5 kB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!pip install -e .

Obtaining file:///content/visgator
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting numpy>=1.25.0 (from deepsight==0.1.0+editable)
  Downloading numpy-1.25.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m44.7 MB/s[0m eta [36m0:00:00[0m
Collecting scikit-learn>=1.3.0 (from deepsight==0.1.0+editable)
  Downloading scikit_learn-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
Collecting albumentations>=1.3.1 (from deepsight==0.1.0+editable)
  Downloading albumentatio

## Introduction

Referring Expression Comprehension (REC) is the task of localizing a target object in an image given a natural language expression that refers to it. Most recent approaches ([Zhang et al. 2022](https://arxiv.org/abs/2206.05836), [Xu et al. 2023](https://arxiv.org/abs/2302.00402), [Liu et al. 2023](https://arxiv.org/abs/2303.05499)) that obtain state-of-the-art results on this task are not specifically designed for it, but they are designed to solve a large variety of tasks that require fusing vision and language modalities, like open-set object detection, image captioning, visual question answering and so on. In particular, most of these first independently encode the the visual and textual input using vision and text encoders (based on the Transformer architecture) respectively, then another transformer module is used to fuse the two modalities by making the visual features attend to the textual features and vice versa. Finally, the fused features are given in input to another module (a simple head, a transformer decoder, etc.) based on the task that is being solved.

Here we argue that the task of REC requires an high-level understanding of the scene described the region caption. For example, given a caption like *"The girl approaching the table while holding a glass"*, to correctly localize the associated bounding box, we need to first identify all the entities referred by the sentence (`the girl`, `the table`, `a glass`) and the relation that exist among them ((`the girl` -- `approaching` -> `the table`), (`the girl` -- `holding` --> `a glass`)). In other words, we need to extract from the sequence of words that form the sentence an intermediate higher-level representation of the scene. Then, instead of grounding the sequence of words to the image, we can ground the intermediate representation to the image. On the other hand, previously cited approaches, since they need to generalize to many image-text tasks, simply ground the word features (here we make the simplifying assumption that each token correspond to a word) extracted by the text encoder to the image features. Thus, to perform well in such task, the text encoder need to encode into each token not only the meaning of the corresponding word but also its relations with the other entities that in the sentence may be refered by group of tokens. In toher words, the text encoder need to learn to extract from the sequence of tokens a higher-level representation without being explicitly supervised to do so.

Based on this observation, we propose a new approach to REC that is specifically designed for this task by making the network directly exploit the higher-level semantic information encoded in the input sentence. In particular, from the input sentence we extract a scene graph representing which entities are present in the region and how they are related to each other. Then, we localize the target region by localizing in the image the referred entities that satisfy the referred relations.

### High level architecture overview

Our architecture is highly inspired to DETR-like models ([Carion et al. 2020](https://arxiv.org/abs/2005.12872), [Gao et al. 2021](https://arxiv.org/abs/2101.07448), [Liu et al. 2022](https://arxiv.org/abs/2201.12329)). Such models use a (CNN) backbone followed by a transformer-based vision encoder to extract visual features from the input image. Then a set of queries (representing candidate bounding boxes) is given in input to a transformer-based decoder, where such queries go through layers of self-attention and cross-attention with the visual features. Finally, the output of the decoder is given in input to a simple head that predicts the bounding boxes coordinates and the class of each bounding box. At training time, Hungarian matching is used to obtain a one-to-one matching between a query and a ground truth bounding box. Once such association is obtained, the loss is computed by comparing the predicted bounding box with the associated ground truth bounding box. At inference time, the predicted bounding boxes are filtered by a simple post-processing step to remove the predicted bounding boxes that have a low confidence score.

Similarly, we extract the visual features by employing a transformer-based vision encoder (no backbone is used since we use the CLIP vision encoder). Then, differently from DETR-like models, we do not generate a predefined set of fixed or learnable queries, but we create a graph based on the one extracted from the sentence. In particular, for each entity in the sentence scene graph we create multiple nodes in the graph (since in the image there may be multiple instances of the same entity) whose embeddings are initialized with the embedding obtained by giving in input to the CLIP text encoder the entity textual description extracted from the sentence. Then, we create an edge between two nodes if the corresponding entities are related in the sentence scene graph; as before, the edge features are initialized by encoding the textual description of the relation with the CLIP text encoder. Then, the generated graph is given in input to the transformer-based decoder whose blocks consist of a sequence of `Multi-Head Cross Attention`, `Graph Attention` and `FFN`.

In the multi-head cross attention layer, each query (nodes + edges) can attend to the visual features extracted from the vision encoder. This allow each query to verify whether the associated entity/relation is present in a specific region of the image. Then, in the graph attention layer, each node can communicate with its neighbours and the associated relations to verify whether the encoded instance of the entity satisfy the relations encoded in the sentence scene graph.

Finally, from the graph outputted by the decoder, we extract the nodes that represent the subject of the sentence (i.e. the target entity) and we give them in input to a simple head to obtain candidate bounding boxes for the target entity. At training time, a simple matching algorithm is applied to associate the ground truth bounding box with one of the predicted bounding boxes and the loss is computed. At inference time, we select the node (and the obtained bounding box) whose embedding is the most similar to the embedding obtained by giving in input to the CLIP text encoder the full sentence.

As currently presented, the decoder should also perform open-set object detection, since for each entity it should localize all the instances in the image. Thus, we should create a sufficient high number of nodes for each entity to be able to localize all the instances of the entity in the image. For example, Grounding DINO ([Liu et al. 2023](https://arxiv.org/abs/2303.05499)) creates 900 queries for each image. This would clearly require a lot of memory and computation power. Furthermore, current open-set object detectors are trained on huge amounts of data, on many GPUs and for long period of times (Grounding DINO uses 64 A100). Given the limited resources availables, we decided to employ an open-set object detector to obtain all instances of an entity in the image and an estimate of their location. In this way the decoder does not need to perform open-set object detection from scratch but it only needs to refine the estimated locations. Since existing open-set object detectors are mainly trained on closed object detection datasets, where each entity to be detected is represented by a single noun (i.e., the category name), to make the detector localize an entity we do not use its full textual description. Instead, for each entity we extract a single noun that best describe that entity. For example, given the entity `The woman with dark hair`, we extract the noun `woman` to localize the entity.

## Method

### How to extract scene graphs?

As previously said, one of the first step is the extraction of the region scene graph from its text description. This task can be seen as the union of two closely related problems: named entity recognition and relation extraction. Since these tasks have long been studied by the NLP community, we have tried many existing solutions or we took inspiration from them to build our own. In the following we will describe the main approaches we have tried.

Before diving into details, we notice that the generation of a scene graph from a sentence is an ambiguous task, that is the same sentence could be parsed into different scene graphs. When two noun phrases are connected by an action verb, it seems obvious to identify each noun as an entity and the verb as a relation. However, when the nouns are connected by a preposition, the situation is more ambiguous. For example, in the sentence *"the woman in a green shirt"*, the noun phrase `a green shirt` could be considered an attribute of `the woman` or a different entity related to `the woman` with the relation `in`/`wearing`. Similarly, in the sentence *"the woman on the right"*, some people may consider `the right` as an actual physical location and thus as an entity, while others may consider it as an attribute of `the woman`.

Since most phrases in the RefCOCOg dataset are quite short, if we preferred the attribute interpretation, we would have obtained many scene graphs with very few nodes and edges or no edges at all, thus jeopardizing the idea underlying the model. For this reason, we have generally preferred the creation of a new entity for each noun phrase. However, we have preferred the attribute interpretation, when we thought that the detector would find it difficult to localize such entity (for example, in the case of spatial locations like `the right`, `the left`, `the background`, etc.).

#### Dependency Graph based parsers

Historically, one of the first approach to parse sentences in natural language into scene graphs was the one proposed by ([Schuster et al. 2015](https://aclanthology.org/W15-2812.pdf)). First the sentence is parsed into a semantic graph (i.e., a dependency graph to which some refinements are applied, like the handling of pronouns and plural nouns) using the CoreNLP pipeline. Then, based on a set of human-written rules, the semantic graph is converted into a scene graph. This approach has been used in many past works for different purposes, like evaluating generated image captions ([Anderson et al. 2016](https://arxiv.org/abs/1607.08822)) or creating pseudo ground truth scene graphs for Weakly Supervised Scene Graph Generation ([Ye et al. 2021](https://arxiv.org/abs/2105.13994), [Zhong et al. 2021](https://arxiv.org/abs/2109.02227), [Li et al. 2022](https://arxiv.org/abs/2208.01834)).

Note that, since the original parser is written in Java, we used a Python-based [tool](https://github.com/vacancy/SceneGraphParser) that covers all the rules implemented in the Stanford Parser with some additional ones (however, it does not implement some features like pronoun handling and quantificational modifiers).


In [None]:
import sng_parser

from deepsight.data.structs import SceneGraph, Entity, Triplet

gdict = sng_parser.parse("The girl approaching the table")

graph = SceneGraph.new(
    entities=[Entity(ent['head'], ent['span']) for ent in gdict["entities"]],
    triplets=[Triplet(trip['subject'], trip['relation'], trip['object']) for trip in gdict["relations"]]
)

print(graph.entities())
print(graph.triplets(None, True, False))

[Entity(noun='girl', phrase='The girl'), Entity(noun='table', phrase='the table')]
[Triplet(subject=0, relation='approaching', object=1)]


Despite being largely used by many previous works, we noticed that the quality of the parsed scene graphs rapidly degrades as the strcuture of the sentence becomes more distant from *subject* *predicate* *object*. For example, a sentence like "the girl approaching the table" is correctly parsed as we can see from the previous python snippet.

However, as the sentence becomes more complex, many entities are not found or prepositions/adjectives are classified as entities. Similarly, many relations are missing or the wrong relation is assigned to a pair of entities. For example, if we simply extend the previous sentence with a coordinate conjunction ("while holding a glass"), the parser completely ignores the relation `(0, "holding", 2)`, making the entity `the glass` not present in the scene graph (as it will be described later, the generated scene graph will be pruned to remove entities not connected to the subject of the description).

In [None]:
gdict = sng_parser.parse("The girl approaching the table while holding a glass")

graph = SceneGraph.new(
    entities=[Entity(ent['head'], ent['span']) for ent in gdict["entities"]],
    triplets=[Triplet(trip['subject'], trip['relation'], trip['object']) for trip in gdict["relations"]]
)

print(graph.entities())
print(graph.triplets(None, True, False))

[Entity(noun='girl', phrase='The girl'), Entity(noun='table', phrase='the table'), Entity(noun='glass', phrase='a glass')]
[Triplet(subject=0, relation='approaching', object=1)]


Similarly, if we consider the sentence *"There is a truck covered in snow farthest from the right"*, the parser completely ignores the entity `the snow` and the relation `(0, "covered in", 2)`. Furthermore, the parser wrongly classified the expression `farthest from the right` as relation + entity, when they should be considered attributes of the entity `the truck`.

In [None]:
gdict = sng_parser.parse("There is a truck covered in snow farthest from the right")

graph = SceneGraph.new(
    entities=[Entity(ent['head'], ent['span']) for ent in gdict["entities"]],
    triplets=[Triplet(trip['subject'], trip['relation'], trip['object']) for trip in gdict["relations"]]
)

print(graph.entities())
print(graph.triplets(None, True, False))

[Entity(noun='truck', phrase='a truck'), Entity(noun='right', phrase='the right')]
[Triplet(subject=0, relation='from', object=1)]


By analyzing the functioning of the parser, we noticed that its poor results are mainly due to the NLP pipeline used for the generation of the dependency graph and to the limited set of rules used to convert the dependency graph into a scene graph. In particular, due to the low quality of many sentences in the dataset (e.g., typos, not perfect syntactic structure), the dependency graph generated by the spaCy `en_core_web_sm` pipeline in many cases assign the wrong universal dependency relation tag between two words, thus leading to the wrong conversion into a scene graph.

Thus, since the quality of the scene graph is paramount for the success of the model, we tried to develop a new tool using a more powerful dependency parser and a more refined set of rules. In particular, we used one of the state-of-the-art dependency parsers by ([Attardi et al. 2022](https://github.com/Unipisa/diaparser)), that extends the architecture of the Biaffine Parser by exploiting both embeddings and attentions provided by transformers. Then, based on the dependency graphs generated for some sentences of the dataset and the corresponding ground-truth scene graphs, we developed a new set of rules.

In [None]:
from dataclasses import dataclass

from diaparser.parsers import Parser
import rustworkx as rx


@dataclass(frozen=True)
class Word:
    tag: str
    text: str


@dataclass(frozen=True)
class DepEntity:
    head: list[int]
    others: list[int]


class ERParser:
    def __init__(self) -> None:
        self._parser = Parser.load("en_ewt-electra")

    def _get_dependency_graph(self, sentence: str) -> rx.PyDiGraph:
        dataset = self._parser.predict(sentence, text="en")
        tokens = dataset.sentences[0].to_tokens()

        graph = rx.PyDiGraph()  # type: ignore

        graph.add_node(Word("ROOT", "-ROOT-"))
        graph.add_nodes_from([Word(token["deprel"], token["form"]) for token in tokens])

        for token in tokens:
            head_id = int(token["head"])
            id = int(token["id"])
            graph.add_edge(head_id, id, None)

        return graph

    def _get_child_by_tag(self, graph: rx.PyDiGraph, node: int, tag: str) -> list[int]:
        children = []

        for child_id in graph.neighbors(node):
            if tag in graph.get_node_data(child_id).tag:
                children.append(child_id)

        return children

    def _compose_span(self, dep_graph: rx.PyDiGraph, words_ids: list[int]) -> str:
        words_ids.sort()
        words = [dep_graph.get_node_data(id).text for id in words_ids]
        return " ".join(words)

    def _get_all_children(self, dep_graph: rx.PyDiGraph, node_id: int) -> list[int]:
        children = []
        for child_id in dep_graph.neighbors(node_id):
            children.append(child_id)
            children.extend(self._get_all_children(dep_graph, child_id))
        return children

    def _compose(
        self,
        orig: rx.PyDiGraph,
        other: rx.PyDiGraph,
        parent_id: int,
        rel_ids: list[int],
    ) -> None:
        roots = []
        for node_id in other.node_indexes():
            if other.in_degree(node_id) == 0:
                roots.append(node_id)

        if len(roots) == 0:
            raise ValueError("No root found")

        new_node_ids = orig.compose(other, {})

        for root_id in roots:
            orig.add_edge(parent_id, new_node_ids[root_id], rel_ids)

    def _get_coordinated_verbs(self, dep_graph: rx.PyDiGraph, verb_id: int) -> list[int]:
        verbs = [verb_id]
        for child_id in dep_graph.neighbors(verb_id):
            tag = dep_graph.get_node_data(child_id).tag
            if "conj" in tag or "parataxis" in tag:
                cc_ids = self._get_child_by_tag(dep_graph, child_id, "cc")
                for cc_id in cc_ids:
                    dep_graph.remove_edge(child_id, cc_id)

                punct_ids = self._get_child_by_tag(dep_graph, child_id, "punct")
                for punct_id in punct_ids:
                    dep_graph.remove_edge(child_id, punct_id)

                dep_graph.remove_edge(verb_id, child_id)
                verbs.append(child_id)
        return verbs

    def _parse_noun(self, dep_graph: rx.PyDiGraph, noun_id: int) -> rx.PyDiGraph:
        graph = rx.PyDiGraph()

        noun_ids = DepEntity([noun_id], [])
        noun_node_id = graph.add_node(noun_ids)

        for child_id in dep_graph.neighbors(noun_id):
            child = dep_graph.get_node_data(child_id)

            if dep_graph.out_degree(child_id) == 0:
                if "compound" in child.tag:
                    noun_ids.head.append(child_id)
                else:
                    noun_ids.others.append(child_id)
                continue

            if "det" in child.tag:
                # this determiner should have no children
                raise NotImplementedError
            elif "amod" in child.tag:
                obl_ids = self._get_child_by_tag(dep_graph, child_id, "obl")
                if len(obl_ids) == 0:
                    noun_ids.others.append(child_id)
                    noun_ids.others.extend(self._get_all_children(dep_graph, child_id))
                else:
                    for obl_id in obl_ids:
                        dep_graph.remove_edge(child_id, obl_id)
                        case_ids = self._get_child_by_tag(dep_graph, obl_id, "case")
                        for case_id in case_ids:
                            dep_graph.remove_edge(obl_id, case_id)

                    rel_ids = [child_id, *case_ids]
                    for obl_id in obl_ids:
                        sub_graph = self._parse_noun(dep_graph, obl_id)
                        self._compose(graph, sub_graph, noun_node_id, rel_ids)
            elif "compound" in child.tag:
                raise NotImplementedError
            elif "nmod" in child.tag or "obl" in child.tag:
                case_ids = self._get_child_by_tag(dep_graph, child_id, "case")
                if len(case_ids) == 0:
                    raise ValueError("No case found")
                for case_id in case_ids:
                    dep_graph.remove_edge(child_id, case_id)
                    sub_graph = self._parse_noun(dep_graph, child_id)
                    self._compose(graph, sub_graph, noun_node_id, [case_id])
            elif "acl:relcl" in child.tag:
                verb_ids = self._get_coordinated_verbs(dep_graph, child_id)
                for verb_id in verb_ids:
                    nsubj_ids = self._get_child_by_tag(dep_graph, verb_id, "nsubj")
                    if len(nsubj_ids) == 0:
                        raise ValueError("No nsubj found")
                    if len(nsubj_ids) > 1:
                        raise ValueError("More than one nsubj found")

                    dep_graph.remove_edge(verb_id, nsubj_ids[0])
                    rel_ids, sub_graph = self._parse_verb(dep_graph, verb_id)
                    if sub_graph.num_nodes() == 0:
                        noun_ids.others.extend(rel_ids)
                    else:
                        self._compose(graph, sub_graph, noun_node_id, rel_ids)
            elif "acl" in child.tag or "root" in child.tag:
                verb_ids = self._get_coordinated_verbs(dep_graph, child_id)
                for verb_id in verb_ids:
                    rel_ids, sub_graph = self._parse_verb(dep_graph, verb_id)
                    if sub_graph.num_nodes() == 0:
                        noun_ids.others.extend(rel_ids)
                    else:
                        self._compose(graph, sub_graph, noun_node_id, rel_ids)
            elif "conj" in child.tag:
                cc_ids = self._get_child_by_tag(dep_graph, child_id, "cc")
                for cc_id in cc_ids:
                    dep_graph.remove_edge(child_id, cc_id)

                punct_ids = self._get_child_by_tag(dep_graph, child_id, "punct")
                for punct_id in punct_ids:
                    dep_graph.remove_edge(child_id, punct_id)

                sub_graph = self._parse_noun(dep_graph, child_id)
                graph.compose(sub_graph, {})
            else:
                raise ValueError(f"Unknown tag: {child.tag}")

        return graph

    def _parse_verb(
        self, dep_graph: rx.PyDiGraph, verb_id: int
    ) -> tuple[list[int], rx.PyDiGraph]:
        graph = rx.PyDiGraph()
        rel_ids = [verb_id]

        for child_id in dep_graph.neighbors(verb_id):
            tag = dep_graph.get_node_data(child_id).tag

            if dep_graph.out_degree(child_id) == 0:
                if "punct" not in tag:
                    rel_ids.append(child_id)
                continue

            if "aux" in tag:
                # this auxiliary verb should have no children
                raise NotImplementedError
            elif "nsubj" in tag:
                raise ValueError("nsubj should be the root")
            elif "nmod" in tag or "obl" in tag:
                case_ids = self._get_child_by_tag(dep_graph, child_id, "case")
                if len(case_ids) == 0:
                    raise ValueError("No case found")

                for case_id in case_ids:
                    rel_ids.append(case_id)
                    dep_graph.remove_edge(child_id, case_id)

                sub_graph = self._parse_noun(dep_graph, child_id)
                graph.compose(sub_graph, {})
            elif "obj" in tag:
                sub_graph = self._parse_noun(dep_graph, child_id)
                graph.compose(sub_graph, {})
            elif "conj" in tag:
                raise ValueError("conj should have been removed")
            else:
                raise ValueError(f"Unknown tag {tag}")

        return rel_ids, graph

    def parse(self, sentence: str) -> SceneGraph:
        dep_graph = self._get_dependency_graph(sentence)

        root = list(dep_graph.out_edges(0))[0][1]
        dep_graph.remove_node(0)

        nsubj_ids = self._get_child_by_tag(dep_graph, root, "nsubj")
        if len(nsubj_ids) == 0:
            tmp = self._parse_noun(dep_graph, root)
        elif len(nsubj_ids) > 1:
            raise ValueError("More than one nsubj found")
        else:
            dep_graph.remove_edge(root, nsubj_ids[0])
            dep_graph.add_edge(nsubj_ids[0], root, None)
            tmp = self._parse_noun(dep_graph, nsubj_ids[0])

        entities = []
        for span_ids in tmp.nodes():
            head = self._compose_span(dep_graph, span_ids.head)
            span = self._compose_span(dep_graph, span_ids.head + span_ids.others)

            entities.append(Entity(head, span))

        relations = []
        for edge in tmp.edge_list():
            subject = edge[0]
            object = edge[1]
            span_ids = tmp.get_edge_data(subject, object)
            predicate = self._compose_span(dep_graph, span_ids)

            relations.append(Triplet(subject, predicate, object))

        return SceneGraph.new(entities, relations)


parser = ERParser()

Downloading: "https://github.com/Unipisa/diaparser/releases/download/v1.0/en_ewt.electra-base" to /root/.cache/diaparser/en_ewt.electra-base
100%|██████████| 452M/452M [00:28<00:00, 16.7MB/s]


Downloading (…)lve/main/config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/electra-base-discriminator were not used when initializing ElectraModel: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/27.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Using bos_token, but it is not set yet.
Using eos_token, but it is not set yet.


In [None]:
graph = parser.parse("The boy wearing a white shirt having dinner with his friends")

print(graph.entities())
print(graph.triplets(None, True, False))

[Entity(noun='boy', phrase='The boy'), Entity(noun='shirt', phrase='a white shirt'), Entity(noun='friends', phrase='his friends')]
[Triplet(subject=0, relation='wearing', object=1), Triplet(subject=0, relation='having dinner with', object=2)]




Despite the promising results obtained on some sentences, we abandoned this approach for two reasons:
1. The quality of the resulting scene graph depends too much on the quality of the dependency graph generated by the parser. In particular, even though the new parser correclty parses more complex sentences, it still fails at handling long-range connections between words or relations that can be inferred from common sense. For example, in the sentence *"There is a truck covered in snow farthest from the right"*, the parser connects the clause `farthest from the right` to the clause `covered in snow`. While this interpretation of the sentence may be deemed right (the sentence is ambiguous), the preferred interpretation should be that `farthest from the right` is an attribute of the entity `the truck`. However, the parser is not able to infer this relation, thus leading to the wrong scene graph.

In [None]:
from spacy import displacy

sent = parser._parser.predict("There is a truck covered in snow farthest from the right", text="en").sentences[0]
displacy.render(sent.to_displacy(), style='dep', manual=True, options={'compact': True, 'distance': 120}, jupyter=True)

2. Defining a set of universal rules to handle all possible structures of English sentences (even malformed ones) is extremely hard. In particular, there are many sentences that have very similar dependency graphs that, however, should be transformed in different scene graphs, thus requiring not only to handle the structure of the sentence, but also its semantics. For example, the expressions *"the part of the table"*, *"the first horse from the left"* and *"the woman in green clothes"*, have very similar dependency graphs, but in the first two cases the expressions should be considered a single entity, while in the last case they should be considered two different entities related by the relation `wearing`.

In [None]:
sent = parser._parser.predict("the part of the table", text="en").sentences[0]
displacy.render(sent.to_displacy(), style='dep', manual=True, jupyter=True, options={'compact': True, 'distance': 120})

In [None]:
sent = parser._parser.predict("the first horse from the left", text="en").sentences[0]
displacy.render(sent.to_displacy(), style='dep', manual=True, jupyter=True, options={'compact': True, 'distance': 120})

In [None]:
sent = parser._parser.predict("the woman in green clothes", text="en").sentences[0]
displacy.render(sent.to_displacy(), style='dep', manual=True, options={'compact': True, 'distance': 120}, jupyter=True)

#### Large Language Models

Large Language Models pretrained on large text corpora have recently shown to be able to reach state-of-the-art results on many NLP tasks for which they were not explicitly trained, like question answering, summarization and text generation. Furthermore, since these models seem to show emergent reasoning capabilities ([Liu et al. 2023](https://arxiv.org/abs/2304.03439)), we considered their use for this task that, as previously said, requires common sense and the ability to understand the scene context described by the sentence.

In our first attempts, we tried to use open LLM available on the HuggingFace site, like [Falcon-7B-Instruct](https://huggingface.co/tiiuae/falcon-7b-instruct), [Alpaca-13B](https://crfm.stanford.edu/2023/03/13/alpaca.html), Stable-Vicuna-13B, given the possibility to use their quantized version and thus running them on single gpus with only 8/16 GB of memory. For all our experiments we adopted a few-shot approach, i.e., in the prompt we described the task and then we provided a set of examples of the task. When generating the prompt, the main tradeoff was between the number of examples (and thus the inference time) and the quality of the generated scene graphs. At the end, we decided to use 5 examples that allowed us to show the model a variety of sentences with different structures and relations.

In [None]:
prompt = \
""""\
The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, and a human user, called User. Falcon is extremely good at visualizing and understanding a scene from a short description of it, and can answer questions about the scene. User is a human who is curious about the world, and wants to know more about the entities present in the scene, even if not explicitly stated in the description, and which relations occur among them.
In the following interactions, User will make requests in natural language, while Falcon will answer to each of these requests with a well formed JSON. The conversation begins.
User: Given the following sentence: "{sentence1}", what are the entities present in the scene? What are their relations? For each entity, please provide also a single word that best summarizes it and make sure that the subject of the sentence is the first entity.
Falcon: {example1}
User: Do the same for the following sentence: "{sentence2}".
Falcon: {example2}
User: Try with this one: "{sentence3}".
Falcon: {example3}
User: Please, do the same also for this sentence: "{sentence4}".
Falcon: {example4}
User: Let's try with this one: "{sentence5}".
Falcon: {example5}
User: Finally, try with this one: "{sentence6}".
Falcon:
"""

In [None]:
import json
from typing import Iterable

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline


class LLMParser:
    def __init__(self) -> None:
        name = "tiiuae/falcon-7b-instruct"

        self._prompt = prompt

        tokenizer = AutoTokenizer.from_pretrained(name)
        model = AutoModelForCausalLM.from_pretrained(
            name,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
        )

        self._pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto",
            batch_size=1, # we found that increasing the batch slows down the generation
        )

        self._pipe.tokenizer.pad_token_id = model.config.eos_token_id

    def _create_prompt(self, sentence: str) -> str:
        sentence1 = "the girl looking at the table full of drinks"
        example1 = {
            "entities": [
                ("the girl", "girl"),
                ("the table", "table"),
                ("drinks", "drinks"),
            ],
            "relations": [
                (0, "looking at", 1),
                (1, "full of", 2),
            ],
        }

        sentence2 = (
            "the man wearing a long sleeved white shirt and a pair of blue jeans "
            "catching a freesbie"
        )
        example2 = {
            "entities": [
                ("the man", "man"),
                ("a long sleeved white shirt", "shirt"),
                ("a pair of blue jeans", "jeans"),
                ("a freesbie", "freesbie"),
            ],
            "relations": [
                (0, "wearing", 1),
                (0, "wearing", 2),
                (0, "catching", 3),
            ],
        }

        sentence3 = "Skateboarder in green"
        example3 = {
            "entities": [
                ("Skateboarder", "Skateboarder"),
                ("green clothes", "clothes"),
            ],
            "relations": [
                (0, "in", 1),
            ],
        }

        sentence4 = "glass far right"
        example4 = {
            "entities": [("glass far right", "glass")],
            "relations": [],
        }

        sentence5 = "2nd to theleft brown horse drinking"
        example5 = {
            "entities": [
                ("brown horse drinking", "horse"),
                ("leftmost brown horse", "horse"),
            ],
            "relations": [
                (0, "at the right of", 1),
            ],
        }

        return self._prompt.format(
            sentence1=sentence1,
            example1=json.dumps(example1),
            sentence2=sentence2,
            example2=json.dumps(example2),
            sentence3=sentence3,
            example3=json.dumps(example3),
            sentence4=sentence4,
            example4=json.dumps(example4),
            sentence5=sentence5,
            example5=json.dumps(example5),
            sentence6=sentence,
        )

    def parse(self, sentences: Iterable[str]) -> Iterable[SceneGraph]:
        """"Extracts scene graphs from a list of sentences.

        Parameters
        ----------
        sentences : Iterable[str]
            A list of sentences to parse.

        Returns
        -------
        Iterable[SceneGraph]
            A list of scene graphs, one for each sentence.
        """

        prompts = (self._create_prompt(s) for s in sentences)

        generator = self._pipe(
            prompts,
            max_length=1000,
            num_return_sequences=1,
            do_sample=False,
            # top_k=10,
            return_full_text=False,
        )

        for output in generator:
            generated = output[0]["generated_text"]
            generated = generated.split("\n")[0]
            gen_json = json.loads(generated)

            entities = []
            for span, head in gen_json["entities"]:
                entities.append(Entity(head, span))

            triplets = []
            for s, pred, o in gen_json["relations"]:
                triplets.append(Triplet(s, pred, o))

            yield SceneGraph.new(entities=entities, triplets=triplets)

Despite the recent claims on the capabilities of open LLM, we found that the quality of the generated scene graphs was not good enough for our purposes, if not for some simple sentences. Furthermore, the inference time was too high to be able to preprocess the whole dataset in reasonable times with our limited computational resources. For example, on a single A100 GPU, the inference time for a single sentence was around 15 seconds, thus requiring more than 16 days to preprocess the whole dataset.

For these reasons, we decided to use ChatGPT (`gpt-3.5-turbo`) through the APIs made available by OpenAI. Indeed, by making ChatGPT parse mutliple sentences (`10`) for each request, we were able to reduce the time required to preprocess the whole dataset to around 10 hours. The number of sentences per request was chosen sufficiently high to reduce the inference time (and the cost of the API calls) but not too high to avoid the model to start hallucinating.

Even in this case, when generating the prompt, we had to tradeoff between the number of examples and the quality of the generated scene graphs. In particular, by increasing the number of examples, we were able to show the model a wider variety of sentences, thus improving the quality of the scene graphs. However, this also increased the inference time
and the cost of the API calls. At the end, we decided to use 7 examples that showed many typical constructions of the sentences in the dataset (for a short description of the reasons why we chose these examples, see the following code snippet).

In [None]:
import ast
import enum
import os
from typing import Any

import openai

from deepsight.data.structs import SceneGraph


class GPTModel(enum.Enum):
    GPT3_5 = "gpt-3.5-turbo"
    GPT4 = "gpt-4"

    def openai_model(self) -> str:
        return self.value


class SceneGraphParser:
    def __init__(
        self,
        api_key: str | None = None,
        model: GPTModel = GPTModel.GPT3_5,
        temperature: float = 0.2,
    ) -> None:
        """Initializes the parser with the given parameters.

        Parameters
        ----------
        api_key : str, optional
            OpenAI API key. If not provided, the token will be read from the
            environment variable OPENAI_API_KEY.
        model : GPTModel, optional
            The GPT model to use. Defaults to GPTModel.GPT3_5.
        temperature : float, optional
            The temperature to use when sampling from the model. Should be between
            0 and 2, where higher values will make the output more random, while
            lower values will make the output more focused and deterministic.
            Defaults to 0.2.
        """
        if api_key is None:
            api_key = os.getenv("OPENAI_API_KEY")
            if api_key is None:
                raise ValueError(
                    "No OpenAI API key provided. Please provide a key or set the "
                    "environment variable OPENAI_API_KEY."
                )

        openai.api_key = api_key
        self.model = model
        self.temperature = temperature

    def _build_requests(self, captions: list[str]) -> str:
        output = ""
        for caption in captions:
            output += f"<caption>{caption}</caption>\n"

        return output

    def _build_examples(self, graphs: list[SceneGraph]) -> str:
        output = ""
        for graph in graphs:
            output += f"<json>{graph.to_dict()}</json>\n"

        return output

    def _match_entity(self, entity: str, entities: list[dict[str, Any]]) -> int | None:
        """Matches the given entity to an entity in the list of entities."""

        for idx, ent in enumerate(entities):
            if entity in ent["phrase"]:
                return idx

        return None

    def _get_entity_index(
        self, entity: str | int | None, entities: list[dict[str, Any]]
    ) -> int | None:
        if entity is None:
            return None

        entity_idx: int
        if isinstance(entity, int):
            entity_idx = entity
        elif entity.isdigit():
            entity_idx = int(entity)
        else:
            idx = self._match_entity(entity, entities)
            if idx is None:
                entities.append({"noun": entity, "phrase": entity})
                entity_idx = len(entities) - 1
            else:
                entity_idx = idx

        if entity_idx >= len(entities):
            return None

        return entity_idx

    def _postprocess(self, output: dict[str, Any]) -> SceneGraph | None:
        entities = output.get("entities", [])
        if len(entities) == 0:
            return None

        triplets = output.get("triplets", [])
        new_triplets = []
        for triplet in triplets:
            subj = self._get_entity_index(triplet.get("subject"), entities)
            obj = self._get_entity_index(triplet.get("object"), entities)

            match (subj, obj):
                case (None, None):
                    continue
                case (None, obj):
                    entities[obj]["phrase"] = (
                        entities[obj]["phrase"] + " " + triplet["relation"]
                    )
                case (subj, None):
                    entities[subj]["phrase"] = (
                        entities[subj]["phrase"] + " " + triplet["relation"]
                    )
                case (subj, obj):
                    new_triplets.append(
                        {
                            "subject": subj,
                            "object": obj,
                            "relation": triplet["relation"],
                        }
                    )

        return SceneGraph.from_dict({"entities": entities, "triplets": new_triplets})

    async def parse(
        self, examples: list[tuple[str, SceneGraph]], captions: list[str]
    ) -> list[tuple[str, SceneGraph | None]]:
        """Parses the given captions into scene graphs.

        Parameters
        ----------
        examples : list[tuple[str, SceneGraph]]
            A list of examples to use for the prompt. Each example is a tuple
            consisting of a caption and the corresponding scene graph.
        captions : list[str]
            A list of captions to parse into scene graphs.

        Returns
        -------
        list[tuple[str, SceneGraph | None]]
            A list of tuples consisting of the original caption and the parsed
            scene graph. If the parsing fails due to formatting issues, the scene
            graph will be `None`.

        Raises
        ------
        RuntimeError
            If the parsing fails.
        openai.error.OpenAIError
            The error returned by the OpenAI API.
        """

        try:
            res = await openai.ChatCompletion.acreate(
                model=self.model.openai_model(),
                temperature=self.temperature,
                n=1,
                messages=[
                    {"role": "system", "content": system},
                    { # we add the example captions as previous requests from the user
                        "role": "user",
                        "content": self._build_requests([cap for cap, _ in examples]),
                    },
                    { # we add the corresponding scene graphs as previous responses from the assistant
                        "role": "assistant",
                        "content": self._build_examples(
                            [graph for _, graph in examples]
                        ),
                    },
                    { # we add the captions to parse as the next request from the user
                        "role": "user",
                        "content": self._build_requests(captions),
                    },
                ],
            )
        except openai.error.OpenAIError as e:
            raise e
        except Exception as e:
            raise RuntimeError(f"Input: {captions}") from e

        response: str = res["choices"][0]["message"]["content"]
        outputs = response.split("\n")

        results: list[tuple[str, SceneGraph | None]] = []
        for caption, output in zip(captions, outputs):
            start = output.find("{")
            end = output.rfind("}")
            output = output[start : end + 1]

            try:
                output_dict = ast.literal_eval(output)
                graph = self._postprocess(output_dict)
                if graph is not None:
                    results.append((caption, graph))
                else:
                    results.append((caption, None))
            except Exception as e:
                results.append((caption, None))
                print(f"Input: {caption} | Output: {output}")
                print(f"Exception: {e}")

        return results


# this is the firt part of the prompt
# after this, the examples are added
system = """\
You will be provided with a set of captions each describing a region in an image. \
For each region, first identify the entities, like people, objects or places, present in the region Specify both a single noun and a phrase that describes the entity. \
Then, identify the triplets of subject, relation and object that describe the relationships between the entities. \
"""  # noqa: E501


In [None]:
# This example is useful to show the model that the attributes of an entity can be non adjacent to the entity itself.
# Furthermore, it shows that the location of an entity in the scene should be considered an attribute of the entity.
example1 = (
    "There is a truck covered in snow farthest from the right",
    SceneGraph.new(
        entities=[
            Entity("truck", "a truck farthest from the right"),
            Entity("snow", "snow"),
        ],
        triplets=[
            Triplet(0, "covered in", 1),
        ],
    ),
)

# This example shows that the attributes of an entity do not need to be adjectives or adverbial phrases,
# but can also relative clauses.
example2 = (
    "A placemat is empty behind a placemat that is full",
    SceneGraph.new(
        entities=[
            Entity("placemat", "an empty placemat"),
            Entity("placemat", "a full placemat"),
        ],
        triplets=[
            Triplet(0, "behind", 1),
        ],
    ),
)

# This example condenses in one sentence the information provided by the two previous examples.
example3 = (
    "the chair not being used in the background, perpendicular to the viewer",
    SceneGraph.new(
        entities=[
            Entity("chair", "the chair not being used in the background"),
            Entity("viewer", "the viewer"),
        ],
        triplets=[
            Triplet(0, "perpendicular to", 1),
        ],
    ),
)

# This example was added because in the dataset we found many descriptions of this form,
# "the book with the title X" and we found that the model was not able to parse them.
example4 = (
    "A double decker bus with the wording The Ghost Bus Tours.com on the side.",
    SceneGraph.new(
        entities=[
            Entity("bus", "a double decker bus"),
            Entity("wording", "the wording The Ghost Bus Tours.com"),
            Entity("side", "the side"),
        ],
        triplets=[
            Triplet(0, "with", 1),
            Triplet(1, "on", 2),
        ],
    ),
)

# This example shows a quite complex sentence, with multiple entities and relations.
# In particular, it shows the model that a pronoun should not be parsed as a new entity, but as a reference to an existing entity,
# thus all relations involving the pronoun should be between the subject/object of the relation and the entity the pronoun refers to.
# Furthemore, it shows that thr subject of the caption does not need to be the subject of all relations, but can also be the object of some relations.
example5 = (
    "a man stands majestically on his skis on a snow covered area with 2 other people "
    + "behind him in the distance",
    SceneGraph.new(
        entities=[
            Entity("man", "a man"),
            Entity("ski", "his skis"),
            Entity("snow area", "a snow covered area"),
            Entity("people", "2 other people"),
        ],
        triplets=[
            Triplet(0, "stands on", 1),
            Triplet(0, "on", 2),
            Triplet(3, "behind", 0),
        ],
    ),
)

# This example shows a quite simple and linear sentence, similar to many of the sentences in the dataset.
# In particular, it shows the model that the expression "<subject> in <clothing>" (largely present in the dataset)
# should be parsed as two different entities, one for the subject and one for the clothing (in a relation) and not as a single entity.
example6 = (
    "Woman in white shirt looking down at laptop computer and " + "holding a glass",
    SceneGraph.new(
        entities=[
            Entity("woman", "woman"),
            Entity("shirt", "white shirt"),
            Entity("computer", "laptop computer"),
            Entity("glass", "a glass"),
        ],
        triplets=[
            Triplet(0, "in", 1),
            Triplet(0, "looking down at", 2),
            Triplet(0, "holding", 3),
        ],
    ),
)

# This simple example was added because we found that, based on the previous examples,
# the model still did not correctly parsed the location of an entity as an attribute of the entity.
example7 = (
    "Woman on the right",
    SceneGraph.new(entities=[Entity("woman", "woman on the right")], triplets=[]),
)

examples = [example1, example2, example3, example4, example5, example6, example7]

In [None]:
import asyncio

gpt_parser = SceneGraphParser("api key")

sentence = "The girl approaching the table while holding a glass"

# note: this request will fail since the provided api key is not valid
# if necessary, we can provide a temporary api key for testing
res = await gpt_parser.parse(examples, [sentence])
graph = res[0][1]
if graph is not None: # if the parsing was successful
    print(graph.entities())
    print(graph.triplets(None, True, False))

AuthenticationError: ignored

Some considerations on the quality of the generated scene graphs:
1. Among the methods previously illustrated, the graph generated with ChatGPT are by far the best ones, but they are still far from perfect. In particular, we noticed that ChatGPT is inconsistent in the generated scene graphs, i.e., clauses with extremely similar structures (and thus corresponding scene graphs) are associated to different scene graphs. For example, spatial locations are not always treated as attributes of the entity they refer to, but they are sometimes parsed as spatial relations, despite the mutliple example in the prompt. Furthermore, we noticed that sometimes the text associated to a relation is also associated to one of the entities involved.
2. The textual descriptions of the parsed entities and relations are strongly based on the words present in the input sentence. For example, in a sentence like *"the woman in a green shirt"*, the relation is described simply using the preposition `in` like it is in the sentence. While this is not wrong, a better description of the relation would be `wearing`. Similarly, in the sentence *"the zebra walking with its young one*", the detected entities are `the zebra` and `its young one` which is not wrong, but a better description would be `the zebra` and `the young zebra`. When generating the prompt, we tried to teach the model to not necessarily use the words present in the sentence, but the model started to hallucinate thus compromising the quality of the generated scene graphs. <br>
We noticed similar results when we tried to make the model generate different relations based on whether the subject of the relation in the scene graph is also the subject of the relation in the sentence. For example, given a sentence like *"the girl looking at the table"*, the generated scene graph would be a graph with two nodes (`the girl`, `the table`) and an undirected edge between the two representing the relation `looking at`. We tried to make the model create a directed scene graph with the relation `looking at` from `the girl` to `the table` and the relation `being looked at` from `the table` to `the girl`, but the model started to hallucinate. <br>
Notice that GPT4 is instead able to generate such more complex scene graphs, showing that it has a better "understading" of the content of the sentence. Unfortunately, the GPT4 API were not openly available at the time of the project and their cost is 20x higher than the cost of the ChatGPT API.

### Architecture modules

<img src="https://github.com/FrancescoGentile/visgator/blob/deepsight/docs/img/sgg.png?raw=true" width="800px"/>

The developed framework decomposes the model architecture (here called `Pipeline`) into four main modules:
- `PreProcessor`: module used to transform the input data before feeding it to the model. Notice that such module at training time takes in input also the target data since it may be necessary for implementing some training strategies like denoising bounding box coordinates ([Li et al. 2022](https://arxiv.org/abs/2203.01305), [Zhang et al. 2022](https://arxiv.org/abs/2203.03605)).
- `Model`: this is the core of the pipeline and it is responsible for all the main computations.
- `PostProcessor`: module used to transform the output of the model into the same format of the target data.
- `Criterion`: module used to compute the loss between the output of the model and the target data.

#### PreProcessor

In the preprocessor, we perform two main preprocessing steps. First, we resize each image such that its shortest size is 800px keeping its original aspect ratio. If by doing so, the longest size is longer than 1333px, then we resize the image such that its longest size is 1333px (still keeping the aspect ratio). This is the same strategy used by Grounding DINO. The we standardize the image using the channels mean and standard deviation used by CLIP.

The preprocessor is also the step where the scene graph is generated. Since generating the scene graph using ChatGPT has a high latency (more than 10 seconds), it is not feasible to perform this step in real time. Thus, we preprocessed the whole dataset before training the model. The generated scene graphs are then stored in a json file that is loaded by the preprocessor at initialization.

Since the candidate bounding boxes are generated from the embeddings of the subject nodes, in case the generated scene graph has more than one connected component, we remove from the scene graph all components except the one containing the subject entity. In fact, since during the `Graph Attention` step, each node can pass directly (or indirectly thorugh its neighbours) messages only with the nodes in the same connected component, nodes in different connected components would not be able to exchange messages and thus influence the embeddings of the subject nodes. Keeping them would thus be useless.

In [None]:
import json
from typing import Any

from deepsight.data.structs import Batch, RECInput, RECOutput, SceneGraph
from deepsight.data.transformations import Compose, Resize, Standardize
from deepsight.modeling.pipeline import PreProcessor as _PreProcessor
from deepsight.utils.torch import Batched3DTensors

from projects.sgg.modeling import PreprocessorConfig
from projects.sgg.modeling._structs import ModelInput


class PreProcessor(_PreProcessor[RECInput, RECOutput, ModelInput]):
    def __init__(self, config: PreprocessorConfig) -> None:
        super().__init__()

        self._preparsed: dict[str, dict[str, Any]] = {}
        if config.file is not None:
            with config.file.open("r") as f:
                self._preparsed = json.load(f)

        # self._parser = gpt.SceneGraphParser(config.token)

        self._transform = Compose(
            [
                Resize(config.side, max_size=config.max_side, p=1.0),
                Standardize(config.mean, config.std, p=1.0),
            ],
            p=1.0,
        )

    def forward(
        self,
        inputs: Batch[RECInput],
        targets: Batch[RECOutput] | None,
    ) -> ModelInput:
        graphs = []
        for inp in inputs:
            if inp.description in self._preparsed:
                scene_graph = SceneGraph.from_dict(self._preparsed[inp.description])
                # remove all nodes not connected to the root node
                # since they will never pass messages to the root node
                # (directly or through other nodes)
                scene_graph = scene_graph.node_connected_component(0)
                graphs.append(scene_graph)
            else:
                # we currently do not support real-time scene graph generation
                raise NotImplementedError

        return ModelInput(
            images=[i.image for i in inputs],
            features=Batched3DTensors.from_list(
                [self._transform(inp.image)[0].to_tensor().data for inp in inputs]
            ),
            captions=[i.description for i in inputs],
            graphs=graphs,
        )



#### Model

As previously said, the model consists of four main components:
1. Vision Encoder
2. Text Encoder
3. Object Detector
4. Decoder

##### Vision Encoder

The vision encoder used is a modified version of the ViT encoder used by CLIP. In particular, the original CLIP vision encoder returns a simgle embedding for each image. In our case, we need it to return a feature map for each image to allow each query to attend to different parts of the image. Thus, we modified the cLIP implementation by removing the attention-based pooling and taking the patch embeddings of the last layer as the feature map. Similarly to OwlViT ([Minderer et al. 2022](https://arxiv.org/abs/2205.06230)), we further multiply each patch embedding with the class token and we apply a layer norm. The authors of OwlViT state that this last operation improves the performance of the model, but no explanation if given. Probably, this allow each patch to encode information of the whole image and not only of a specific region. Since the feature dimension of each patch is 768, we also apply a final linear projection to reduce the dimension to 256 (the same dimension used by DETR-like models). To avoid losing all the information of the discarded linear projection, we initialize the weights by applying PCA to the original weights (even though no significant improvement was observed with respect to random initialization).

Another difference is that the original CLIP encoder resizes and center-crops each image to a 224x224 format. By resizing the image to such a small dimension, small details in the image may be lost, that may be important for the object detection task. Firthermore, by center cropping the image, some parts of the image thatmay contain some of the entities referred in the sentence may be cropped out. Thus, as previously said, we resize each image to a larger size and we do not apply cropping. This is similar to what is done by OwlViT. However, while OwlViT resizes each image to a fixed size, we keep the original aspect ratio and we use padding (with attention masking) to handle images of different sizes in the same batch. Similarly to OwlViT, we do not discard the learned positional embeddings of CLIP, but we simply interpolate them to the new image size.

Note: We used the CLIP vision encoder based on ViT instead of ResNet-50 because we verified that when giving in input to the encoder images with a larger size than what the model was trained on, the final embeddings were more similar to the embeddings obtained by using the 224x224 image. Thus it seems that the ViT encoder is more robust to changes in the input size.

In [None]:
import torch
import torch.nn.functional as F
from jaxtyping import Float
from sklearn.decomposition import PCA
from torch import Tensor, nn
from transformers.models.clip.modeling_clip import (
    CLIPVisionEmbeddings,
    CLIPVisionModelWithProjection,
)

from deepsight.utils.torch import Batched2DTensors, Batched3DTensors

from deepsight.modeling.layers.clip._misc import Models


class VisionEncoder(nn.Module):
    """A modified version of the CLIP [1]_ vision encoder.

    There are two main differences:
    - While the CLIP vision encoder returns a single vector representation for each
    image by pooling the patch embeddings, this encoder returns a 2D feature map for
    each image. Similarly to OwlViT [2]_, the 2D feature map is obtained by multiplying
    the patches with the class token and applying a layer norm. Each patch is then
    projected to the output dimension using a linear layer.
    - While the CLIP vision encoder requires all images to be rescaled to the same
    size (224x224 or 336x336), this encoder does not require images to be rescaled to
    the same fixed size. Instead, the positional embeddings are interpolated to the size
    of each image. This should improve the performance of the encoder on large images
    with fine-grained details.

    .. note::
        The CLIP vision encoder has an output dimension of 512 that is double what is
        used by most object detection models. Thus, if the specified `output_dim` is
        not 512, the projection layer is replaced with a linear layer that has the
        specified output dimension. The weights of the linear layer are initialized by
        applying PCA to the weights of the original projection layer.

    References
    ----------
    .. [1] Radford, A., Kim, J.W., Hallacy, C., Ramesh, A., Goh, G., Agarwal,
        S., Sastry, G., Askell, A., Mishkin, P., Clark, J. and Krueger, G., 2021, July.
        Learning transferable visual models from natural language supervision.
        In International conference on machine learning (pp. 8748-8763). PMLR.
    .. [2] Minderer, M., Gritsenko, A., Stone, A., Neumann, M., Weissenborn, D.,
        Dosovitskiy, A., Mahendran, A., Arnab, A., Dehghani, M., Shen, Z. and Wang, X.,
        2022, October. Simple open-vocabulary object detection. In European Conference
        on Computer Vision (pp. 728-755). Cham: Springer Nature Switzerland.
    """

    def __init__(self, model: Models, output_dim: int) -> None:
        super().__init__()

        clip = CLIPVisionModelWithProjection.from_pretrained(model.weights())

        vision = clip.vision_model
        self.embeddings = VisionEmbeddings(vision.embeddings)
        self.pre_layernorm = vision.pre_layrnorm
        self.encoder = vision.encoder
        self.post_layernorm = vision.post_layernorm

        self.last_layernorm = nn.LayerNorm(clip.config.hidden_size)

        projection = clip.visual_projection
        if projection.out_features != output_dim:
            weights = projection.weight.transpose(0, 1).detach().numpy()
            weights = PCA(output_dim).fit_transform(weights)
            self.projection = nn.Linear(
                in_features=projection.in_features,
                out_features=output_dim,
                bias=False,
            )

            with torch.no_grad():
                self.projection.weight = nn.Parameter(
                    torch.from_numpy(weights).transpose(0, 1)
                )

        else:
            self.projection = projection

    def _create_attention_mask(self, x: Batched2DTensors) -> Float[Tensor, "B 1 L L"]:
        """Creates an attention mask to mask out the padding tokens.

        Parameters
        ----------
        x : Batched2DTensors
            The input flattened image tensors.

        Returns
        -------
        Float[Tensor, "B 1 L L"]
            The attention mask.
        """

        mask = x.mask[:, None, None, :].expand(-1, 1, x.shape[1], -1)

        dtype = x.tensor.dtype
        attn_mask = torch.zeros_like(mask, dtype=dtype)
        attn_mask.masked_fill_(mask, -torch.inf)

        return attn_mask

    def forward(self, images: Batched3DTensors) -> Batched3DTensors:
        x, new_sizes = self.embeddings(images)  # (B, 1+HW, C)

        attn_mask = self._create_attention_mask(x)

        hidden: Tensor = self.pre_layernorm(x.tensor)
        tmp = self.encoder(
            inputs_embeds=hidden,
            output_attentions=False,
            output_hidden_states=False,
            attention_mask=attn_mask,
            return_dict=True,
        )

        hidden = tmp.last_hidden_state  # (B, 1+HW, C)
        hidden = self.post_layernorm(hidden)

        class_token = hidden[:, :1]  # (B, 1, C)
        image_embeds = hidden[:, 1:]  # (B, HW, C)

        out: Tensor = class_token * image_embeds
        out = self.last_layernorm(out)

        out = self.projection(out)  # (B, HW, D)

        H = max(size[0] for size in new_sizes)
        W = max(size[1] for size in new_sizes)

        out = out.view(out.shape[0], H, W, -1).permute(0, 3, 1, 2)  # (B, D, H, W)

        return Batched3DTensors(out, sizes=new_sizes)

    def __call__(self, images: Batched3DTensors) -> Batched3DTensors:
        return super().__call__(images)  # type: ignore


class VisionEmbeddings(nn.Module):
    """A wrapper around the CLIP vision embeddings.

    This wrapper allows the CLIP vision encoder to work with batches of images of
    different sizes. To avoid discarding the learned positional embeddings, the
    positional embeddings are interpolated to the size of each image.
    """

    def __init__(self, embeddings: CLIPVisionEmbeddings) -> None:
        super().__init__()

        self.patch_embedding = embeddings.patch_embedding
        self.class_embedding = embeddings.class_embedding

        h, w = (int(embeddings.num_patches**0.5),) * 2
        patch_pos_embedding = embeddings.position_embedding.weight.data[1:]
        patch_pos_embedding = patch_pos_embedding.reshape(h, w, -1).permute(2, 0, 1)
        class_pos_embedding = embeddings.position_embedding.weight.data[0]

        self.patch_pos_embedding = nn.Parameter(patch_pos_embedding)
        self.class_pos_embedding = nn.Parameter(class_pos_embedding)

    def _compute_new_size(self, old_size: tuple[int, int]) -> tuple[int, int]:
        """Computes the new size of the image after patch embedding.

        Parameters
        ----------
        old_size : tuple[int, int]
            The size of the image before patch embedding.

        Returns
        -------
        tuple[int, int]
            The size of the image after patch embedding.
        """

        kh, kw = self.patch_embedding.kernel_size
        sh, sw = self.patch_embedding.stride
        ph, pw = self.patch_embedding.padding

        H, W = old_size
        h = (H + 2 * ph - kh) // sh + 1
        w = (W + 2 * pw - kw) // sw + 1

        return h, w

    def forward(
        self, images: Batched3DTensors
    ) -> tuple[Batched2DTensors, list[tuple[int, int]]]:
        B = len(images)
        x: Tensor = self.patch_embedding(images.tensor)

        new_sizes = []
        patch_pos_emb = torch.zeros_like(x)  # (B, C, H, W)

        for idx in range(len(x)):
            h, w = self._compute_new_size(images.sizes[idx])
            new_sizes.append((h, w))

            # resize the positional embeddings to the new size of the image
            emb = F.interpolate(
                self.patch_pos_embedding[None],
                size=(h, w),
                mode="bilinear",
                align_corners=False,
            )[0]

            patch_pos_emb[idx, :, :h, :w] = emb

        patch_pos_emb = patch_pos_emb.flatten(2).transpose(1, 2)  # (B, HW, C)
        class_pos_emb = self.class_pos_embedding.expand(B, 1, -1)  # (B, 1, C)
        pos_emb = torch.cat([class_pos_emb, patch_pos_emb], dim=1)  # (B, 1+HW, C)

        class_token = self.class_embedding.expand(B, 1, -1)  # (B, 1, C)
        x = x.flatten(2).transpose(1, 2)  # (B, HW, C)
        x = torch.cat([class_token, x], dim=1)  # (B, 1+HW, C)

        x = x + pos_emb

        out = Batched2DTensors(x, sizes=[(1 + h * w) for h, w in new_sizes])

        return out, new_sizes

    def __call__(
        self, images: Batched3DTensors
    ) -> tuple[Batched2DTensors, list[tuple[int, int]]]:
        return super().__call__(images)  # type: ignore


##### Text Encoder

The text encoder used is the CLIP text encoder. Thus, differently from the vision encoder, a sentence consisting of many words is encoded into a single embedding and not into a sequence of embeddings.

The only difference is that, to make the output dimension match the feature dimension of each visual patch, we change the last linear projection with a new one that reduces the dimension to 256. As for the vision encoder, we initialize the weights of this linear projection by applying PCA to the original weights.

In [None]:
import torch
from jaxtyping import Float
from sklearn.decomposition import PCA
from torch import Tensor, nn
from transformers.models.clip.modeling_clip import (
    CLIPTextModelWithProjection,
)
from transformers.models.clip.processing_clip import CLIPProcessor


class TextEncoder(nn.Module):
    """A wrapper around the CLIP [1]_ text encoder.

    .. note::
        To make the output dimension of the text encoder match the output dimension
        of the vision encoder, if the specified `output_dim` is different from
        the dimension of the text encoder's projection layer, the projection layer
        is replaced with a new linear layer that has the specified output dimension.
        The weights of the linear layer are initialized by applying PCA to the weights
        of the original projection layer.

    References
    ----------
    .. [1] Radford, A., Kim, J.W., Hallacy, C., Ramesh, A., Goh, G., Agarwal,
        S., Sastry, G., Askell, A., Mishkin, P., Clark, J. and Krueger, G., 2021, July.
        Learning transferable visual models from natural language supervision.
        In International conference on machine learning (pp. 8748-8763). PMLR."""

    def __init__(self, model: Models, output_dim: int) -> None:
        super().__init__()

        self._dummy = nn.Parameter(torch.empty(0))

        self.processor = CLIPProcessor.from_pretrained(model.weights())
        self.transformer = CLIPTextModelWithProjection.from_pretrained(model.weights())

        projection = self.transformer.text_projection
        if projection.out_features != output_dim:
            weights = projection.weight.transpose(0, 1).detach().numpy()
            weights = PCA(output_dim).fit_transform(weights)
            self.projection = nn.Linear(
                in_features=projection.in_features,
                out_features=output_dim,
                bias=False,
            )

            with torch.no_grad():
                self.projection.weight = nn.Parameter(
                    torch.from_numpy(weights).transpose(0, 1)
                )

    def forward(self, text: list[str]) -> Float[Tensor, "N D"]:
        inputs = self.processor(text=text, return_tensors="pt", padding=True)
        input_ids = inputs["input_ids"].to(self._dummy.device)
        attention_mask = inputs["attention_mask"].to(self._dummy.device)

        x = self.transformer(
            input_ids=input_ids, attention_mask=attention_mask, return_dict=True
        )
        text_embeds = x.text_embeds  # (N, D)
        out: Tensor = self.projection(text_embeds)

        return out

    def __call__(self, text: list[str]) -> Float[Tensor, "N D"]:
        """Encodes each text in the batch into a vector.

        Parameters
        ----------
        text : list[str]
            A list of texts to encode.

        Returns
        -------
        Float[Tensor, "N D"]
            Tensor of shape (N, D) where N is the number of texts in the batch and D
            is the output dimension of the text encoder.
        """

        return super().__call__(text)  # type: ignore


##### OwlViT

As open-set object detector we decided to use OwlVit ([Minderer et al. 2022](https://arxiv.org/abs/2205.06230)) given its not poor performances (both in terms of inference speed and localization capabilities) and simplicity. Indeed, OwlViT is based on a pair of vision and text encoders that are pretrained on a contrastive image-text task, exactly like CLIP. Then, such architecture is finetuned on object detection by adding a simple regression head on top of the vision encoder. In particular, each patch (outputted by the vision encoder) is given in input to the regression head to predict the bounding boxes offsets with repsect to the patch position. The similarity between each patch and the text embeddings of the entities to detect is computed to obtain the confidence score that the patch contains the entity. A threshold is then applied to the confidence scores to obtain the final set of bounding boxes.

Regarding the implementation, we used the one provided by [HuggingFace](https://huggingface.co/docs/transformers/model_doc/owlvit) with a modification to handle the case in which a given entity is not detected. In particular, if no instance of an entity is detected with a confidence score higher than the threshold, we select the __k__ most confident bounding boxes for that entity even if their confidence score is lower than the threshold. Even if the estimated locations of an entity are not accurate, we consider this approach better than not detecting the entity at all. Indeed, if the input scene graph is (`the girl` -- `approaching` --> `the table`), if no table is detected, the new graph will consist only of nodes associated to the entity `the girl`, thus missing the information that `the girl` that needs to be found is the one that is approaching `the table`.

Other than OwlViT we also tried to use Grounding DINO as object detector. However, despite the better performance with respect to OwlViT, the inference speed was too slow making the training time for a single epoch more than double.

In [None]:
import torch
from jaxtyping import Float
from torch import Tensor, nn
from transformers import OwlViTForObjectDetection, OwlViTProcessor

from deepsight.data.structs import (
    Batch,
    BoundingBoxes,
    BoundingBoxFormat,
    ODInput,
    ODOutput,
)


class OwlViT(nn.Module):
    """Wrapper around the OwlViT model for open-set detection.

    With respect to the original OwlViT model, this wrapper adds the possibility
    to return the bounding boxes even when the confidence of the entity is below
    a certain threshold.
    """

    def __init__(self, threshold: float, num_boxes: int | None = None) -> None:
        """Initializes the OwlViT model.

        Parameters
        ----------
        threshold : float
            The threshold used to determine whether an entity is present in the
            input image.
        num_boxes : int | None
            If not None, when an entity is not found with a confidence above the
            `threshold`, the model will return the top `num_boxes` boxes with the
            highest confidence scores for that entity. If None, the model will
            return only the entities that are found with a confidence above the
            `threshold`. Defaults to None.
        """

        super().__init__()

        self._threshold = threshold
        self._num_boxes = num_boxes

        self._dummy = nn.Parameter(torch.empty(0))

        model_id = "google/owlvit-base-patch32"
        processor = OwlViTProcessor.from_pretrained(model_id)
        model = OwlViTForObjectDetection.from_pretrained(model_id)

        self.processor = processor
        self.owlvit = model.owlvit
        self.class_head = model.class_head
        self.box_head = model.box_head

        self.layer_norm = model.layer_norm

    def _get_boxes(
        self, image_embeds: Float[Tensor, "B L D"]
    ) -> Float[Tensor, "B L 4"]:
        """Returns for each patch the corresponding bounding box.

        The bounding box associated to a patch is obtained by computing the coordinates offsets
        with respect to the center of the patch using a simple regression head.

        Parameters
        ----------
        image_embeds : Float[Tensor, "B L D"]
            The image embeddings obtained from the OwlViT model. The shape is (B, L, D) where
            B is the batch size, L is the number of patches and D is the dimension of the
            embeddings.

        Returns
        -------
        Float[Tensor, "B L 4"]
            The bounding boxes associated to each patch. The shape is (B, L, 4) where B is the
            batch size, L is the number of patches and 4 are the normalized coordinates of the bounding box
            in the format (center_x, center_y, width, height).
        """

        L = image_embeds.shape[1]
        side = int(L**0.5)
        device = image_embeds.device
        dtype = image_embeds.dtype

        coords = torch.stack(
            torch.meshgrid(
                torch.arange(1, side + 1, device=device, dtype=dtype),
                torch.arange(1, side + 1, device=device, dtype=dtype),
                indexing="xy",
            ),
            dim=-1,
        )
        coords = coords / side
        coords = coords.view(L, 2)

        coords = torch.clamp(coords, 0.0, 1.0)  # (L, 2)
        coord_bias = torch.log(coords + 1e-4) - torch.log1p(-coords + 1e-4)

        size = torch.full_like(coord_bias, 1.0 / side)  # (L, 2)
        size_bias = torch.log(size + 1e-4) - torch.log1p(-size + 1e-4)

        box_bias = torch.cat((coord_bias, size_bias), dim=-1)  # (L, 4)

        pred_boxes: Tensor = self.box_head(image_embeds)  # (B, L, 4)
        pred_boxes = pred_boxes + box_bias  # (B, L, 4)
        pred_boxes = torch.sigmoid(pred_boxes)  # (B, L, 4)

        return pred_boxes

    def forward(self, inputs: Batch[ODInput]) -> Batch[ODOutput]:
        images = [inp.image.to_pil().data for inp in inputs]

        # Create list of list of entities and remove duplicates
        entities: list[list[str]] = []
        str_to_idx: list[dict[str, list[int]]] = []
        for inp in inputs:
            sample_entities = []
            sample_str_to_idx = {}
            for ent_idx, ent in enumerate(inp.entities):
                # Since the OwlViT was trained by adding the prefix "a photo of a"
                # to each category name, we do the same here.
                # Indeed we find that the model is much more accurate and confident
                # when the prefix is added.
                ent = f"a photo of a {ent}"
                if ent not in sample_str_to_idx:
                    sample_str_to_idx[ent] = [ent_idx]
                    sample_entities.append(ent)
                else:
                    sample_str_to_idx[ent].append(ent_idx)

            entities.append(sample_entities)
            str_to_idx.append(sample_str_to_idx)

        B = len(images)
        max_queries = max(len(ent) for ent in entities)

        tmp = self.processor(
            images=images, text=entities, return_tensors="pt", truncation=True
        )
        pixel_values = tmp["pixel_values"].to(self._dummy.device)
        input_ids = tmp["input_ids"].to(self._dummy.device)
        attention_mask = tmp["attention_mask"].to(self._dummy.device)

        outputs = self.owlvit(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )

        image_tokens = outputs.vision_model_output[0]  # (B, 1+L, D)
        image_tokens = self.owlvit.vision_model.post_layernorm(
            image_tokens
        )  # (B, 1+L, D)

        class_token = image_tokens[:, :1, :]  # (B, 1, D)
        image_patches = image_tokens[:, 1:, :]  # (B, L, D)
        image_embeds = image_patches * class_token  # (B, L, D)
        image_embeds = self.layer_norm(image_embeds)  # (B, L, D)

        query_embeds = outputs[-4]  # (BQ, D) where BQ = B * max_queries
        query_embeds = query_embeds.view(B, max_queries, -1)  # (B, Q, D)

        query_mask = torch.zeros(
            (B, max_queries),
            dtype=torch.bool,
            device=query_embeds.device,
        )
        for sample_idx, sample_entities in enumerate(entities):
            query_mask[sample_idx, : len(sample_entities)] = True

        # (B, L, Q) means that each image patch is compared to each query
        # embedding, and the result is a scalar value representing the
        # similarity between the two.
        pred_logits, _ = self.class_head(image_embeds, query_embeds, query_mask)
        pred_boxes = self._get_boxes(image_embeds)

        probs, labels = pred_logits.max(dim=-1)  # (B, L)
        if self._num_boxes is not None:
            _, top_index_per_query = torch.topk(
                pred_logits, self._num_boxes, dim=1
            )  # (B, K, Q)

        scores = probs.sigmoid()  # (B, L)

        results = []
        for sample_idx in range(B):
            mask = scores[sample_idx] > self._threshold

            sample_pred_boxes = pred_boxes[sample_idx, mask]  # (N, 4)
            sample_pred_labels = labels[sample_idx, mask]  # (N,)
            sample_pred_scores = scores[sample_idx, mask]  # (N,)

            boxes_list: list[Tensor] = []
            labels_list: list[int] = []
            scores_list: list[Tensor] = []

            for ent_idx, ent in enumerate(entities[sample_idx]):
                indices = sample_pred_labels == ent_idx
                num_found = indices.sum().item()
                if num_found > 0:
                    # entity found in image
                    # add all duplicates to the list
                    for j in str_to_idx[sample_idx][ent]:
                        boxes_list.append(sample_pred_boxes[indices])
                        labels_list.extend([j] * num_found)
                        scores_list.append(sample_pred_scores[indices])
                elif self._num_boxes is not None:
                    # entity not found in image
                    # add top K boxes for the entity
                    topk_boxes = pred_boxes[
                        sample_idx, top_index_per_query[sample_idx, :, ent_idx]
                    ]

                    topk_scores = scores[
                        sample_idx, top_index_per_query[sample_idx, :, ent_idx]
                    ]

                    for j in str_to_idx[sample_idx][ent]:
                        boxes_list.append(topk_boxes)
                        labels_list.extend([j] * self._num_boxes)
                        scores_list.append(topk_scores * self._num_boxes)

            sample_labels = torch.tensor(
                labels_list, dtype=torch.long, device=self._dummy.device
            )

            boxes = BoundingBoxes(
                tensor=torch.cat(boxes_list, dim=0),
                images_size=inputs[sample_idx].image.size,
                format=BoundingBoxFormat.CXCYWH,
                normalized=True,
            )

            sample_scores = torch.cat(scores_list, dim=0)

            results.append(
                ODOutput(
                    boxes=boxes,
                    entities=sample_labels,
                    scores=sample_scores,
                )
            )

        return Batch(results)

    def __call__(self, inputs: Batch[ODInput]) -> Batch[ODOutput]:
        """Given a batch of images and entities to be detected, returns a list
        of OSDOutput objects containing the bounding boxes of the detected entities.

        .. note::
            In case of duplicate entities for the same image, this implementation will
            return the same bounding boxes for both entities. This is different from the
            implementation of HuggingFace's OwlViT model which returns the bounding
            boxes only for one of the entities (usually the first one).

        Parameters
        ----------
        inputs : Batch[OSDInput]
            A Batch object containing OSDInput objects containing the images and
            entities.

        Returns
        -------
        Batch[OSDOutput]
            A Batch object containing the OSDOutput objects containing the bounding
            boxes.
        """

        return super().__call__(inputs)  # type: ignore


##### Decoder

In DETR-like models, the decoder contains a self-attention step to make each query proposal attend to all the other query proposals allowing them to exchange information. Such information can be used for example to avoid two queries to be assigned to the same entity or to suggest that an entity may be present (for example, if a query is associated to a ball, then there may be a kid in the adjacent region). In our case, however an entity node does not need to attend to all the other entities but only to the ones with which it should have a relation according to the scene graph. However, we also add edges (whose features are initialized with learnable parameters) between all nodes associated to the same entity to avoid multiple nodes to focus on the same entity instance.

By exchaning information with its neighbours, each node can determine whether the associated entity instance is the one referred in the sentence. For example, if the sentence were *"the girl approaching the table"*, the generated scene graph would be (`the girl` -- `approaching` -> `the table`). If two different instances of a girl are found in the image, we need to choose which instance is the right one. Suppose also that only one instance of table is detected. When passing messages with the `table` node, each `girl` node can determine whether it is the right one by checking whether the relation `approaching` exists between the two nodes.

For these reasons, to update the graph features, we do not use the self-attention mechanism (that can be considered a [graph operation](https://thegradient.pub/transformers-are-graph-neural-networks/)), but we use a more traditional message passing graph neural network (MPNN). In particular, we decided to use Graph Attention ([Veličković et al. 2017](https://arxiv.org/abs/1710.10903), [Brody at el. 2021](https://arxiv.org/abs/2105.14491)), given its simplicity and the good results obtained in many tasks (there are more performant graph operations but since we operate on very small graphs the advantages should not be significant).

In the second version of Graph Attention ([Brody at el. 2021](https://arxiv.org/abs/2105.14491)), the node features are updated as follows:
1. the messages passed by the node's neighbours (usually, the one hop neighourhood) are collected;
2. the messages are aggregated, each weighted by an attention coefficient that measures the relevance of the message for the node (to compute the relance, edge features can also be used if present);
3. the current node features are combined with the aggregated messages to obtain the new node features.

In formulas:
$$ x_i' = \alpha_{ii} \theta x_i + \sum_{j \in N_i}{\alpha_{ij} \theta x_j} $$
and the attention coefficients are computed as:
$$ \alpha_{ij} = \frac{\exp{(W_2^\top \textrm{LeakyReLU}(W_1 [x_i || x_j || e_{ij}]))}}{\sum_{k \in N_i \cup \{i\}}{\exp{(W_2^\top \textrm{LeakyReLU}(W_1 [x_i || x_k || e_{ik}]))}}} $$

where $\theta$, $W_1$ and $W_2$ are learnable parameters, $e_{ij}$ is the edge feature between node $i$ and $j$ and $||$ is the concatenation operator. Usually, multi-head attention is used, where different attention coefficients are computed for each head, thus obtaining different node features for each head. The final node features are obtained by concatenating the node features of each head.

With respect to such formulation, we make two main modifications:
1. we do not use $\textrm{LeakyReLU}$ but we use $\textrm{GELU}$ in the attention coefficients computation;
2. we add a step to also update the edge features. In particular, we compute the edge features as:
$$ e_{ij}' = W_3 \textrm{GELU}(W_1 [x_i || x_j || e_{ij}]) $$
that is we concatenate the endpoint nodes features and the edge features and we apply a MLP with one hidden layer to obtain the new edge features.

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from torch_scatter import scatter_add, scatter_softmax

from deepsight.utils.torch import BatchedGraphs


class GATConv(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool = True,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()

        if embed_dim % num_heads != 0:
            raise ValueError(
                f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
            )

        head_dim = embed_dim // num_heads
        self.num_heads = num_heads

        self.first_node_proj = nn.Linear(embed_dim, embed_dim, bias)
        self.second_node_proj = nn.Linear(embed_dim, embed_dim, bias)
        self.edge_proj = nn.Linear(embed_dim, embed_dim, bias)
        self.attn_proj = nn.Parameter(torch.randn(1, num_heads, head_dim))

        self.node_out_proj = nn.Linear(embed_dim, embed_dim, bias)
        self.edge_out_proj = nn.Linear(embed_dim, embed_dim, bias)

        self.attn_dropout = nn.Dropout(dropout)

    def forward(
        self,
        graphs: BatchedGraphs,
        embeddings: BatchedGraphs | None = None,
    ) -> BatchedGraphs:
        nodes = graphs.nodes(None)
        edges = graphs.edges(None)

        N, _ = nodes.shape
        E, _ = edges.shape
        H = self.num_heads

        query_nodes = nodes
        query_edges = edges

        if embeddings is not None:
            query_nodes = query_nodes + embeddings.nodes(None)
            query_edges = query_edges + embeddings.edges(None)

        first_node = self.first_node_proj(query_nodes)[graphs.edge_indices[0]]
        second_node = self.second_node_proj(query_nodes)[graphs.edge_indices[1]]
        query_edges = self.edge_proj(query_edges)

        hidden = first_node + second_node + query_edges
        hidden = F.gelu(hidden)

        hidden_head = hidden.view(E, H, -1)
        presoftmax_alpha = (hidden_head * self.attn_proj).sum(dim=-1)  # (E, H)
        alpha = scatter_softmax(presoftmax_alpha, graphs.edge_indices[0], dim=0)
        alpha = self.attn_dropout(alpha)

        new_edges = self.edge_out_proj(hidden)
        values = nodes[graphs.edge_indices[1]] + new_edges
        values = self.node_out_proj(values)
        values = values.view(E, H, -1)
        values = values * alpha.unsqueeze(-1)
        new_nodes = scatter_add(values, graphs.edge_indices[0], dim=0)
        new_nodes = new_nodes.view(N, -1)

        return graphs.new_like(nodes=new_nodes, edges=new_edges)

    def __call__(
        self,
        graphs: BatchedGraphs,
        embeddings: BatchedGraphs | None = None,
    ) -> BatchedGraphs:
        return super().__call__(graphs, embeddings)  # type: ignore


In DETR-like models, in the self attention step before projecting the box queries into the attention queries and keys, each box query embedding is summed with an embedding encoding the spatial position of the box ([Men et al. 2021](https://arxiv.org/abs/2108.06152), [Liu et al. 2022](https://arxiv.org/abs/2201.12329)). In our model, we do the same. In particular, the four coordinates of the bounding box are encoded using sinusoidal functions as in [Vaswani et al. 2017](https://arxiv.org/abs/1706.03762). Inspired by [Zhang et al. 2020](https://arxiv.org/abs/2012.06060), we also add to each edge feature a positional encoding that encodes the spatial relation between the two endpoint nodes. In particular, we use the same sinusoidal functions to encode the difference between the center of the two bounding boxes, their intersection over union (IoU) and their union.

Since all the values (coordinates, IoU and union) are normalized between 0 and 1, following [Liu et al. 2022](https://arxiv.org/abs/2201.12329), we set the temperature for the sinusoidal functions to 20 instead of 10000.

In [None]:
import torch
from jaxtyping import Float
from torch import Tensor, nn

from deepsight.data.structs import BoundingBoxes


class SinusoidalBoxEmbeddings(nn.Module):
    """Sinusoidal box embeddings.

    This module computes sinusoidal embeddings for a set of bounding boxes. The
    embeddings are computed by applying sinusoidal functions (as described in [1]_)
    to the coordinates of the boxes and concatenating the results.

    .. note::
        Since such position embeddings are intended to be matched with the ones
        computed for the feature maps, the coordinates embeddings are concatenated
        in the following order: cx, cy, (w, h).

    Attributes
    ----------
    dim : int
        The final embedding dimension. This is equal to the feature dimension used
        for each coordinate times the number of coordinates used (2 or 4). If
        `include_wh` is `True`, the dimension must be divisible by 4, otherwise it
        must be even.
    temperature : float
        The temperature of the sinusoidal function. Defaults to `20`.
    include_wh : bool
        Whether to include the width and height of the boxes in the embeddings.
        Defaults to `False`.
    """

    def __init__(
        self,
        dim: int,
        temperature: int = 20,
        scale: float = 2 * torch.pi,
        include_wh: bool = False,
    ) -> None:
        super().__init__()

        if include_wh:
            if dim % 4 != 0:
                raise ValueError(f"dim must be divisible by 4, got {dim}.")
        else:
            if dim % 2 != 0:
                raise ValueError(f"dim must be even, got {dim}.")

        self.dim = dim
        self.temperature = temperature
        self.scale = scale
        self.include_wh = include_wh

    def forward(self, boxes: BoundingBoxes) -> Float[Tensor, "... D"]:
        boxes = boxes.to_cxcywh().normalize()

        dim = self.dim // 4 if self.include_wh else self.dim // 2  # D
        temperature = self.temperature

        dim_t = torch.arange(dim, dtype=torch.float32, device=boxes.device)  # (D,)
        dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)

        if self.include_wh:
            coords = boxes.tensor  # (..., 4)
        else:
            coords = boxes.tensor[..., :2]  # (..., 2)

        pos = coords.unsqueeze(-1) * self.scale / dim_t
        pos = torch.stack((pos[..., 0::2].sin(), pos[..., 1::2].cos()), dim=-1)
        pos = pos.flatten(start_dim=-3)  # (..., D)

        return pos

    def __call__(self, boxes: BoundingBoxes) -> Float[Tensor, "... D"]:
        """Computes the embeddings for the given boxes.

        Parameters
        ----------
        boxes : BoundingBoxes
            The bounding boxes to compute the embeddings for. The bounding boxes
            tensor can have any number of leading dimension.

        Returns
        -------
        Float[Tensor, "... D"]
            The computed embeddings. The number of leading dimensions of the returned
            tensor is equal to the number of leading dimensions of the bounding boxes
            tensor. The last dimension is equal to the embedding dimension.
        """

        return super().__call__(boxes)  # type: ignore



class SinusoidalPairwiseBoxEmbeddings(nn.Module):
    """Sinusoidal embeddings for pairs of bounding boxes."""

    def __init__(self, dim: int, temperature: int = 20) -> None:
        super().__init__()

        if dim % 4 != 0:
            raise ValueError(f"dim must be divisible by 4, got {dim}.")

        self.dim = dim
        self.temperature = temperature

    def forward(
        self,
        first: BoundingBoxes,
        second: BoundingBoxes,
    ) -> Float[Tensor, "... D"]:
        first = first.to_cxcywh().normalize()
        second = second.to_cxcywh().normalize()

        distance = first.tensor[..., :2] - second.tensor[..., :2]  # (..., 2)
        iou = first.iou(second).unsqueeze(-1)  # (..., 1)
        union = first.union(second).area().unsqueeze(-1)  # (..., 1)

        coords = torch.cat((distance, iou, union), dim=-1)  # (..., 4)
        coords = coords.unsqueeze(-1)  # (..., 4, 1)

        dim = self.dim // 4  # D
        temperature = self.temperature

        dim_t = torch.arange(dim, dtype=torch.float32, device=coords.device)  # (D,)
        dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)

        pos = coords / dim_t
        pos = torch.stack((pos[..., 0::2].sin(), pos[..., 1::2].cos()), dim=-1)
        pos = pos.flatten(start_dim=-3)  # (..., D)

        return pos

    def __call__(
        self,
        first: BoundingBoxes,
        second: BoundingBoxes,
    ) -> Float[Tensor, "... D"]:
        return super().__call__(first, second)  # type: ignore

In [None]:
import torch
from jaxtyping import Bool, Float
from torch import Tensor, nn

from deepsight.data.structs import BoundingBoxes


class GaussianHeatmaps(nn.Module):
    """Gaussian heatmaps.

    This module computes a Gaussian heatmap for a set of bounding boxes.
    The gaussian is centered at the center of the bounding box and its standard
    devation in each direction is equal to the corresponding side of the bounding
    box.

    These heatmaps can be used to bias the attention of each query to a specific region
    of the image during the cross-attention step of DETR-like models. See [1]_ for more
    details.

    Attributes
    ----------
    beta : float
        A scaling factor for the gaussian standard deviation. The smaller the value,
        the more concentrated the gaussian will be around the center of the bounding
        box. Defaults to 1.0.

    References
    ----------
    .. [1] Gao, P., Zheng, M., Wang, X., Dai, J. and Li, H., 2021. Fast convergence of
        detr with spatially modulated co-attention. In Proceedings of the IEEE/CVF
        international conference on computer vision (pp. 3621-3630).
    """

    def __init__(self, beta: float = 1.0) -> None:
        super().__init__()

        self.beta = beta

    def forward(
        self,
        boxes: BoundingBoxes,
        mask: Bool[Tensor, "... H W"],
    ) -> Float[Tensor, "... H W"]:
        boxes = boxes.to_cxcywh().normalize()

        mean = boxes.tensor[..., :2]  # (..., 2)
        std = boxes.tensor[..., 2:]  # (..., 2)

        not_mask = ~mask
        y_coords = not_mask.cumsum(dim=-2, dtype=torch.float32)  # (..., H, W)
        x_coords = not_mask.cumsum(dim=-1, dtype=torch.float32)  # (..., H, W)

        eps = torch.finfo(torch.float32).eps
        y_coords = y_coords / (y_coords[..., -1:, :] + eps)
        x_coords = x_coords / (x_coords[..., -1:] + eps)

        y = (y_coords - mean[..., 1, None, None]) ** 2
        y = y / (self.beta * (std[..., 1, None, None] ** 2))

        x = (x_coords - mean[..., 0, None, None]) ** 2
        x = x / (self.beta * (std[..., 0, None, None] ** 2))

        out: Tensor = torch.exp(-(x + y))  # (..., H, W)
        out.masked_fill_(mask, 0.0)

        return out

    def __call__(
        self,
        boxes: BoundingBoxes,
        mask: Bool[Tensor, "... H W"],
    ) -> Float[Tensor, "... H W"]:
        """Computes log-Gaussian heatmaps.

        Parameters
        ----------
        boxes : BoundingBoxes
            The bounding boxes for which to compute the heatmaps. The bounding boxes
            tensor can have any number of leading dimensions.
        mask : Bool[Tensor, "... H W"]
            A boolean mask indicating which pixels in the heatmaps should be considered
            as padding, i.e. which pixels are outside the image. The mask tensor must
            have the same number of leading dimensions as the bounding boxes tensor.

        Returns
        -------
        Float[Tensor, "... H W"]
            The computed heatmaps. The number of leading dimensions of the returned
            tensor is equal to the number of leading dimensions of the bounding boxes
            tensor. The last two dimensions are equal to the height and width of the
            heatmaps, respectively.
        """

        return super().__call__(boxes, mask)  # type: ignore


The cross-attention operation in the decoder is not different from the one used in DETR-like models, that is the nodes and edges of the graph returned by the last decoder layer are used as queries, while the visual features outputted by the encoder are used to compute the keys and values. This allows each node to detect whether in the associated region there is the corresponding entity and allows each edge to detect whether the relation is present between the two endpoint nodes.

Recent works ([Gao et al. 2021](https://arxiv.org/abs/2101.07448), [Men et al. 2021](https://arxiv.org/abs/2108.06152), [Liu et al. 2022](https://arxiv.org/abs/2201.12329)) have argued that one the reasons why DETR requires extremely long training time is that, since in cross-attention each box query attend to all the patches in the image, the network needs to learn how to make each box focus only on a small area of the image by giving high importance only to the patches associated to that area. To make it easier for the network to do this, such works propose different techniques. For example, [Men et al. 2021](https://arxiv.org/abs/2108.06152) propose to concatenate to each patch its positional encoding and to do the same for each query box. Thus, boxes and patches with close spatial positions will have similar positional encodings. In this way, when computing the dot product between them, the similarity will be higher and higher importance will be given to that patch.

In our model, we adopt the technique proposed by [Gao et al. 2021](https://arxiv.org/abs/2101.07448). In particular, given a node with associated boundign box coordinates $(c_x, c_y, w, h)$, we compute a Gaussian-like weight map as
$$ G(i, j) = \exp{(- \frac{(i - c_x)^2}{\beta w^2} - \frac{(j - c_y)^2}{\beta h^2} )} $$
where $(i, j) \in [0, W] \times [0, H] $ is the position of a patch in the feature map; $\beta$ is an hyperparameter (here, set to 1) to modulate the width of the Gaussian-like function. When computing the cross-attention matrix between a node and the patches, we sum this weight map to the dot product between the node and the patches. In this way, the network is encouraged to give more importance to the patches close to the center of the node bounding box. For the edges, we compute the Gaussian-like weight map by choosing for each position $(i, j)$ the maximum value between the weight maps of the edge endpoints and the weight map computed using the smallest bounding box containing the two bounding boxes of the endpoints. Thus, similarly to previous works ([Wang et al. 2020](https://arxiv.org/abs/2003.14023)), we assume that the key information to detect a relation is always in the middle of the two bounding boxes, even if it has been shown that this is not always true ([Tamura et al. 2021](https://arxiv.org/abs/2103.05399)).

In [None]:
import torch
from jaxtyping import Float
from torch import Tensor, nn

from deepsight.data.structs import BoundingBoxes
from deepsight.modeling.layers import LayerScale
from deepsight.utils.torch import Batched3DTensors, BatchedGraphs

from projects.sgg.modeling import DecoderConfig


class Decoder(nn.Module):
    def __init__(self, config: DecoderConfig) -> None:
        super().__init__()

        self._num_heads = config.num_heads

        self.gaussian_heatmaps = GaussianHeatmaps()
        self.node_embeddings = SinusoidalBoxEmbeddings(
            config.hidden_dim, include_wh=True
        )
        self.edge_embeddings = SinusoidalPairwiseBoxEmbeddings(config.hidden_dim)

        self.layers = nn.ModuleList(
            [DecoderLayer(config) for _ in range(config.num_layers)]
        )

    def forward(
        self,
        features: Batched3DTensors,
        graphs: BatchedGraphs,
        boxes: BoundingBoxes,
    ) -> list[BatchedGraphs]:
        H, W = features.shape[-2:]

        # Visual queries matched with the graph nodes, aligned with the edge connections
        edge_indices = graphs.edge_indices  # (2, E)
        first_boxes = boxes[edge_indices[0]]
        second_boxes = boxes[edge_indices[1]]
        union_boxes = first_boxes | second_boxes  # (E, 4)

        # Masks to account for padding in batch of SceneGraphs (multiple sentences)
        node_mask_list = []
        edge_mask_list = []
        for idx, (num_nodes, num_edges) in enumerate(graphs.sizes):
            node_mask_list.append(features.mask[idx, None].expand(num_nodes, -1, -1))
            edge_mask_list.append(features.mask[idx, None].expand(num_edges, -1, -1))

        nodes_mask = torch.cat(node_mask_list, dim=0)  # (N, H, W)
        edges_mask = torch.cat(edge_mask_list, dim=0)  # (E, H, W)

        # Build heatmaps for visual queries for cross attention
        node_heatmaps = self.gaussian_heatmaps(boxes, nodes_mask)  # (N, H, W)
        union_heatmaps = self.gaussian_heatmaps(union_boxes, edges_mask)  # (E, H, W)
        first_heatmaps = node_heatmaps[edge_indices[0]]  # (E, H, W)
        second_heatmaps = node_heatmaps[edge_indices[1]]  # (E, H, W)
        edge_heatmaps = torch.maximum(
            torch.maximum(first_heatmaps, second_heatmaps),
            union_heatmaps,
        )  # (E, H, W)

        node_heatmaps = node_heatmaps.flatten(1)  # (N, H * W)
        edge_heatmaps = edge_heatmaps.flatten(1)  # (E, H * W)
        heatmaps_graph = graphs.new_like(node_heatmaps, edge_heatmaps)

        heatmaps = torch.cat(
            [
                heatmaps_graph.nodes(pad_value=0.0),
                heatmaps_graph.edges(pad_value=0.0),
            ],
            dim=1,
        )  # (B, N + E, H * W)

        flattened_features = features.to_batched2d()  # (B, H * W, C)
        mask = flattened_features.mask[:, None].expand_as(heatmaps)
        attn_mask = heatmaps.masked_fill_(mask, -torch.inf)
        attn_mask = attn_mask.repeat(self._num_heads, 1, 1)  # (B * heads, N + E, H * W)

        # Add bbox sinusoidal embeddings to nodes and edges
        node_embeddings = self.node_embeddings(boxes)  # (N, D)
        edge_embeddings = self.edge_embeddings(first_boxes, second_boxes)  # (E, D)
        embeddings_graph = graphs.new_like(node_embeddings, edge_embeddings)

        # graphs: scenegraphs with text embeddings
        # embeddings_graph: visual queries with added sinus bbox embeddings
        # flattened_features: the visual patch embeddings masked with gaussian heatmaps
        layer: DecoderLayer
        outputs = []
        for layer in self.layers:
            graphs = layer(
                graphs,
                embeddings_graph,
                flattened_features.tensor,
                attn_mask,
            )
            outputs.append(graphs)

        return outputs

    def __call__(
        self,
        features: Batched3DTensors,
        graphs: BatchedGraphs,
        boxes: BoundingBoxes,
    ) -> list[BatchedGraphs]:
        return super().__call__(features, graphs, boxes)  # type: ignore


class DecoderLayer(nn.Module):
    def __init__(self, config: DecoderConfig) -> None:
        super().__init__()

        self.pre_cross_attn_layernorm = nn.LayerNorm(config.hidden_dim)
        self.cross_attn = nn.MultiheadAttention(
            config.hidden_dim,
            config.num_heads,
            dropout=config.dropout,
            batch_first=True,
        )
        self.post_cross_attn_layerscale = LayerScale(config.hidden_dim)

        self.pre_gat_layernorm = nn.LayerNorm(config.hidden_dim)
        self.gat = GATConv(
            config.hidden_dim,
            config.num_heads,
            bias=True,
            dropout=config.dropout,
        )
        self.post_gat_layerscale = LayerScale(config.hidden_dim)

        self.pre_ffn_layernorm = nn.LayerNorm(config.hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(config.hidden_dim, 4 * config.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(4 * config.hidden_dim, config.hidden_dim),
        )
        self.post_ffn_layerscale = LayerScale(config.hidden_dim)

    def _perform_cross_attention(
        self,
        graphs: BatchedGraphs,
        features: Float[Tensor, "B HW D"],
        attn_mask: Float[Tensor, "Bh (N+E) HW"],
    ) -> BatchedGraphs:
        nodes = graphs.nodes(pad_value=0.0)  # (B, N, D)
        edges = graphs.edges(pad_value=0.0)  # (B, E, D)
        # padded number of nodes and edges
        N, E = nodes.shape[1], edges.shape[1]

        queries = torch.cat([nodes, edges], dim=1)  # (B, N + E, D)
        queries, _ = self.cross_attn(
            queries,
            features,
            features,
            attn_mask=attn_mask,
            need_weights=False,
        )

        nodes, edges = torch.split(queries, [N, E], dim=1)
        return graphs.new_like(nodes, edges)

    def forward(
        self,
        graphs: BatchedGraphs,
        embeddings: BatchedGraphs,
        features: Float[Tensor, "B HW D"],
        attn_mask: Float[Tensor, "Bh (N+E) HW"],
    ) -> BatchedGraphs:
        # Perform cross-attention.
        nodes, edges = graphs.nodes(None), graphs.edges(None)
        N, E = nodes.shape[0], edges.shape[0]
        pre_cross_attn_queries = torch.cat([nodes, edges], dim=0)  # (N + E, D)
        queries = self.pre_cross_attn_layernorm(pre_cross_attn_queries)
        nodes, edges = torch.split(queries, [N, E], dim=0)
        graphs = graphs.new_like(nodes, edges)
        graphs = self._perform_cross_attention(graphs, features, attn_mask)
        nodes, edges = graphs.nodes(None), graphs.edges(None)
        queries = torch.cat([nodes, edges], dim=0)  # (N + E, D)
        post_cross_attn_queries = (
            pre_cross_attn_queries + self.post_cross_attn_layerscale(queries)
        )

        # Perform GAT
        pre_gat_queries = post_cross_attn_queries
        pre_gat_queries = self.pre_gat_layernorm(pre_gat_queries)
        nodes, edges = torch.split(pre_gat_queries, [N, E], dim=0)
        graphs = graphs.new_like(nodes, edges)
        graphs = self.gat(graphs, embeddings)
        nodes, edges = graphs.nodes(None), graphs.edges(None)
        queries = torch.cat([nodes, edges], dim=0)  # (N + E, D)
        post_gat_queries = pre_gat_queries + self.post_gat_layerscale(queries)

        # Perform FFN
        pre_ffn_queries = post_gat_queries
        queries = self.pre_ffn_layernorm(pre_gat_queries)
        queries = self.ffn(queries)
        post_fnn_queries = pre_ffn_queries + self.post_ffn_layerscale(queries)

        # Update graphs
        nodes, edges = torch.split(post_fnn_queries, [N, E], dim=0)
        return graphs.new_like(nodes, edges)

    def __call__(
        self,
        graphs: BatchedGraphs,
        embeddings: BatchedGraphs,
        features: Float[Tensor, "B HW D"],
        attn_mask: Float[Tensor, "Bh (N+E) HW"],
    ) -> BatchedGraphs:
        return super().__call__(graphs, embeddings, features, attn_mask)  # type: ignore

In [None]:
import torch
from jaxtyping import Integer
from torch import Tensor, nn

from deepsight.data.structs import (
    Batch,
    BoundingBoxes,
    BoundingBoxFormat,
    ODInput,
    ODOutput,
    SceneGraph,
)
from deepsight.modeling.detectors import OwlViT
from deepsight.modeling.layers import clip
from deepsight.modeling.pipeline import Model as _Model
from deepsight.utils.torch import BatchedGraphs, Graph

from projects.sgg.modeling import Config
from projects.sgg.modeling._structs import ModelInput, ModelOutput, TextEmbeddings


class Model(_Model[ModelInput, ModelOutput]):
    def __init__(self, config: Config) -> None:
        super().__init__()

        self.vision_encoder = VisionEncoder(
            config.encoders.model,
            config.encoders.output_dim,
        )

        self.text_encoder = TextEncoder(
            config.encoders.model,
            config.encoders.output_dim,
        )

        self.detector = OwlViT(
            config.detector.box_threshold, config.detector.num_queries
        )

        self.same_entity_edge = nn.Parameter(torch.randn(1, config.encoders.output_dim))
        self.decoder = Decoder(config.decoder)

        # before computing the similarity between a node and the caption embedding, we project the node
        # indeed the caption embedding itself is obtained by projecting the pooled output of the CLIP text transformer
        # since both projections are learned, we can assume that the similarity is computed in the same space
        self.projection = nn.Linear(
            config.decoder.hidden_dim, config.decoder.hidden_dim
        )

        self.regression_head = nn.Sequential(
            nn.Linear(config.decoder.hidden_dim, config.decoder.hidden_dim),
            nn.GELU(),
            nn.Dropout(config.decoder.dropout),
            nn.Linear(config.decoder.hidden_dim, 4),
        )

    def _get_text_embeddings(self, inputs: ModelInput) -> list[TextEmbeddings]:
        texts = []
        for caption, graph in zip(inputs.captions, inputs.graphs):
            # here we do not add the article 'a' after 'a photo of' since
            # most entities already have the article in their phrase
            texts.append(caption)
            texts.extend(f"a photo of {e.phrase}" for e in graph.entities())
            texts.extend(
                f"a photo of {r.subject.phrase} {r.relation} {r.object.phrase}"
                for r in graph.triplets(None, False, False)
            )

        tmp = self.text_encoder(texts)

        embeddings = []
        count = 0

        for graph in inputs.graphs:
            caption_emb = tmp[count]
            count += 1

            entities_emb = tmp[count : count + len(graph.entities())]
            count += len(graph.entities())

            num_relations = len(graph.triplets(None, True, True))
            relations_emb = tmp[count : count + num_relations]
            count += num_relations

            embeddings.append(
                TextEmbeddings(
                    entities=entities_emb,
                    relations=relations_emb,
                    caption=caption_emb,
                )
            )

        return embeddings

    def _get_detections(self, inputs: ModelInput) -> list[ODOutput]:
        batch = Batch(
            [
                ODInput(image, [e.noun for e in graph.entities()])
                for image, graph in zip(inputs.images, inputs.graphs)
            ]
        )

        return list(self.detector(batch))

    def _get_graph(
        self,
        graph: SceneGraph,
        embeddings: TextEmbeddings,
        detections: ODOutput,
    ) -> Graph:
        device = embeddings.entities.device
        num_relations = embeddings.relations.shape[0]

        edge_index_list: list[Integer[Tensor, "2 N"]] = []
        rel_index_list: list[int] = []

        for det_idx, detection in enumerate(detections.entities):
            entity_idx = int(detection)

            # add relations between entities based on the scene graph
            for rel in graph.triplets(entity_idx, True, True):
                end = (detections.entities == rel.object).nonzero(as_tuple=True)[0]
                end = end[None]  # (1, K)

                start = torch.tensor([det_idx], device=device).expand_as(end)
                indexes = torch.cat([start, end], dim=0)  # (2, K)

                edge_index_list.append(indexes)
                rel_index_list.extend([rel.relation] * indexes.shape[1])

            # add relations between instances of the same entity
            end = (detections.entities == entity_idx).nonzero(as_tuple=True)[0]
            end = end[None]  # (1, K)

            start = torch.tensor([det_idx], device=device).expand_as(end)
            indexes = torch.cat([start, end], dim=0)  # (2, K)

            edge_index_list.append(indexes)
            rel_index_list.extend([num_relations] * indexes.shape[1])

        edge_indices = torch.cat(edge_index_list, dim=1)  # (2, E)

        relations_emb = torch.cat([embeddings.relations, self.same_entity_edge])
        rel_indices = torch.tensor(rel_index_list, device=device)
        relations = relations_emb[rel_indices]  # (E, D)

        nodes = embeddings.entities[detections.entities]  # (N, D)

        return Graph(
            nodes=nodes,
            edges=relations,
            edge_indices=edge_indices,
        )

    def forward(self, inputs: ModelInput) -> ModelOutput:
        features = self.vision_encoder(inputs.features)
        embeddings = self._get_text_embeddings(inputs)
        detections = self._get_detections(inputs)

        tmp = [
            self._get_graph(graph, embedding, detection)
            for graph, embedding, detection in zip(
                inputs.graphs, embeddings, detections
            )
        ]

        # Build decoder inputs
        graph = BatchedGraphs.from_list(tmp)
        boxes = BoundingBoxes.cat([detection.boxes for detection in detections])
        graphs = self.decoder(features, graph, boxes)

        # Compute new boxes
        base_boxes = BoundingBoxes.pad_sequence(
            [detection.boxes for detection in detections]
        )  # (B, N, 4)
        base_boxes = base_boxes.to_cxcywh().normalize()

        new_boxes = []
        for idx in range(len(graphs)):
            graph = graphs[idx]

            nodes = graph.nodes(pad_value=0)  # (B, N, D)
            offsets = self.regression_head(nodes)  # (B, N, 4)
            box_tensor = torch.logit(base_boxes.tensor) + offsets
            box_tensor = torch.sigmoid(box_tensor)
            box = BoundingBoxes(
                box_tensor,
                base_boxes.images_size,
                format=BoundingBoxFormat.CXCYWH,
                normalized=True,
            )
            new_boxes.append(box)

            nodes = graph.nodes(None)  # (N, D)
            nodes = self.projection(nodes)  # (N, D)
            graphs[idx] = graph.new_like(nodes=nodes, clone=False)

        max_detections = max(len(detection.entities) for detection in detections)
        padded_entities = torch.nn.utils.rnn.pad_sequence(
            [detection.entities for detection in detections],
            batch_first=True,
            padding_value=max_detections,
        )

        return ModelOutput(
            captions=torch.stack([embedding.caption for embedding in embeddings]),
            graphs=graphs,
            boxes=new_boxes,
            padded_entities=padded_entities,
        )

##### Criterion

The criterion is responsible for the loss calculation. To calculate the loss we first need to choose which of the candidate bounding boxes to associate to the ground truth bounding box. As previously said, we only consider the bounding boxes obtained from the graph nodes that correspond to the entity that is the subject of the region description, sine this is the target entity to detect. Notice that here we make the assumption that the first entity of the scene graph is the subject of the region description. We deemed not necessary to apply a NLP tool to extract the subject of the sentence since we never observed a case in the dataset where the subject was not the first entity.

To decide which of the candidate bounding boxes match to the target bounding box, we adopt the same approach used by DETR-like models. In particular, for each candidate bounding box, we compute the cost of matching it with the ground truth. The cost is computed as the weighted sum of three differents factors:
1. the L1 distance between the coordinates of the candidate bounding box and the ground truth bounding box;
2. the Generalized IoU between the candidate bounding box and the ground truth bounding box;
3. the negative cosine similarity between the node embedding of the candidate bounding box and the text embedding of the region description.
Notice that since there is only one ground truth bounding box, we do not need to apply the full Hungarian matching algorithm, but we can simple select the candidate that minimizes the cost.

Once the matching is done, we can compute the loss between the matched bounding boxes. The loss is computed as the weighted sum of three different losses:
1. the L1 loss between the coordinates of the matched bounding boxes;
2. the Generalized IoU ([Rezatofighi et al. 2019](https://arxiv.org/abs/1902.09630)) loss between the matched bounding boxes;
3. the InfoNCE loss ([Oord et al. 2018](https://arxiv.org/abs/1807.03748)) where the positive sample is the embedding of the matched node and the negative samples are the embeddings of the other nodes.

The first two losses are the canonical losses used by DETR-like models to make the network learn to correctly predict the bounding box coordinates. Here, we remove the classification loss (usually computed using focal loss) since we do not need to predict the class of the selected node; its class is already known from the input scene graph. However, since at inference time we select among the candidate nodes the one with the highest similarity with the text embedding of the region description, we use a contrastive loss (here, InfoNCE) to force the network to pull together the embeddings of the matched node and the text embedding of the region description and to push away the embeddings of the other wrong nodes. Since it has been shown that InfoNCE performs better when the number of negative samples is high, as negative samples we do not use only the other subject nodes but all the nodes of all the graphs in the same batch.

Finally, similarly to other DETR-like models, the loss is computed not only with respect to the output of the last decoder layer but also with respect to the output of the intermediate decoder layers. In particular, for each layer we recompute the matching and the corresponding loss.

In [None]:
import torch

from deepsight.data.structs import Batch, BoundingBoxes, RECOutput
from deepsight.measures import Loss, Reduction
from deepsight.measures.losses import BoxL1Loss, GeneralizedBoxIoULoss, InfoNCELoss
from deepsight.modeling.pipeline import Criterion as _Criterion

from projects.sgg.modeling._config import CriterionConfig
from projects.sgg.modeling._structs import ModelOutput


class Criterion(_Criterion[ModelOutput, RECOutput]):
    def __init__(self, config: CriterionConfig) -> None:
        super().__init__()

        self.auxiliary = config.auxiliary
        self.num_layers = config.num_layers

        self.l1_cost = config.l1_cost
        self.giou_cost = config.giou_cost
        self.similarity_cost = config.similarity_cost

        self.l1_weight = config.l1_weight
        self.giou_weight = config.giou_weight
        self.infonce_weight = config.infonce_weight

        self.l1_loss = BoxL1Loss(reduction=Reduction.NONE)
        self.giou_loss = GeneralizedBoxIoULoss(reduction=Reduction.NONE)
        self.infonce_loss = InfoNCELoss(
            temperature=config.temperature, reduction=Reduction.MEAN
        )

    def losses_names(self) -> list[str]:
        losses = []
        if self.auxiliary:
            losses += [f"L1_{i}" for i in range(self.num_layers)]
            losses += [f"GIoU_{i}" for i in range(self.num_layers)]
            losses += [f"InfoNCE_{i}" for i in range(self.num_layers)]
        else:
            losses += ["L1", "GIoU", "InfoNCE"]

        return losses

    def _compute_layer_loss(
        self,
        output: ModelOutput,
        tgt_boxes: BoundingBoxes,
        layer_idx: int,
    ) -> list[Loss]:
        """Computes the loss for the output of a single layer.

        Parameters
        ----------
        output : ModelOutput
            The output of the model.
        tgt_boxes : BoundingBoxes
            The target boxes. The tensor has shape (B, N, 4).
        layer_idx : int
            The index of the layer.

        Returns
        -------
        list[Loss]
            A list of the computed losses.
        """

        B, N = output.padded_entities.shape
        subject_mask = output.padded_entities != 0
        padding_mask = output.padded_entities == N

        out_boxes = output.boxes[layer_idx].to_cxcywh().normalize()  # (B, N, 4)

        l1_loss = self.l1_loss(out_boxes, tgt_boxes)  # (B, N)
        giou_loss = self.giou_loss(out_boxes, tgt_boxes)  # (B, N)

        nodes = output.graphs[layer_idx].nodes(pad_value=0.0)  # (B, N, D)
        captions = output.captions.unsqueeze(1).expand(-1, N, -1)  # (B, N, D)
        similarity = torch.cosine_similarity(nodes, captions, dim=-1)  # (B, N)

        cost = (
            self.l1_cost * l1_loss
            + self.giou_cost * giou_loss
            - self.similarity_cost * similarity
        )

        cost = cost.masked_fill_(subject_mask, torch.inf)  # (B, N)
        idx = cost.min(dim=1)[1]  # (B,)

        pos_mask = torch.zeros_like(output.padded_entities, dtype=torch.bool)  # (B, N)
        pos_mask[torch.arange(B), idx] = True

        nodes = output.graphs[layer_idx].nodes(pad_value=0.0)  # (B, N, D)
        queries = output.captions  # (B, D)
        pos_keys = nodes[pos_mask]
        neg_mask = torch.logical_xor(pos_mask, ~padding_mask)
        neg_keys = nodes[neg_mask]
        infonce_loss = self.infonce_loss(queries, pos_keys, neg_keys)  # (B,)

        l1_loss = l1_loss[pos_mask].mean()
        giou_loss = giou_loss[pos_mask].mean()

        if layer_idx == -1:
            return [
                Loss("L1", l1_loss, self.l1_weight),
                Loss("GIoU", giou_loss, self.giou_weight),
                Loss("InfoNCE", infonce_loss, self.infonce_weight),
            ]
        else:
            return [
                Loss(f"L1_{layer_idx}", l1_loss, self.l1_weight),
                Loss(f"GIoU_{layer_idx}", giou_loss, self.giou_weight),
                Loss(f"InfoNCE_{layer_idx}", infonce_loss, self.infonce_weight),
            ]

    def forward(self, output: ModelOutput, targets: Batch[RECOutput]) -> list[Loss]:
        B, N = output.padded_entities.shape

        tgt_boxes = BoundingBoxes.stack([tgt.box for tgt in targets], dim=0)  # (B, 4)
        tgt_boxes = tgt_boxes.to_cxcywh().normalize()  # (B, 4)
        tgt_boxes = tgt_boxes.unsqueeze(1).expand(-1, N, -1)  # (B, N, 4)

        if self.auxiliary:
            losses = []
            for i in range(self.num_layers):
                losses += self._compute_layer_loss(output, tgt_boxes, i)
        else:
            losses = self._compute_layer_loss(output, tgt_boxes, -1)

        return losses


##### PostProcessor

As previously described, to obtain the candidate bounding box, from the graph outputted by the last decoder layer, we select the nodes that refer to the subject of the region description. Then, the similarity between the node embeddings and the text embedding of the description is computed using cosine similarity. The bounding box associated to the nodes with the highest similarity is then returned as the candidate bounding box.

In [None]:
import torch

from deepsight.data.structs import Batch, RECOutput
from deepsight.modeling.pipeline import PostProcessor as _PostProcessor

from projects.sgg.modeling._structs import ModelOutput


class PostProcessor(_PostProcessor[ModelOutput, RECOutput]):
    def forward(self, output: ModelOutput) -> Batch[RECOutput]:
        B, N = output.padded_entities.shape
        subject_mask = output.padded_entities != 0  # (B, N)

        queries = output.captions.unsqueeze(1)  # (B, 1, D)
        keys = output.graphs[-1].nodes(pad_value=0.0)  # (B, N, D)

        similarity = torch.cosine_similarity(queries, keys, dim=-1)  # (B, N)
        similarity.masked_fill_(subject_mask, -torch.inf)  # (B, N)

        idx = similarity.max(dim=1)[1]  # (B,)

        boxes = output.boxes[-1][torch.arange(B), idx]  # (B, 4)

        return Batch([RECOutput(box=boxes[i]) for i in range(B)])


## Experiments

### Training

#### Implementation details

The hyperparameters chosen to train the network are mostly based on the ones used by the detection papers our work is based on.

In [None]:
from pathlib import Path

from deepsight.modeling.layers.clip import Models
from projects.sgg.modeling import (
    Config,
    CriterionConfig,
    DecoderConfig,
    DetectorConfig,
    EncodersConfig,
    PreprocessorConfig,
)

pipeline_config = Config(
    preprocessor=PreprocessorConfig(
        file=Path("./data/refcocog/annotations/scene_graphs.json"),
        token="",
        # the same resolution is used by Grounding DINO
        side=800,
        max_side=1333,
        # mean and std of the CLIP model
        mean=[0.48145466, 0.4578275, 0.40821073],
        std=[0.26862954, 0.26130258, 0.27577711],
    ),
    encoders=EncodersConfig(
        output_dim=256,
        # we choose the smallest ViT model to reduce training times and memory usage
        model=Models.ViT_B_32_224,
    ),
    detector=DetectorConfig(
        # we found that a threshold of 0.25 is the smallest value that allows the model to detect most of the entities
        # without retuning bounding boxes for instances that are clearly not present in the image
        box_threshold=0.25,
        # the number of bounding box to return for an entity that was not found with the box_threshold
        # we set this number to 4 such that the returned boundig box can uniformly cover the whole image
        num_queries=4,
    ),
    decoder=DecoderConfig(
        # 256 is the typical hidden dimension used by all object detection models inspired by DETR
        hidden_dim=256,
        # In DETR-like models, the decoder usually consists of 6 layers. Here, to reduce training times and memory usage,
        # we use only 3 layers. However, it has been shown that the performance of such models is not extremely sensitive
        # to the number of layers used in the decoder. Notice that the number of decoder layers also limits the nodes
        # with which a graph node can exchange information (this is a problem of traditional MPNNs).
        # In this case, each node can exhange directly or indirectly information with all the nodes at a maximum distance of 3.
        # This is not a problem since there are no scene graphs with a depth greater than 3.
        num_layers=3,
        num_heads=8,
        # Since our decoder has less than 18 layers, the initial epsilon value is set to 0.1
        # (see https://paperswithcode.com/method/layerscale)
        epsilon_layer_scale=0.1,
    ),
    criterion=CriterionConfig(
        # we use the same cost and loss weights employed by DINO
        # in this case, since we substitute the classification loss with the InfoNCE loss,
        # we use the weights of the classification loss for the InfoNCE loss
        # indeed, in our case the InfoNCE can be seen as a sort of classification loss,
        # since it pulls together the embeddings of the query and the correct subject node
        # and the similatity of these embeddings is used to "classify" the right node and the wrong nodes
        l1_cost=5.0,
        giou_cost=2.0,
        similarity_cost=2.0,
        l1_weight=5.0,
        giou_weight=2.0,
        infonce_weight=1.0,
        # whether to compute the loss also for intermediate layers
        # DETR-like models usually compute the loss for all layers
        auxiliary=True,
    ),
)

In [None]:
from pathlib import Path

from datasets.refcocog import Config as RefCOCOGConfig
from deepsight.engines.trainer import Config, Params
from deepsight.lr_schedulers.torch import Config as LRSchedulerConfig
from deepsight.optimizers.torch import Config as OptimizerConfig
from deepsight.optimizers.torch import ParamGroupConfig
from deepsight.utils import wandb
from deepsight.utils.torch import FloatType

train_config = Config(
    dir=Path("./output/sgg"),
    wandb=wandb.Config(
        job_type="train",
        enabled=False,
        project="sgg",
        entity="visgator",
    ),
    debug=False,
    params=Params(
        num_epochs=24,
        # the total batch size is set to 256, however we implement gradient accumulation
        # since 256 samples do not fit in a single gpu H100
        train_batch_size=256,
        eval_batch_size=16,
        gradient_accumulation_steps=16,
        max_grad_norm=5.0,
        # we use a smaller init_scale that the default one (2**16) since the default ones lead to overflows
        init_scale=2**12,
        seed=3407, # https://arxiv.org/abs/2109.08203
        # we use mixed precision training to reduce memory usage and training times
        dtype=FloatType.FLOAT16,
        dataset=RefCOCOGConfig(path=Path("./data/refcocog")),
        pipeline=pipeline_config,
        # We use AdamW as optimizer as recommended by FastAI and it is the same optimizer used by DINO and Grounding DINO.
        # Since CLIP is trained on an image-text association task, while we want the vision encoder to encode local information
        # in each patch, we have decided to finetune the CLIP encoders. Since the same reasoning is at the base of OwlViT, we
        # use the same learning rate and weight decay values used by OwlViT to finetune the imaget-text constrative pretrained text and vision encoder. 7
        # In particular, they show that it is fundamental to use a much smaller learning rate (100x smaller) for the text encoder
        # to reduce overfitting, possibly by preventing the text encoder from “forgetting” the semantics learned during pre-training
        # while fine-tuning on the small space of detection label.
        # The weight decay value for the decoder is the same used by DINO. With respect to DINO that uses a learning rate of 1e-4,
        # we increase it to 5e-4 since we use OneCycleLR and thus the learning rate is much smaller than the maximum value for most of the training.
        # Since the projections of the vision and text encoder must be trained from scratch, we use the same learning rate and weight decay
        # of the decoder.
        optimizer=OptimizerConfig(
            "AdamW",
            groups=[
                ParamGroupConfig(
                    regex=r"model.vision_encoder.projection.*",
                    args={"lr": 5e-4, "weight_decay": 1e-4},
                ),
                ParamGroupConfig(
                    regex=r"model.text_encoder.projection.*",
                    args={"lr": 5e-4, "weight_decay": 1e-4},
                ),
                ParamGroupConfig(
                    regex=r"model.vision_encoder.*",
                    args={"lr": 2e-4, "weight_decay": 0.0},
                ),
                ParamGroupConfig(
                    regex=r"model.text_encoder.*",
                    args={"lr": 2e-6, "weight_decay": 0.0},
                ),
                ParamGroupConfig(
                    regex=r"model.*", args={"lr": 5e-4, "weight_decay": 1e-4}
                ),
            ],
            # we do not finetune the detector since it is pretrained on much more data than RefCOCOg
            # thus by finetuning the detector may lose its detection capabilities on a variety of scenes
            # not present in the training set of RefCOCOg
            freeze=["model.detector.*"],
        ),
        # as learning rate scheduler we use the OneCycleLR scheduler since it is the one recommended by FastAI
        # we also use the same default behaviour of the FastAI implementation
        lr_scheduler=LRSchedulerConfig(
            "OneCycleLR", args={"max_lr": [5e-4, 5e-4, 2e-4, 2e-6, 5e-4]}
        ),
    ),
)

We do not report here the code of the trainer since it would only clutter the notebook and the its implementation follows the standard PyTorch training logic. The code can be found at `deepsight.engines.trainer._trainer`.

In [None]:
%cd /content/visgator

/content/visgator


In [None]:
# To start the training, run:

from typing import Any

from deepsight.engines.trainer import Trainer

# note the training will fail since the batch size is too high for colab
# we do not modify the hyperparameters since they are the ones used in the training run
# eventually, it is sufficient to set an higher number as gradient_accumulation_steps
trainer: Trainer[Any, Any] = Trainer.new(train_config)
trainer.run()

[2023-07-11 20:47:21] INFO: Using device cuda:0.
[2023-07-11 20:47:21] INFO: Using device cuda:0.
[2023-07-11 20:47:21] INFO: Using device cuda:0.
[2023-07-11 20:47:21] INFO: Using device cuda:0.
[2023-07-11 20:47:21] INFO: Using device cuda:0.
[2023-07-11 20:47:54] INFO: Using RefCOCOg dataset.
[2023-07-11 20:47:54] INFO: Using RefCOCOg dataset.
[2023-07-11 20:47:54] INFO: Using RefCOCOg dataset.
[2023-07-11 20:47:54] INFO: Using RefCOCOg dataset.
[2023-07-11 20:47:54] INFO: Using RefCOCOg dataset.
[2023-07-11 20:47:54] INFO: 	(train) size: 80506 | (eval) size: 4896
[2023-07-11 20:47:54] INFO: 	(train) size: 80506 | (eval) size: 4896
[2023-07-11 20:47:54] INFO: 	(train) size: 80506 | (eval) size: 4896
[2023-07-11 20:47:54] INFO: 	(train) size: 80506 | (eval) size: 4896
[2023-07-11 20:47:54] INFO: 	(train) size: 80506 | (eval) size: 4896
[2023-07-11 20:47:54] INFO: 	(train) batch size: 256 | (eval) batch size: 16
[2023-07-11 20:47:54] INFO: 	(train) batch size: 256 | (eval) batch size:

OutOfMemoryError: ignored

#### Results

We trained **Scene Graph Grounder** on a single NVIDIA H100. Due to computational limitations and costs, SGG has been trained only once for 15 epochs, amounting to 35 hours (56h expected for 24 epochs). Hereafter, we report some plots to summarize the results of the training session (to see all the logged statistics see the run on [Weights & Biases](https://wandb.ai/visgator/sgg/runs/z3ruios2)).

<img src="https://github.com/FrancescoGentile/visgator/blob/deepsight/docs/img/total_loss.png?raw=true" width="500px">

<img src="https://github.com/FrancescoGentile/visgator/blob/deepsight/docs/img/giou.png?raw=true" width="500px">

<img src="https://github.com/FrancescoGentile/visgator/blob/deepsight/docs/img/iou.png?raw=true" width="500px">

<img src="https://github.com/FrancescoGentile/visgator/blob/deepsight/docs/img/accuracy_50.png?raw=true" width="500px">

<img src="https://github.com/FrancescoGentile/visgator/blob/deepsight/docs/img/accuracy_75.png?raw=true" width="500px">

<img src="https://github.com/FrancescoGentile/visgator/blob/deepsight/docs/img/accuracy_90.png?raw=true" width="500px">


From the reported plots we can make two main observations:
1. The performances of the model improves slowly. Indeed, both train and evaluation metrics improve approximately by 0.01 every epoch. This may be due to a too small learning rate that does not allow the model update its parameters fast enough. The second possibility is that the model already starts from a good initialization point (thanks to the position bias given by the detector) and has not enough capacity to significantly improve its performances. Thus, further works could focus also on scaling the architecture: our current model amounts up to **154,706,948 (155M)** trainable parameters; but only 3.8M parameters belong to the decoder (all the others are from CLIP). Thus, the decoder may not have enough capacity to refine the predictions made by OwlViT. Another possibility could be to use a better object detector like Grounding DINO. Finally, note that the limited capacity of the model may not necessarily be due to the decoder, but it may depend also on the vision encoder. Indeed, the CLIP vision encoder was trained on image level tasks thus it is not trained to extract region level features that are needed to perform object detection.
2. The model is clearly overfitting after epoch 8. Indeed, while all training metrics keep improving, all evaluation metrics starts to decrease. This can be seen also by looking at the training and evaluation losses. All the training losses keep decreasing, while most of the evaluation losses starts to increase or remain stable after epoch 8. In particular, the GIoU loss for all layers has a huge increase after epoch 13, while the InfoNCE increases much more slowly. The only loss that keeps decreasing is the L1 loss. Unfortunately, we were not able to perform more experiments to try to solve this problem. Possible ways may consist in increasing the weight decay, adding data augmentations, increasing the dropout (in our experiment we used a relatively low dropout of 0.1) or adopting  dropout strategies like DropHead ([Zhou et al. 2020](https://arxiv.org/abs/2004.13342)) or structured dropout ([Fan et al. 2019](https://arxiv.org/1909.11556)), even though the number of layers should not be the cause of the overfitting since we only have 3 decoder layers.

As for the data augmentations, in the repository we have already implemented various augmentations using `torchvision` and `albumentations` (see `deepsight.data.transformations`). We decided not to use them to first verify whether the model was able to learn in the simplest setting. However, we believe that data augmentations could be very useful to improve the performances of the model, since are often used in object detection models. However, due to the functioning of the proposed model, many traditionally used augmentation cannot be used. For example, random cropping cannot be used since it may lead to the removal of entities referred in the sentence. Similarly, geometric operations like random rotation, affine transformations or perspective transformations can be applied only on a small scale since they may alter the spatial position of an entity and thus compromise the spatial attributes contained in the input sentence. Lastly, also pixel transformation like hue or color jittering can not be applied too strongly since thay may cause complete alteration of the colours in the image and thus compromise the colour attributes contained in the input sentence (for example, if the sentence says "the red car", if the image is too much altered, the car may not be red anymore).

### Comparison

To evaluate the effectiveness of the proposed model, we define two different baselines.

The first baseline is the ones proposed in the assignment, that is we use YOLOv8 extra to detect a set of boundign boxes in the input image. Then we crop the image using the detected bounding boxes and we use CLIP to compute the similarity between the cropped images and the text embedding of the region description. The bounding box with the highest similarity is then returned as the candidate bounding box.

In [None]:
import torch
from torch import Tensor, nn
from transformers.models.clip import CLIPModel, CLIPProcessor

from deepsight.data.structs import Batch, BoundingBoxes, ODInput, RECInput, RECOutput
from deepsight.modeling.detectors import YOLO
from deepsight.modeling.pipeline import Model as _Model

from projects.yoloclip.modeling import Config


class Model(_Model[Batch[RECInput], Batch[RECOutput]]):
    def __init__(self, config: Config) -> None:
        super().__init__()

        self._dummy = nn.Parameter(torch.empty(0))

        self.yolo = YOLO(config.yolo, config.box_threshold)

        self.clip = CLIPModel.from_pretrained(config.clip.weights())
        self.processor = CLIPProcessor.from_pretrained(config.clip.weights())

    def forward(self, inputs: Batch[RECInput]) -> Batch[RECOutput]:
        det_results = self.yolo(Batch([ODInput(inp.image, []) for inp in inputs]))

        outputs = []
        for sample_idx, result in enumerate(det_results):
            if len(result.entities) == 0:
                outputs.append(RECOutput(result.boxes))
                continue

            image = inputs[sample_idx].image.denormalize().data
            cropped_regions: list[Tensor] = []
            boxes = result.boxes.to_xyxy().denormalize()
            for bbox in boxes.tensor:
                x1, y1, x2, y2 = bbox.int()
                cropped_regions.append(image[:, y1:y2, x1:x2])

            tmp = self.processor(
                text=inputs[sample_idx].description,
                images=cropped_regions,
                return_tensors="pt",
            ).to(self._dummy.device)

            output = self.clip(**tmp)
            idx = output.logits_per_image.argmax(0).item()

            outputs.append(
                RECOutput(
                    BoundingBoxes(
                        boxes.tensor[idx],
                        images_size=boxes.images_size[idx],
                        format=boxes.format,
                        normalized=boxes.normalized,
                    )
                )
            )

        return Batch(outputs)

The second baseline consists in simply giving in input to OwlViT the image and region description and then select the bouding box with the highest confidence score.

In [None]:
from deepsight.data.structs import Batch, ODInput, RECInput, RECOutput
from deepsight.modeling.detectors import OwlViT
from deepsight.modeling.pipeline import Model as _Model

from projects.owlvit.modeling import Config


class Model(_Model[Batch[RECInput], Batch[RECOutput]]):
    def __init__(self, config: Config) -> None:
        super().__init__()

        self.detector = OwlViT(config.box_threshold, 1)

    def forward(self, inputs: Batch[RECInput]) -> Batch[RECOutput]:
        tmp = Batch([ODInput(inp.image, [inp.description]) for inp in inputs])
        results = self.detector(tmp)

        outputs = []
        for result in results:
            idx = result.scores.argmax()
            box = result.boxes[idx]
            outputs.append(RECOutput(box))

        return Batch(outputs)


These are the configurations used to test our model and the two baselines.

In [None]:
# download the trained weights

!gdown 1NsFEpRxwO8lI22K4dkvH_n9RXcls0pGO
!mkdir weights
!mv weights.pt weights/

In [None]:
from pathlib import Path

from datasets.refcocog import Config as RefCOCOGConfig
from deepsight.engines.tester import Config
from deepsight.utils.wandb import Config as WandbConfig

sgg_test_config = Config(
    dataset=RefCOCOGConfig(Path("data/refcocog")),
    pipeline=pipeline_config,
    wandb=WandbConfig(
        enabled=False,
        job_type="test",
        project="tests",
        entity="visgator",
        save=False,
    ),
    weights=Path("weights/weights.pt"),
)


In [None]:
from pathlib import Path

from datasets.refcocog import Config as RefCOCOGConfig
from deepsight.engines.tester import Config
from deepsight.modeling.detectors import YOLOModel
from deepsight.utils.wandb import Config as WandbConfig
from projects.yoloclip.modeling import Config as PipelineConfig

yoloclip_test_config = Config(
    dataset=RefCOCOGConfig(Path("data/refcocog")),
    pipeline=PipelineConfig(yolo=YOLOModel.EXTRA),
    wandb=WandbConfig(
        enabled=False,
        job_type="test",
        project="tests",
        entity="visgator",
        save=False,
    ),
)


In [None]:
from pathlib import Path

from datasets.refcocog import Config as RefCOCOGConfig
from deepsight.engines.tester import Config
from deepsight.utils.wandb import Config as WandbConfig
from projects.owlvit.modeling import Config as PipelineConfig

owlvit_test_config = Config(
    dataset=RefCOCOGConfig(Path("data/refcocog")),
    pipeline=PipelineConfig(),
    wandb=WandbConfig(
        enabled=False,
        job_type="test",
        project="tests",
        entity="visgator",
        save=False,
    ),
)


In [None]:
# To start the testing run:

from typing import Any

from deepsight.engines.tester import Tester

tester: Tester[Any, Any] = Tester.new(sgg_test_config) # substitute the passed config with the one of the model to test
tester.run()

To compare the three models we use the same metrics used in the training phase, that is:
1. Intersection over Union (IoU);
2. Generalized IoU (GIoU);
3. Accuracy@50 counts how many predictions have a IoU greater than 0.5;
4. Accuracy@75 counts how many predictions have a IoU greater than 0.75;
5. Accuracy@90 counts how many predictions have a IoU greater than 0.9.

Hereafter, we report the results (these results can also be seen in [Weight & Biases](https://wandb.ai/visgator/tests)):

|Method|Accuracy@50|Accuracy@75|Accuracy@90|GIoU|IoU|
|---|---|---|---|---|---|
|SGG|<ins>65.18</ins>|<ins>54.36</ins>|29.33|<ins>0.5293</ins>|<ins>0.5998</ins>|
|YOLOCLip|56.05|52.78|<ins>45.87<ins>|0.4720|0.5661|
|OwlViT|48.55|38.87|21.20|0.3777|0.4718|

We also report the results of current or past state of the art models. In particular, for each model we report the datasets on which it was trained, whether it was finetuned on RefCOCOg and its Accuracy@50.

As we can see, our model performs worse than all these models. This however should not come as surprise since such models are trained on exponentially more data than our model and they have many more weights. Interesting is the fact that both GLIP and OwlViT, the only models not finetuned on RefCOCOg, perform much worse than the other models. This may indicate that visual grounding requires specific supervision.

|Model|Pre-Training data|Finetuned|Accuracy@50|
|---|---|---|---|
|RefTR ([Li et al. 2022](https://arxiv.org/abs/2106.03089))| VG | yes | 80.01|
|mDETR-ENB3 ([Kamath et al. 2021](https://arxiv.org/abs/2104.12763))|GoldG, RefC|yes|83.31|
|DQ-DETR ([Liu et al. 2022](https://arxiv.org/abs/2211.15516)| GoldG, RefC | yes |83.44|
|mPLUG-2 ([Xu et al. 2023](https://arxiv.org/abs/2302.00402))|COCO, VG, CC3M, CC12M, SBU|yes|85.14|
|GLIP-T ([Li et al. 2021](https://arxiv.org/abs/2112.03857))|O365, GoldG, Cap4M|no|66.89|
|Grounding-DINO-L ([Liu et al. 2023](https://arxiv.org/abs/2303.05499))|O365, OI, GoldG, Cap4M, COCO, RefC|yes|87.02|
|OwlViT|O365, VG|no|48.55|
|__Scene Graph Grounder__ (Ours)| None |yes|65.18|

where O365 stands for Object365, OI for OpenImage, CC for Conceptual Captions, RefC for all three RefCOCO datasets, VG for Visual Genome.


## Future works

We report here three interesting ways in which the model could be improved in the future:
1. The first way is to use a better vision encoder. As we have seen, the CLIP vision encoder is not trained to extract region level features that are needed to perform object detection. Thus, it would be interesting to use a vision encoder that is trained to extract region level features.
2. Remove the detector and make the decoder perform the detection task from scratch. First of all, this would reduce the training and inference time since the model would be much smaller. Furthermore, this would require the model to be trained on larger dataset (like LVIS, Object365, etc.) and thus the modle would be able to learn better features. Particularly interesting would be training the model on Visual Genome since it already provides the scene graph for each region. This would allow to train the model on higher quality data and to supervise the model not only using the bounding box of the subject but the boundign boxes of all the entities referred in the region description.
3. Train a new text encoder that would not only encode the text but it would also build the scene graph of the inout sentence (similarly to how [Attardi et al. 2022](https://github.com/Unipisa/diaparser) uses a transformer to obtain the dependency graph of the input sentence). This would allow not only to generate more accurate scene graphs but it would also lead the model to generate text encodings that contain high-level information about the context described in the sentence.