# Recall@x for visual relationships

From Wikipedia (CC BY-SA 4.0):

<img 
     src="https://upload.wikimedia.org/wikipedia/commons/2/26/Precisionrecall.svg" 
     alt="Precision and Recall"
     width="25%"
/>

In [1]:
%matplotlib inline

import torch
import numpy as np
import pandas as pd
import sklearn.metrics
import matplotlib.pyplot as plt

from itertools import product
from IPython.display import display, Markdown

np.set_printoptions(precision=2)
np.errstate(invalid='ignore', divide='ignore').__enter__()

## Dataset

Dataset like ["Visual Relationship Detection with Language Priors" (Lu et al.)](https://arxiv.org/abs/1608.00187).

- Each image can have multiple annotated relationships
- Predictions are made for every triplet (obj, pred, obj), 
  where the objects come from a detection system applied to the image,
  and predicates come from the predicate vocabulary
- A metric is computed for every image, then aggregated across all images

In [2]:
rg = np.random.default_rng(42)

objects = ['person', 'dog']
predicates = ['talk to', 'look at', 'next to']

images = [
    {
        # Few annotated relationships, 
        # all get a high score,
        # all others get low scores
        'objects': ['dog', 'person'],
        'gt_relationships': {
            (0, 'talk to', 1),
            (1, 'next to', 0),
            (1, 'look at', 0),
        }
    },
    {
        # Many annotated relationships, 
        # more than the number considered for top-x,
        # all get a high score
        'objects': ['dog', 'person', 'person'],
        'gt_relationships': {
            (1, 'talk to', 2),
            (2, 'look at', 0),
            (2, 'next to', 0),
            (0, 'next to', 2),
            (1, 'look at', 0),
            (1, 'next to', 0),
            (0, 'next to', 1),
        }
    },
    {
        'objects': ['dog', 'dog'],
        'gt_relationships': {
            (0, 'look at', 1),
            (1, 'next to', 0),
        }
    },
]

data = pd.DataFrame(
    [
        (
            img_idx, subj_idx, pred, obj_idx, 
            f'{subj}, {pred}, {obj}', 
            rg.integers(0, 100), 
            (subj_idx, pred, obj_idx) in img_d['gt_relationships']
        )
        for img_idx, img_d in enumerate(images)
        for subj_idx, subj in enumerate(img_d['objects'])
        for pred in predicates
        for obj_idx, obj in enumerate(img_d['objects'])
    ],
    columns=['img', 'subj_idx', 'predicate', 'obj_idx', 'relationship', 'score', 'annotated']
)

data.loc[(data.img == 0) & (data.annotated), 'score'] = rg.integers(90, 100, size=np.count_nonzero((data.img == 0) & (data.annotated)))
data.loc[(data.img == 0) & (~data.annotated), 'score'] = rg.integers(0, 20, size=np.count_nonzero((data.img == 0) & (~data.annotated)))

data.loc[(data.img == 1) & (data.annotated), 'score'] = rg.integers(80, 100, size=np.count_nonzero((data.img == 1) & (data.annotated)))

data.set_index(['img', 'subj_idx', 'predicate', 'obj_idx'], inplace=True)

display(
    data.style
    .format({True: 'v', False: 'x'}.get, subset='annotated')
    .bar(color='green', subset='score', vmin=0, vmax=100)
    .highlight_max(axis=0, color='""; color: green; font-weight: bold;', subset='annotated')
    .highlight_min(axis=0, color='""; color: red;', subset='annotated')
)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,relationship,score,annotated
img,subj_idx,predicate,obj_idx,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,0,talk to,0,"dog, talk to, dog",9,x
0,0,talk to,1,"dog, talk to, person",91,v
0,0,look at,0,"dog, look at, dog",0,x
0,0,look at,1,"dog, look at, person",10,x
0,0,next to,0,"dog, next to, dog",3,x
0,0,next to,1,"dog, next to, person",14,x
0,1,talk to,0,"person, talk to, dog",13,x
0,1,talk to,1,"person, talk to, person",18,x
0,1,look at,0,"person, look at, dog",93,v
0,1,look at,1,"person, look at, person",14,x


## Recall@x

### Definitions

["Visual Relationship Detection with Language Priors" (Lu et al.)](https://arxiv.org/abs/1608.00187) 
describes `recall@x` as:

> The evaluation metrics we report is recall @ 100 and recall @ 50.
Recall @ x computes the fraction of times the correct relationship is predicted
in the top x confident relationship predictions. Since we have 70 predicates and
an average of 18 objects per image, the total possible number of relationship
predictions is 100×70×100, which implies that the random guess will result in a
recall @ 100 of 0.00014.

[Weakly-supervised learning of visual relations (Peyre et al.)](https://arxiv.org/abs/1707.09472)
cites the previous paper and describes `recall@x` as:

>  We compute recall @ x which corresponds to the proportion of ground truth 
pairs among the x top scored candidate pairs in each image.

The book [Introduction to Information Retrieval, Cambridge University Press. 2008 (Manning, Raghavan, Schütze)](https://nlp.stanford.edu/IR-book/information-retrieval-book.html) defines:

**Precision** is the fraction of retrieved documents that are relevant:

                # ( relevant items retrieved )
    Precision = ------------------------------ = P ( relevant | retrieved )
                    # ( retrieved items )

**Recall** is the fraction of relevant documents that are retrieved:

                # ( relevant items retrieved )
      Recall  = ------------------------------ = P ( retrieved | relevant )
                     # ( relevant items )

A simple implementation of `recall@50` could be:

```python
def recall_at(top_50_triplets, annotated_triplets):
    count = 0
    for gt_triplet in annotated_triplets:
        if gt_triplet in top_50_triplets:
            count += 1
    return count / len(gt_triplets)
```

The key elements is that the percentage is computed by dividing by `len(gt_triplets)`, i.e. the number of _annotated_ triplets:
- We assume a noisy annotation process, i.e. the annotators might annotate only a portion of the relationships that are truly present in an image.<br/>
  Therefore we can only check that the network is able to retrieve _at least_ the annotated relationships.<bt/>
  Predicting a relationship that is not in the annotated set is not penalized.

How the `gt_triplet in top_50_triplets` operation is performed depends on the evaluation task:
- **Predicate detection**<br/>
  Subj/obj class/box are provided<br/>
  We only need to match boxes on their ids (e.g. sequence number of the box).
- **Phrase detection**<br/>
  Subj/obj class/box are not provided<br/>
  For every pair of boxes we need to compute the enclosing box and match it with the ground-truth enclosing box, 
  also we need to match the predicted subj/pred/obj labels with the ground-truth labels.
- **Relationship detection**<br/>
  Subj/obj class/box are not provided<br/>
  For every pair of boxes we need match the subj/obj boxes with the ground-truth subj/obj boxes, 
  also we need to match the predicted subj/pred/obj labels with the ground-truth labels.

However, we are always retieving exactly 50 triplets, regardless of whether the model predicted a high score for only ~5 of them or for more than 50:
- `len(annotated_triplets) <= 50`<br/>
  If the number of annotated triplet is smaller than the number of retrieved elements, <br/>
  no problem, the model can get 100% recall just by predicting the annotated triplets
- `len(annotated_triplets) > 50`<br/>
  If the number of annotated triplet is larger than the number of retrieved elements, <br/>
  it will be impossible to achieve 100% recall.

In [3]:
def pr(df, *, x=None):
    """Precision-recall summary for a single image.
    
    Args:
        df: a dataframe containing a binary `annotated` column and a `score` column.
        x: consider only the top-x samples according to their score

    Returns: a dataframe with a single row and various statistics as columns.
    """
    relevant_tot = df['annotated'].sum()
    if x is not None:
        df = df.nlargest(x, 'score')

    retrieved = len(df)
    retrieved_and_relevant = df['annotated'].sum()

    # How many retrieved items are relevant
    precision = retrieved_and_relevant / retrieved

    # How many relevant items are retrieved
    recall = retrieved_and_relevant / relevant_tot

    f1 = 2 * precision * recall / (precision + recall)

    # Return a df so each column has a different dtype
    return pd.DataFrame({
        'relevant_tot': [relevant_tot],
        'retrieved': [retrieved],
        'retrieved_and_relevant': [retrieved_and_relevant],
        'precision': [precision],
        'recall': [recall],
        'f1': [f1]
    })

### Example

Here we consider the top 5 predicted relationships per image:
- These 5 relationships are always considered as _retrieved_, even if their score is low.
- Some images have fewer than 5 _relevant_ relationships and therefore it results impossible to achieve 100% precision.
- Some other images have more than 5 _relevant_ relationships and therefore it results impossible to achieve 100% recall.

In [4]:
(
    data
    .groupby('img')
    .apply(lambda img_df: img_df.droplevel(0, axis=0).nlargest(5, 'score'))
    .style
    .format(['x', 'v'].__getitem__, subset='annotated')
    .bar(color='green', subset='score', vmin=0, vmax=100)
    .highlight_max(axis=0, color='""; color: green; font-weight: bold;', subset='annotated')
    .highlight_min(axis=0, color='""; color: red;', subset='annotated')
)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,relationship,score,annotated
img,subj_idx,predicate,obj_idx,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,1,next to,0,"person, next to, dog",94,v
0,1,look at,0,"person, look at, dog",93,v
0,0,talk to,1,"dog, talk to, person",91,v
0,1,talk to,1,"person, talk to, person",18,x
0,0,next to,1,"dog, next to, person",14,x
1,0,next to,1,"dog, next to, person",99,v
1,1,look at,0,"person, look at, dog",98,v
1,2,next to,0,"person, next to, dog",89,v
1,0,next to,2,"dog, next to, person",88,v
1,2,look at,1,"person, look at, person",88,x


**Image 0:** The number of annotated relationships (`relevant_tot = 3`) is lower than the number of selected relationships (`retrieved = 5`),<br />
therefore it is impossible to achieve 100% precision (even if the system retrieves all 3 relevant document within the top 5, it is also forced to retrieve 2 irrelevant documents).

**Image 1:** The number of annotated relationships (`relevant_tot = 7`) is higher than the number of selected relationships (`retrieved = 5`),<br />
therefore it is impossible to achieve 100% recall.

In [5]:
(
    data
    .groupby('img')
    .apply(pr, x=5)
    .droplevel(1, axis=0)
    .style
    .format('{:.1%}', subset=['precision', 'recall', 'f1'])
)

Unnamed: 0_level_0,relevant_tot,retrieved,retrieved_and_relevant,precision,recall,f1
img,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,3,5,3,60.0%,100.0%,75.0%
1,7,5,4,80.0%,57.1%,66.7%
2,2,5,1,20.0%,50.0%,28.6%


### Comment

In a real dataset, we can assume that the number of _annotated_ relationships is much lower than the number of _true_ relationships present in an image.<br/>
Therefore, it's very likely that among the top 50 relationships retrieved, many relationships are _truly relevant_, even if the corresponding annotation is missing/negative.

In [6]:
rg = np.random.default_rng(42)

truly_relevant = rg.random(1000) < .14
annotated = np.where(rg.random(1000) < .25, truly_relevant, False)
score = rg.normal(truly_relevant * 100, scale=40).astype(int)
random_score = rg.integers(0, 100, 1000)

display(pd.DataFrame({
    'score': score,
    'random_score': random_score,
    'annotated': annotated,
    'truly relevant': truly_relevant,
}))

display(
    pd.DataFrame({
        'annotated': annotated,
        'truly relevant': truly_relevant,
    })
    .transpose()
    .agg(['sum', 'mean'], axis=1)
    .rename(columns={'sum': 'count', 'mean': 'percent'})
    .style
    .format('{:.2%}', subset='percent')
)

Unnamed: 0,score,random_score,annotated,truly relevant
0,-40,11,False,False
1,36,57,False,False
2,-48,17,False,False
3,-11,31,False,False
4,76,41,False,True
...,...,...,...,...
995,-46,52,False,False
996,30,60,False,False
997,88,86,False,True
998,-100,84,False,False


Unnamed: 0,count,percent
annotated,44,4.40%
truly relevant,152,15.20%


**Random classifier**

If we use the annotation as indicator of relevance, we get:
- low precision (there are many retrieved document, but few are annotated as relevant)
- low recall (there are very few relevant documents, even retrieving 50 doesn't help).

In [7]:
(
    pr(pd.DataFrame({'score': random_score, 'annotated': annotated}), x=50)
    .style
    .hide_index()
    .format('{:.1%}', subset=['precision', 'recall', 'f1'])
)

relevant_tot,retrieved,retrieved_and_relevant,precision,recall,f1
44,50,3,6.0%,6.8%,6.4%


If we use the _true relevancy_ instead, we get:
- higher precision (we are still retrieving at random, but the number of documents in the top 50 that are annotated as relevant has increased),
- lower recall (the total number of relevant documents increased, so the few relevant documents that were in the top 50 by chance are even less w.r.t. the total).

In [8]:
(
    pr(pd.DataFrame({'score': random_score, 'annotated': truly_relevant}), x=50)
    .style
    .hide_index()
    .format('{:.1%}', subset=['precision', 'recall', 'f1'])
)

relevant_tot,retrieved,retrieved_and_relevant,precision,recall,f1
152,50,9,18.0%,5.9%,8.9%


**Good classifier**

The good classifier has learned to classify according to the true relevancy of the data, even if it was trained on a dataset with potentially missing annotations.

If we use the annotation as indicator of relevance:

In [9]:
(
    pr(pd.DataFrame({'score': score, 'annotated': annotated}), x=50)
    .style
    .hide_index()
    .format('{:.1%}', subset=['precision', 'recall', 'f1'])
)

relevant_tot,retrieved,retrieved_and_relevant,precision,recall,f1
44,50,13,26.0%,29.5%,27.7%


If we use the _true relevancy_ both precision and recall go up because we have higher quality ground-truth labels

In [10]:
(
    pr(pd.DataFrame({'score': score, 'annotated': truly_relevant}), x=50)
    .style
    .hide_index()
    .format('{:.1%}', subset=['precision', 'recall', 'f1'])
)

relevant_tot,retrieved,retrieved_and_relevant,precision,recall,f1
152,50,49,98.0%,32.2%,48.5%


**Oracle (perfect classifier)**

Predictions from an oracle don't score so well against the annotations.

In [11]:
(
    pr(pd.DataFrame({'score': truly_relevant, 'annotated': annotated}), x=50)
    .style
    .hide_index()
    .format('{:.1%}', subset=['precision', 'recall', 'f1'])
)

relevant_tot,retrieved,retrieved_and_relevant,precision,recall,f1
44,50,9,18.0%,20.5%,19.1%


If we use the _true relevancy_ instead, we get:
- perfect precision (we know exactly which documents are relevant and we retrieve 50 of them)
- low recall (the number of relevant documents is capped to 50, so it's impossible to retrieve all 152)

In [12]:
(
    pr(pd.DataFrame({'score': truly_relevant, 'annotated': truly_relevant}), x=50)
    .style
    .hide_index()
    .format('{:.1%}', subset=['precision', 'recall', 'f1'])
)

relevant_tot,retrieved,retrieved_and_relevant,precision,recall,f1
152,50,50,100.0%,32.9%,49.5%


Therefore:
- Computing **precision**, i.e. " the fraction of retrieved documents that are relevant", would not work well, 
  since we don't have good annotations as we can't know wheter a document is irrelevant or was simply missed in the annotation process.
- Computing **recall**, i.e. "the fraction of relevant documents that are retrieved", makes more sense, 
  since it only relies on the relationships that are annotated as relevant and is not penalized by un-annotated relationships.<br/>
  However, we must keep in mind that `recall@k` is capped at 

## Our dataset

We only predict _predicate labels_ at the image level, regardless of how many predicates of that type are in the image and between which pairs of objects.

In [13]:
rg = np.random.default_rng(123)

objects = ['person', 'dog']
predicates = ['talk to', 'look at', 'next to', 'touching']

images = [
    {
        # Many annotated relationships, 
        # covering all predicates
        'objects': ['dog', 'person', 'dog'],
        'gt_relationships': {
            (0, 'talk to', 1),
            (1, 'look at', 0),
            (1, 'next to', 2),
            (2, 'next at', 1),
            (1, 'look at', 2),
            (2, 'look at', 1),
            (1, 'touching', 2),
        }
    },
    {
        # Many annotated relationships, 
        # covering only 1 predicate
        'objects': ['dog', 'person', 'person'],
        'gt_relationships': {
            (0, 'next to', 1),
            (1, 'next to', 0),
            (0, 'next to', 2),
            (2, 'next to', 0),
            (1, 'next to', 2),
            (1, 'next to', 1),
        }
    },
    {
        # Few annotated relationships, 
        # covering all predicates
        'objects': ['person', 'dog'],
        'gt_relationships': {
            (0, 'look at', 1),
            (0, 'next to', 1),
            (0, 'talk to', 1),
            (0, 'touching', 1),
        }
    },
    {
        # Few annotated relationships, 
        # covering only 1 predicate
        'objects': ['dog', 'dog'],
        'gt_relationships': {
            (0, 'next to', 1),
        }
    },
]

data = pd.DataFrame(
    [
        (
            img_idx, pred, 
            rg.integers(0, 100), 
            any(pred == p for _, p, _ in img_d['gt_relationships'])
        )
        for img_idx, img_d in enumerate(images)
        for pred in predicates
    ],
    columns=['img', 'predicate', 'score', 'annotated']
)

# Predictions for image 3 are very wrong
data.loc[(data.img == 3) & (data.annotated), 'score'] = rg.integers(0, 20, size=np.count_nonzero((data.img == 3) & (data.annotated)))
data.loc[(data.img == 3) & (~data.annotated), 'score'] = rg.integers(90, 100, size=np.count_nonzero((data.img == 3) & (~data.annotated)))

data.set_index(['img', 'predicate'], inplace=True)

display(
    data.style
    .format({True: 'v', False: 'x'}.get, subset='annotated')
    .bar(color='green', subset='score', vmin=0, vmax=100)
    .highlight_max(axis=0, color='""; color: green; font-weight: bold;', subset='annotated')
    .highlight_min(axis=0, color='""; color: red;', subset='annotated')
)

Unnamed: 0_level_0,Unnamed: 1_level_0,score,annotated
img,predicate,Unnamed: 2_level_1,Unnamed: 3_level_1
0,talk to,1,v
0,look at,68,v
0,next to,59,v
0,touching,5,v
1,talk to,90,x
1,look at,22,x
1,next to,25,v
1,touching,18,x
2,talk to,33,v
2,look at,17,v


### Database-like implementation

Top 3 relationships per image.

In [14]:
(
    data
    .groupby('img')
    .apply(lambda img_df: img_df.droplevel(0, axis=0).nlargest(3, 'score'))
    .style
    .format(['x', 'v'].__getitem__, subset='annotated')
    .bar(color='green', subset='score', vmin=0, vmax=100)
    .highlight_max(axis=0, color='""; color: green; font-weight: bold;', subset='annotated')
    .highlight_min(axis=0, color='""; color: red;', subset='annotated')
)

Unnamed: 0_level_0,Unnamed: 1_level_0,score,annotated
img,predicate,Unnamed: 2_level_1,Unnamed: 3_level_1
0,look at,68,v
0,next to,59,v
0,touching,5,v
1,talk to,90,x
1,next to,25,v
1,look at,22,x
2,touching,81,v
2,next to,34,v
2,talk to,33,v
3,talk to,98,x


In [15]:
(
    data
    .astype({'annotated': int})
    .groupby('img')
    .apply(pr, x=3)
    .droplevel(1, axis=0)
    .style
    .format('{:.1%}', subset=['precision', 'recall', 'f1'])
)

Unnamed: 0_level_0,relevant_tot,retrieved,retrieved_and_relevant,precision,recall,f1
img,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,4,3,3,100.0%,75.0%,85.7%
1,1,3,1,33.3%,100.0%,50.0%
2,4,3,3,100.0%,75.0%,85.7%
3,1,3,0,0.0%,0.0%,nan%


### Vectorized implementation

It might be easier to work with a matrix of binary labels of size `n_images x n_classes`

In [16]:
predictions = data.score.unstack('predicate')
targets = data.annotated.unstack('predicate').astype(int)

display(predictions.style.set_caption('Predictions').background_gradient(axis=1))
display(targets.style.set_caption('Annotations'))

predicate,look at,next to,talk to,touching
img,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,68,59,1,5
1,22,25,90,18
2,17,34,33,81
3,98,15,98,98


predicate,look at,next to,talk to,touching
img,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,1,1,1,1
1,0,1,0,0
2,1,1,1,1
3,0,1,0,0


In [17]:
top_3 = np.argsort(predictions.values)[:, -3:]
pd.DataFrame(
    top_3,
    index = predictions.index,
    columns=pd.RangeIndex(3, 0, -1, name='position')
)

position,3,2,1
img,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,3,1,0
1,0,1,2
2,2,1,3
3,0,2,3


In [18]:
is_prediction_relevant = np.take_along_axis(targets.values, top_3, axis=1)
pd.DataFrame(
    is_prediction_relevant,
    index = predictions.index,
    columns=pd.RangeIndex(3, 0, -1, name='position')
)

position,3,2,1
img,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,1,1,1
1,0,1,0
2,1,1,1
3,0,0,0


In [19]:
precision = is_prediction_relevant.mean(axis=1)
recall = is_prediction_relevant.sum(axis=1) / targets.sum(axis=1)

prf = pd.DataFrame(
    {
        'precision': precision,
        'recall': recall,
        'f1': 2 * precision * recall / (precision + recall)
    }, index=predictions.index
)
prf.loc['avg', :] = np.mean(prf.values, axis=0)

prf.style.format('{:.1%}')

Unnamed: 0_level_0,precision,recall,f1
img,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,100.0%,75.0%,85.7%
1,33.3%,100.0%,50.0%
2,100.0%,75.0%,85.7%
3,0.0%,0.0%,nan%
avg,58.3%,62.5%,nan%


### Pytorch example

In [20]:
torch.manual_seed(0)

scores = torch.randint(0, 100, size=(4, 7))
print('scores', scores, sep='\n', end='\n\n')
print('labels', torch.arange(7), sep='\n', end='\n\n')

annotations = (torch.rand(4, 7) > .7)
print('annotations', annotations.long(), sep='\n', end='\n\n')

sorted_labels = torch.argsort(scores, dim=1, descending=True)
print('sorted labels', sorted_labels, sep='\n', end='\n\n')
print('sorted scores', torch.gather(scores, index=sorted_labels, dim=1), sep='\n', end='\n\n')

annotations_of_top_3 = torch.gather(annotations, index=sorted_labels[:, :3], dim=1)
print('annotations_of_top_3', annotations_of_top_3.long(), sep='\n', end='\n\n')

print('Precision@3 per image', annotations_of_top_3.float().mean(dim=1, keepdims=True), sep='\n', end='\n\n')
print('Recall@3 per image', annotations_of_top_3.sum(dim=1, keepdims=True).float() / annotations.sum(dim=1, keepdims=True), sep='\n', end='\n\n')

print(f'Precision@3: {annotations_of_top_3.float().mean().item():.2%}')
print(f'Recall@3   : {(annotations_of_top_3.sum(dim=1, keepdims=True).float() / annotations.sum(dim=1, keepdims=True)).mean(axis=0).item():.2%}')

scores
tensor([[44, 39, 33, 60, 63, 79, 27],
        [ 3, 97, 83,  1, 66, 56, 99],
        [78, 76, 56, 68, 94, 33, 26],
        [19, 91, 54, 24, 41, 69, 69]])

labels
tensor([0, 1, 2, 3, 4, 5, 6])

annotations
tensor([[0, 0, 0, 1, 0, 0, 0],
        [0, 0, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 1, 1, 0],
        [0, 1, 1, 0, 0, 1, 0]])

sorted labels
tensor([[5, 4, 3, 0, 1, 2, 6],
        [6, 1, 2, 4, 5, 0, 3],
        [4, 0, 1, 3, 2, 5, 6],
        [1, 5, 6, 2, 4, 3, 0]])

sorted scores
tensor([[79, 63, 60, 44, 39, 33, 27],
        [99, 97, 83, 66, 56,  3,  1],
        [94, 78, 76, 68, 56, 33, 26],
        [91, 69, 69, 54, 41, 24, 19]])

annotations_of_top_3
tensor([[0, 0, 1],
        [0, 0, 1],
        [1, 0, 0],
        [1, 1, 0]])

Precision@3 per image
tensor([[0.3333],
        [0.3333],
        [0.3333],
        [0.6667]])

Recall@3 per image
tensor([[1.0000],
        [0.3333],
        [0.5000],
        [0.6667]])

Precision@3: 41.67%
Recall@3   : 62.50%


If we want to compute precision at multiple sizes efficiently, 
we can use a cumulative sum up to the maximum size needed, 
then compute the mean of the elements of interest.

Recall is the same but we divide by the number of relevant elements in that row.

In [21]:
sizes = [1, 2, 3, 4, 5, 6, 7]

sorted_labels = torch.argsort(scores, dim=1, descending=True)
annotations_of_top_S = torch.gather(annotations, index=sorted_labels[:, :max(sizes)], dim=1)

# cumsum[i, j] = number of relevant items within the top j+1 retrieved items
cumsum = annotations_of_top_S.cumsum(dim=1)

print('annotations_of_top_S', annotations_of_top_S.long(), sep='\n')
print('cumsum', cumsum, sep='\n', end='\n\n')
print('j+1    ', list(range(1, cumsum.shape[1] + 1)), end='\n\n')

# Cast to float to avoid int/int division.
cumsum = cumsum.float()

# Given a size s, cumsum[i, s-1] / s gives the precision for sample i.
# Then we take the batch mean.
print('Precision@:')

for s in sizes:
    print(f'{s}: {(cumsum[:, (s - 1)] / s).mean().item():.2%}')

# Divide each row by the total number of relevant document for that row to get the recall per sample.
# Then take the batch mean.
print('\nRecall@:')
num_rel = annotations.sum(dim=1, keepdims=True)
recall = (cumsum / num_rel).mean(axis=0)
for s in sizes:
    print(f'{s}: {recall[s - 1].item():.2%}')

annotations_of_top_S
tensor([[0, 0, 1, 0, 0, 0, 0],
        [0, 0, 1, 1, 0, 0, 1],
        [1, 0, 0, 0, 0, 1, 0],
        [1, 1, 0, 1, 0, 0, 0]])
cumsum
tensor([[0, 0, 1, 1, 1, 1, 1],
        [0, 0, 1, 2, 2, 2, 3],
        [1, 1, 1, 1, 1, 2, 2],
        [1, 2, 2, 3, 3, 3, 3]])

j+1     [1, 2, 3, 4, 5, 6, 7]

Precision@:
1: 50.00%
2: 37.50%
3: 41.67%
4: 43.75%
5: 35.00%
6: 33.33%
7: 32.14%

Recall@:
1: 20.83%
2: 29.17%
3: 62.50%
4: 79.17%
5: 79.17%
6: 91.67%
7: 100.00%
