Skip to content

Commit

Permalink
Merge 86166b1 into 7d850df
Browse files Browse the repository at this point in the history
  • Loading branch information
wangbo-zhao committed Jun 29, 2023
2 parents 7d850df + 86166b1 commit 615f713
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 2 deletions.
3 changes: 2 additions & 1 deletion mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .refcoco import RefCOCO
from .scienceqa import ScienceQA
from .textvqa import TextVQA
from .visdial import VisDial

Check warning on line 49 in mmpretrain/datasets/__init__.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/__init__.py#L48-L49

Added lines #L48 - L49 were not covered by tests
from .visual_genome import VisualGenomeQA
from .vizwiz import VizWiz
from .vsr import VSR

Check warning on line 52 in mmpretrain/datasets/__init__.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/__init__.py#L51-L52

Added lines #L51 - L52 were not covered by tests
Expand All @@ -54,5 +55,5 @@
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA',
'VSR', 'VizWiz', 'OCRVQA'
'VSR', 'VizWiz', 'OCRVQA', 'VisDial'
])
96 changes: 96 additions & 0 deletions mmpretrain/datasets/visdial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

Check warning on line 2 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L2

Added line #L2 was not covered by tests

import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend

Check warning on line 6 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L4-L6

Added lines #L4 - L6 were not covered by tests

from mmpretrain.registry import DATASETS

Check warning on line 8 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L8

Added line #L8 was not covered by tests


@DATASETS.register_module()
class VisDial(BaseDataset):

Check warning on line 12 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L12

Added line #L12 was not covered by tests
"""VisDial dataset.
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
question_file (str): Question file path.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,

Check warning on line 25 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L25

Added line #L25 was not covered by tests
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(

Check warning on line 30 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L30

Added line #L30 was not covered by tests
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)

def load_data_list(self) -> List[dict]:

Check warning on line 37 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L37

Added line #L37 was not covered by tests
"""Load data list."""
annotations = mmengine.load(self.ann_file)['data']

Check warning on line 39 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L39

Added line #L39 was not covered by tests

dialogs = annotations['dialogs']
answers = annotations['answers']
questions = annotations['questions']

Check warning on line 43 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L41-L43

Added lines #L41 - L43 were not covered by tests

data_list = []

Check warning on line 45 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L45

Added line #L45 was not covered by tests

for dialog in dialogs:
image_id = dialog['image_id']
caption = dialog['caption']

Check warning on line 49 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L48-L49

Added lines #L48 - L49 were not covered by tests

historys = ['Caption:' + caption + '.']

Check warning on line 51 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L51

Added line #L51 was not covered by tests

for i in range(1, len(dialog['dialog'])):
historys.append('')

Check warning on line 54 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L54

Added line #L54 was not covered by tests

previous_idx = i - 1

Check warning on line 56 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L56

Added line #L56 was not covered by tests
# for j in range(i):
question_id = dialog['dialog'][previous_idx]['question']
answer_id = dialog['dialog'][previous_idx]['answer']

Check warning on line 59 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L58-L59

Added lines #L58 - L59 were not covered by tests

history = ' Question:{question}? Answer:{answer}.' \

Check warning on line 61 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L61

Added line #L61 was not covered by tests
.format(question=questions[question_id],
answer=answers[answer_id])

historys[i] = historys[previous_idx] + history

Check warning on line 65 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L65

Added line #L65 was not covered by tests

# get question and answer options for each dialog round
for dialog_id, dialog_round in enumerate(dialog['dialog']):
question_id = dialog_round['question']
answer_id = dialog_round['answer']

Check warning on line 70 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L69-L70

Added lines #L69 - L70 were not covered by tests
answer_options = [
answers[answer_id]
for answer_id in dialog_round['answer_options']
]

data_info = dict(image_id=image_id)

Check warning on line 76 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L76

Added line #L76 was not covered by tests

img_prefix = self.data_prefix['img_path']
file_backend = get_file_backend(img_prefix)

Check warning on line 79 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L78-L79

Added lines #L78 - L79 were not covered by tests

data_info['img_path'] = file_backend.join_path(

Check warning on line 81 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L81

Added line #L81 was not covered by tests
img_prefix,
img_prefix.split('/')[-1] + '_' + str(image_id).zfill(12) +
'.jpg')

data_info['dialog_history'] = historys[dialog_id]

Check warning on line 86 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L86

Added line #L86 was not covered by tests

data_info['question'] = questions[question_id] + '?'
data_info['answer'] = answers[answer_id]
data_info['answer_options'] = answer_options
data_info['gt_answer_index'] = data_info[

Check warning on line 91 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L88-L91

Added lines #L88 - L91 were not covered by tests
'answer_options'].index(data_info['answer'])

data_list.append(data_info)

Check warning on line 94 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L94

Added line #L94 was not covered by tests

return data_list

Check warning on line 96 in mmpretrain/datasets/visdial.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/datasets/visdial.py#L96

Added line #L96 was not covered by tests
3 changes: 2 additions & 1 deletion mmpretrain/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .retrieval import RetrievalAveragePrecision, RetrievalRecall
from .scienceqa import ScienceQAMetric
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
from .visual_dialog import SparseGTMetrics
from .visual_grounding_eval import VisualGroundingMetric
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric
from .vqa import ReportVQA, VQAAcc
Expand All @@ -16,5 +17,5 @@
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc', 'NocapsSave',
'RetrievalAveragePrecision'
'RetrievalAveragePrecision', 'SparseGTMetrics'
]
92 changes: 92 additions & 0 deletions mmpretrain/evaluation/metrics/visual_dialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch
from mmengine.evaluator import BaseMetric

from mmpretrain.evaluation.metrics.vqa import (_process_digit_article,
_process_punctuation)
from mmpretrain.registry import METRICS


@METRICS.register_module()
class SparseGTMetrics(BaseMetric):
"""Visual Dialog Acc metric.
Compute Visual Dialogaccuracy.
Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix = 'Visual Dialog'

