Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fFeat] Add an opensource dataset Tabmwp #505

Merged
merged 8 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions configs/datasets/TabMWP/TabMWP_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .TabMWP_gen_2aef96 import TabMWP_datasets # noqa: F401, F403
53 changes: 53 additions & 0 deletions configs/datasets/TabMWP/TabMWP_gen_2aef96.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import TabMWPDataset, TabMWPEvaluator

# None of the TabMWP dataset in huggingface is correctly parsed, so we use our own dataset reader
# Please download the dataset from https://github.com/lupantech/PromptPG/tree/main

input_format='TQ'
output_format='A'
elements = {"Q": "Question: {question}",
"T": "Table: {table}",
"S": "Solution: {solution}",
"A": "Answer: The answer is {answer}.",
"AS": "Answer: The answer is {answer}. BECAUSE: {solution}",
"SA": "Answer: {solution} The answer is {answer}."}


TabMWP_reader_cfg = dict(
input_columns=["question", "table"],
output_column="test_elements",
train_split='dev',
)

TabMWP_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role="HUMAN",
prompt= "\n".join(elements[label] for label in input_format)
),
],
),
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer),
)

TabMWP_eval_cfg = dict(
evaluator=dict(type=TabMWPEvaluator)
)

TabMWP_datasets = [
dict(
type=TabMWPDataset,
path="./data/tabmwp/",
reader_cfg=TabMWP_reader_cfg,
infer_cfg=TabMWP_infer_cfg,
eval_cfg=TabMWP_eval_cfg,)
]

1 change: 1 addition & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from .strategyqa import * # noqa: F401, F403
from .summedits import * # noqa: F401, F403
from .summscreen import * # noqa: F401, F403
from .tabmwp import * # noqa: F401, F403
from .TheoremQA import * # noqa: F401, F403
from .tnews import * # noqa: F401, F403
from .triviaqa import * # noqa: F401, F403
Expand Down
276 changes: 276 additions & 0 deletions opencompass/datasets/tabmwp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
import json
import os.path as osp
import random
import re
from typing import List

import numpy as np
from datasets import Dataset, DatasetDict

from opencompass.openicl.icl_evaluator.icl_hf_evaluator import \
HuggingfaceEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET

from .base import BaseDataset


def get_table_text(problem):
table = problem['table']
title = problem['table_title']
if title and len(title) > 0:
table = f'[TITLE]: {title}\n{table}'
return table


def get_question_text(problem, option_inds='ABCDEFGH'):
question = problem['question']

unit = problem['unit']
if unit and len(unit) > 0:
question = f'{question} (Unit: {unit})'

choices = problem['choices']
if choices and len(choices) > 0:
choice_list = []
for i, c in enumerate(choices):
choice_list.append('({}) {}'.format(option_inds[i], c))
options = ' '.join(choice_list)
question = f'{question}\nOptions: {options}'

return question


def get_answer(problem):
return problem['answer']


def get_choices(problem):
return problem['choices']


def get_unit(problem):
return problem['unit']


def get_solution_text(problem):
# \\n: GPT-3 can generate the solution with more tokens
solution = problem['solution'].replace('\n', '\\n')
return solution


def normalize_answer(text, unit):
# ["1,000", "123", "3/4", "56.456", "$56.4", "-3", "-10.02", "-3/2"]

text = re.sub(r'^[\$]', '', text)
text = re.sub(r'[\,\.\,\/]$', '', text)

result = re.match(r'^[-+]?[\d,./]+$', text)

if result is not None:
# is number?
text = text.replace(',', '')
result = re.match(r'[-+]?\d+$', text)

if result is not None:
number = int(text)
elif '/' in text:
nums = text.split('/')
number = round(float(nums[0]) / float(nums[1]), 3)
else:
number = round(float(text), 3)
number = str(number)
number = re.sub(r'\.[0]+$', '', number)
return number
else:
# is text
if unit:
text = text.replace(unit, '').strip()
return text


def score_string_similarity(str1, str2):
if str1 == str2:
return 2.0
if ' ' in str1 or ' ' in str2:
str1_split = str1.split(' ')
str2_split = str2.split(' ')
overlap = list(set(str1_split) & set(str2_split))
return len(overlap) / max(len(str1_split), len(str2_split))
else:
if str1 == str2:
return 1.0
else:
return 0.0


def extract_prediction(output, options=None, option_inds='ABCDEFGH'):

# $\\frac{16}{95}$ -> 16/95
output = re.sub(r'\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?', r'\1/\2',
output)

output = re.sub(r'(?<![AP]\.M)\.$', '', output)
output = re.sub(r'(?<=\d)[\=](?=[\-\$\d])', ' = ', output)
output = re.sub(r'\u2212', '-', output)

