-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
192 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List | ||
|
||
import mmengine | ||
from mmengine.dataset import BaseDataset | ||
from mmengine.fileio import get_file_backend | ||
|
||
from mmpretrain.registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class VisDial(BaseDataset): | ||
"""VisDial dataset. | ||
Args: | ||
data_root (str): The root directory for ``data_prefix``, ``ann_file`` | ||
and ``question_file``. | ||
data_prefix (str): The directory of images. | ||
question_file (str): Question file path. | ||
ann_file (str, optional): Annotation file path for training and | ||
validation. Defaults to an empty string. | ||
**kwargs: Other keyword arguments in :class:`BaseDataset`. | ||
""" | ||
|
||
def __init__(self, | ||
data_root: str, | ||
data_prefix: str, | ||
ann_file: str = '', | ||
**kwarg): | ||
super().__init__( | ||
data_root=data_root, | ||
data_prefix=dict(img_path=data_prefix), | ||
ann_file=ann_file, | ||
**kwarg, | ||
) | ||
|
||
def load_data_list(self) -> List[dict]: | ||
"""Load data list.""" | ||
annotations = mmengine.load(self.ann_file)['data'] | ||
|
||
dialogs = annotations['dialogs'] | ||
answers = annotations['answers'] | ||
questions = annotations['questions'] | ||
|
||
data_list = [] | ||
|
||
for dialog in dialogs: | ||
image_id = dialog['image_id'] | ||
caption = dialog['caption'] | ||
|
||
historys = ['Caption:' + caption + '.'] | ||
|
||
for i in range(1, len(dialog['dialog'])): | ||
historys.append('') | ||
|
||
previous_idx = i - 1 | ||
# for j in range(i): | ||
question_id = dialog['dialog'][previous_idx]['question'] | ||
answer_id = dialog['dialog'][previous_idx]['answer'] | ||
|
||
history = ' Question:{question}? Answer:{answer}.' \ | ||
.format(question=questions[question_id], | ||
answer=answers[answer_id]) | ||
|
||
historys[i] = historys[previous_idx] + history | ||
|
||
# get question and answer options for each dialog round | ||
for dialog_id, dialog_round in enumerate(dialog['dialog']): | ||
question_id = dialog_round['question'] | ||
answer_id = dialog_round['answer'] | ||
answer_options = [ | ||
answers[answer_id] | ||
for answer_id in dialog_round['answer_options'] | ||
] | ||
|
||
data_info = dict(image_id=image_id) | ||
|
||
img_prefix = self.data_prefix['img_path'] | ||
file_backend = get_file_backend(img_prefix) | ||
|
||
data_info['img_path'] = file_backend.join_path( | ||
img_prefix, | ||
img_prefix.split('/')[-1] + '_' + str(image_id).zfill(12) + | ||
'.jpg') | ||
|
||
data_info['dialog_history'] = historys[dialog_id] | ||
|
||
data_info['question'] = questions[question_id] + '?' | ||
data_info['answer'] = answers[answer_id] | ||
data_info['answer_options'] = answer_options | ||
data_info['gt_answer_index'] = data_info[ | ||
'answer_options'].index(data_info['answer']) | ||
|
||
data_list.append(data_info) | ||
|
||
return data_list | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List, Optional | ||
|
||
import torch | ||
from mmengine.evaluator import BaseMetric | ||
|
||
from mmpretrain.evaluation.metrics.vqa import (_process_digit_article, | ||
_process_punctuation) | ||
from mmpretrain.registry import METRICS | ||
|
||
|
||
@METRICS.register_module() | ||
class SparseGTMetrics(BaseMetric): | ||
"""Visual Dialog Acc metric. | ||
Compute Visual Dialogaccuracy. | ||
Args: | ||
collect_device (str): Device name used for collecting results from | ||
different ranks during distributed training. Must be 'cpu' or | ||
'gpu'. Defaults to 'cpu'. | ||
prefix (str, optional): The prefix that will be added in the metric | ||
names to disambiguate homonymous metrics of different evaluators. | ||
If prefix is not provided in the argument, self.default_prefix | ||
will be used instead. Should be modified according to the | ||
`retrieval_type` for unambiguous results. Defaults to TR. | ||
""" | ||
default_prefix = 'Visual Dialog' | ||
|
||
def __init__(self, | ||
collect_device: str = 'cpu', | ||
prefix: Optional[str] = None) -> None: | ||
super().__init__(collect_device=collect_device, prefix=prefix) | ||
|
||
def process(self, data_batch, data_samples) -> None: | ||
"""Process one batch of data samples. | ||
The processed results should be stored in ``self.results``, which will | ||
be used to computed the metrics when all batches have been processed. | ||
Args: | ||
data_batch: A batch of data from the dataloader. | ||
data_samples (Sequence[dict]): A batch of outputs from the model. | ||
""" | ||
for sample in data_samples: | ||
answer_options = sample.get('answer_options') | ||
|
||
G = torch.Generator() | ||
G.manual_seed(0) | ||
rank = 1 + torch.randperm(len(answer_options), generator=G) | ||
|
||
pred_answer = sample.get('pred_answer') | ||
|
||
if pred_answer in answer_options: | ||
answer_index = answer_options.index(pred_answer) | ||
rank[answer_index] = 1 | ||
|
||
gt_index = sample.get('gt_answer_index') | ||
gt_rank = rank[gt_index] | ||
|
||
self.results.append(gt_rank) | ||
|
||
def compute_metrics(self, results: List) -> dict: | ||
"""Compute the metrics from processed results. | ||
Args: | ||
results (dict): The processed results of each batch. | ||
Returns: | ||
Dict: The computed metrics. The keys are the names of the metrics, | ||
and the values are corresponding results. | ||
""" | ||
|
||
R1 = (torch.tensor(results) <= 1).float().mean() | ||
R5 = (torch.tensor(results) <= 5).float().mean() | ||
R10 = (torch.tensor(results) <= 10).float().mean() | ||
Mean = torch.tensor(results).float().mean() | ||
MRR = torch.tensor(results).reciprocal().mean() | ||
|
||
metrics = { | ||
'R@1': R1.item(), | ||
'R@5': R5.item(), | ||
'R@10': R10.item(), | ||
'Mean': Mean.item(), | ||
'MRR': MRR.item() | ||
} | ||
return metrics | ||
|
||
def _process_answer(self, answer) -> str: | ||
answer = _process_punctuation(answer) | ||
answer = _process_digit_article(answer) | ||
return answer | ||