Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
wangbo-zhao committed Jun 29, 2023
1 parent 956f65c commit 86166b1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 37 deletions.
30 changes: 0 additions & 30 deletions a.py

This file was deleted.

2 changes: 1 addition & 1 deletion mmpretrain/datasets/visdial.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def load_data_list(self) -> List[dict]:

data_info['dialog_history'] = historys[dialog_id]

data_info['question'] = questions[question_id] + "?"
data_info['question'] = questions[question_id] + '?'
data_info['answer'] = answers[answer_id]
data_info['answer_options'] = answer_options
data_info['gt_answer_index'] = data_info[
Expand Down
16 changes: 10 additions & 6 deletions mmpretrain/evaluation/metrics/visual_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,16 @@ def process(self, data_batch, data_samples) -> None:
"""
for sample in data_samples:
answer_options = sample.get('answer_options')

G = torch.Generator()
G.manual_seed(0)
rank = 1 + torch.randperm(len(answer_options), generator=G)

pred_answer = sample.get('pred_answer')

if pred_answer in answer_options:
answer_index = answer_options.index(pred_answer)
rank[answer_index] = 1



gt_index = sample.get('gt_answer_index')
gt_rank = rank[gt_index]
Expand All @@ -79,7 +77,13 @@ def compute_metrics(self, results: List) -> dict:
Mean = torch.tensor(results).float().mean()
MRR = torch.tensor(results).reciprocal().mean()

metrics = {'R@1': R1.item(), 'R@5': R5.item(), 'R@10': R10.item(), 'Mean': Mean.item(), 'MRR': MRR.item()}
metrics = {
'R@1': R1.item(),
'R@5': R5.item(),
'R@10': R10.item(),
'Mean': Mean.item(),
'MRR': MRR.item()
}
return metrics

def _process_answer(self, answer) -> str:
Expand Down

0 comments on commit 86166b1

Please sign in to comment.