# Multi-choice questions
if options:
patterns = [
r'^\(([A-Za-z])\)$', # "(b)", "(B)"
r'^([A-Za-z])$', # "b", "B"
r'^([A-Za-z]). ', # "b", "B"
r'[Th]he answer is ([A-Z])', # "The answer is B"
r'^\(([A-Za-z])\) [\s\S]+$', # "(A) XXXXX"
r'[Th]he answer is \(([A-Za-z])\) [\s\S]+$'
]

# have "X" in the output
for p in patterns:
pattern = re.compile(p)
res = pattern.findall(output)
if len(res) > 0:
pred = res[0].upper() # e.g., "B"
if pred in option_inds:
ind = option_inds.index(pred) # 1
if ind >= len(options):
random.seed(123)
ind = random.choice(range(len(options)))
prediction = options[ind]
return prediction

# find the most similar options
scores = [score_string_similarity(x, output) for x in options]
max_idx = int(
np.argmax(scores)) # json does not recognize NumPy data types
prediction = options[max_idx]
return prediction

else:
# free_text QA problems, numeric answer
patterns = [
r'[Th]he answer is ([\s\S]+)$', # "The answer is XXXXX.",
r'[Th]he table shows that ([\d\$\.\,\/\:]+) ',
r' = ([\d\$\.\,\/\:]+)', # "= $1.40"
r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "will be $1.40"
r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', # 7:25 P.M.
r'([\-\d\$\.\,\/\:]{0,}[\d]+)', # 14.5
]

for p in patterns:
pattern = re.compile(p)
res = pattern.findall(output)
if len(res) > 0:
prediction = res[-1].strip()
if prediction.endswith('.') and '.M.' not in prediction:
prediction = prediction[:-1]
return prediction

return output


@ICL_EVALUATORS.register_module()
class TabMWPEvaluator(HuggingfaceEvaluator):
"""Accuracy evaluator."""

def __init__(self) -> None:
super().__init__(metric='accuracy')

def _preprocess(self, predictions: List, references: List) -> dict:
preds, golds = [], []
for idx in range(len(references)):
pred = predictions[idx]
unit = references[idx]['unit']
answer = references[idx]['answer']
choices = references[idx]['choices']
preds.append(
normalize_answer(extract_prediction(pred, choices),
unit).lower())
golds.append(normalize_answer(answer, unit).lower())
"""Preprocess the final predictions and references to needed format.

Args:
predictions (List): List of predictions of each sample.
references (List): List of targets for each sample.

Returns:
dict: preprocessed results.
"""
predictions = preds
references = golds
mapping_to_int_dict = {
label: idx
yingfhu marked this conversation as resolved.
Show resolved Hide resolved
for idx, label in enumerate(set(map(str, references)))
}
pred_set = set(predictions)
for pred in pred_set:
if str(pred) not in mapping_to_int_dict.keys():
mapping_to_int_dict[str(pred)] = len(mapping_to_int_dict)
golds = [mapping_to_int_dict[str(gold)] for gold in references]
preds = [mapping_to_int_dict[str(pred)] for pred in predictions]
return {
'predictions': preds,
'references': golds,
}

def _postprocess(self, scores: dict) -> dict:
"""Postprocess for final scores.

Args:
scores (dict): Dict of calculated scores of metrics.

Returns:
dict: postprocessed scores.
"""
scores['accuracy'] *= 100
return scores


@LOAD_DATASET.register_module()
class TabMWPDataset(BaseDataset):
# The TabMWP dataset contains 38,431 tabular math word problems.
# Each question in TabMWP is aligned with a tabular context,
# which is presented as an image, semi-structured text, and a-
# structured table. There are two types of questions: free-text-
# and multi-choice, and each problem is annotated with gold-
# solutions to reveal the multi-step reasoning process.
# To learn more about it, please follow:
# https://github.com/lupantech/PromptPG/tree/main
@staticmethod
def load(path: str):
dataset = DatasetDict()
for split in ['dev', 'test', 'train']:
raw_data = []
filename = osp.join(path, f'problems_{split}.json')
with open(filename, 'r', encoding='utf-8') as f:
json_data = json.load(f)
for idx in json_data:
problem = json_data[idx]
question = get_question_text(problem)
table = get_table_text(problem)
unit = get_unit(problem)
answer = get_answer(problem)
choices = get_choices(problem)
solution = get_solution_text(problem)
raw_data.append({
'question':
question,
'table':
table,
'test_elements': {
'answer': answer,
'unit': unit,
'choices': choices
},
'answer':
f'Answer: The answer is {answer}.',
'solution':
f'Solution: {solution}',
'answer_and_solution':
f'Answer: The answer is {answer}. BECAUSE: {solution}',
'solution_and_answer':
f'Answer: {solution} The answer is {answer}.'
})
dataset[split] = Dataset.from_list(raw_data)
return dataset