def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)

Check warning on line 33 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L33

Added line #L33 was not covered by tests

def process(self, data_batch, data_samples) -> None:
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for sample in data_samples:
answer_options = sample.get('answer_options')

Check warning on line 46 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L46

Added line #L46 was not covered by tests

G = torch.Generator()
G.manual_seed(0)
rank = 1 + torch.randperm(len(answer_options), generator=G)

Check warning on line 50 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L48-L50

Added lines #L48 - L50 were not covered by tests

pred_answer = sample.get('pred_answer')

Check warning on line 52 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L52

Added line #L52 was not covered by tests

if pred_answer in answer_options:
answer_index = answer_options.index(pred_answer)
rank[answer_index] = 1

Check warning on line 56 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L55-L56

Added lines #L55 - L56 were not covered by tests

gt_index = sample.get('gt_answer_index')
gt_rank = rank[gt_index]

Check warning on line 59 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L58-L59

Added lines #L58 - L59 were not covered by tests

self.results.append(gt_rank)

Check warning on line 61 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L61

Added line #L61 was not covered by tests

def compute_metrics(self, results: List) -> dict:
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

R1 = (torch.tensor(results) <= 1).float().mean()
R5 = (torch.tensor(results) <= 5).float().mean()
R10 = (torch.tensor(results) <= 10).float().mean()
Mean = torch.tensor(results).float().mean()
MRR = torch.tensor(results).reciprocal().mean()

Check warning on line 78 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L74-L78

Added lines #L74 - L78 were not covered by tests

metrics = {

Check warning on line 80 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L80

Added line #L80 was not covered by tests
'R@1': R1.item(),
'R@5': R5.item(),
'R@10': R10.item(),
'Mean': Mean.item(),
'MRR': MRR.item()
}
return metrics

Check warning on line 87 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L87

Added line #L87 was not covered by tests

def _process_answer(self, answer) -> str:
answer = _process_punctuation(answer)
answer = _process_digit_article(answer)
return answer

Check warning on line 92 in mmpretrain/evaluation/metrics/visual_dialog.py

View check run for this annotation

Codecov / codecov/patch

mmpretrain/evaluation/metrics/visual_dialog.py#L90-L92

Added lines #L90 - L92 were not covered by tests

0 comments on commit 615f713

Please sign in to comment.