From 86166b1c7800aa0c6b1b317c023ad03d8873800c Mon Sep 17 00:00:00 2001 From: wangbo-zhao Date: Thu, 29 Jun 2023 14:57:10 +0800 Subject: [PATCH] pre-commit --- a.py | 30 ------------------- mmpretrain/datasets/visdial.py | 2 +- .../evaluation/metrics/visual_dialog.py | 16 ++++++---- 3 files changed, 11 insertions(+), 37 deletions(-) delete mode 100644 a.py diff --git a/a.py b/a.py deleted file mode 100644 index 50cc68d64cf..00000000000 --- a/a.py +++ /dev/null @@ -1,30 +0,0 @@ -from mmpretrain.datasets import VisDial - -test_pipeline = [ - dict(type='mmpretrain.LoadImageFromFile'), - dict( - type='mmpretrain.ResizeEdge', - scale=224, - interpolation='bicubic', - backend='pillow'), - dict(type='mmpretrain.CenterCrop', crop_size=(224, 224)), - dict( - type='mmpretrain.PackInputs', - algorithm_keys=['question', 'gt_answer', 'sub_set'], - meta_keys=['image_id'], - ), -] - -dataset = VisDial( - data_root='data/visualdialogue', - data_prefix='VisualDialog_val2018', - ann_file='visdial_1.0_val.json', - pipeline=test_pipeline) - -# dataset = ChartQA( -# data_root='data/chartqa/train', -# data_prefix='png', -# ann_file=['train_human.json', ], -# pipeline=test_pipeline) - -print('a') diff --git a/mmpretrain/datasets/visdial.py b/mmpretrain/datasets/visdial.py index c98fcc5c436..66f3379f8f8 100644 --- a/mmpretrain/datasets/visdial.py +++ b/mmpretrain/datasets/visdial.py @@ -85,7 +85,7 @@ def load_data_list(self) -> List[dict]: data_info['dialog_history'] = historys[dialog_id] - data_info['question'] = questions[question_id] + "?" + data_info['question'] = questions[question_id] + '?' data_info['answer'] = answers[answer_id] data_info['answer_options'] = answer_options data_info['gt_answer_index'] = data_info[ diff --git a/mmpretrain/evaluation/metrics/visual_dialog.py b/mmpretrain/evaluation/metrics/visual_dialog.py index d877d95b24c..acfe0c5b2b7 100644 --- a/mmpretrain/evaluation/metrics/visual_dialog.py +++ b/mmpretrain/evaluation/metrics/visual_dialog.py @@ -44,18 +44,16 @@ def process(self, data_batch, data_samples) -> None: """ for sample in data_samples: answer_options = sample.get('answer_options') - + G = torch.Generator() G.manual_seed(0) rank = 1 + torch.randperm(len(answer_options), generator=G) - + pred_answer = sample.get('pred_answer') - + if pred_answer in answer_options: answer_index = answer_options.index(pred_answer) rank[answer_index] = 1 - - gt_index = sample.get('gt_answer_index') gt_rank = rank[gt_index] @@ -79,7 +77,13 @@ def compute_metrics(self, results: List) -> dict: Mean = torch.tensor(results).float().mean() MRR = torch.tensor(results).reciprocal().mean() - metrics = {'R@1': R1.item(), 'R@5': R5.item(), 'R@10': R10.item(), 'Mean': Mean.item(), 'MRR': MRR.item()} + metrics = { + 'R@1': R1.item(), + 'R@5': R5.item(), + 'R@10': R10.item(), + 'Mean': Mean.item(), + 'MRR': MRR.item() + } return metrics def _process_answer(self, answer) -> str: