From d64eabf356d5f757d26eb09d075f92aed18f85f2 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 6 May 2024 14:40:04 +0000 Subject: [PATCH 01/13] Implement Cacheing and parseing for date context --- prompting/tasks/date_qa.py | 18 +-- prompting/tools/datasets/wiki.py | 203 +++++++++++++++++++------------ 2 files changed, 134 insertions(+), 87 deletions(-) diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index bee5776b2..e24a65d4e 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -2,13 +2,16 @@ from prompting.tasks import Task from prompting.cleaners.cleaner import CleanerPipeline - -SECTION_MESSAGES = {"Births": " was born ", "Deaths": " died ", "Events": " "} - +QUERY_SYSTEM_PROMPT = """You are a question creation expert. When asked to create a question, you use the context to make a specific question that would have the answer .""" +QUERY_PROMPT_TEMPLATE = """\ +Create a question that would have as the answer using the following context: +{context} +""" @dataclass class DateQuestionAnsweringTask(Task): name = "date_qa" + challenge_type = 'query' desc = "get help answering a specific date-based question" goal = "to get the answer to the following date-based question" reward_definition = [ @@ -24,11 +27,10 @@ class DateQuestionAnsweringTask(Task): def __init__(self, llm_pipeline, context, create_reference =True): self.context = context - - self.query = ( - context.content + SECTION_MESSAGES[context.topic] + "on what exact date?" - ) - self.reference = self.context.title.replace("_", " ") + ", " + context.subtopic + self.query_system_prompt = QUERY_SYSTEM_PROMPT + self.query_prompt = QUERY_PROMPT_TEMPLATE.format(context=context.content[1]) + self.query = self.generate_query(llm_pipeline) + self.reference = self.context['content'][0] self.topic = context.title self.subtopic = context.topic diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index 6da057f64..b12a51316 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -23,12 +23,16 @@ import bittensor as bt import wikipedia as wiki from typing import Dict, Union, List, Tuple - +from queue import Queue from functools import lru_cache from .base import Dataset from ..selector import Selector +# Create a queue called CACHED_ARTICLES to store wikipedia articles that have been fetched +CACHED_ARTICLES = Queue(maxsize=300) + + # speed up page loading @lru_cache(maxsize=1000) def _get_page( @@ -208,7 +212,7 @@ def get( key = header, section_title = selector(list(sections.keys())) content = "\n".join(sections[key]) section_length = len(content.split()) - return { + context = { "title": name, # title of wiki article "topic": header or section_title, # title of wiki section "subtopic": section_title, @@ -223,6 +227,11 @@ def get( "section_length": section_length, }, } + try: + CACHED_ARTICLES.put(page, block=False) + except CACHED_ARTICLES.Full: + bt.logging.debug("Cache is full. Skipping article until cache is emptied.") + return context def search(self, name, results=3, selector: Selector = None) -> Dict: titles = _wiki_search(name, results=results) @@ -263,83 +272,119 @@ def __init__(self, max_tries: int = 10, seed=None): self.seed = seed self.rng = random.Random(seed) - def _random_date(self, year: int = None, month: int = None) -> int: - """Returns a random date in the format "Month_DD" (e.g., "January_01").""" - if year is None: - year = self.rng.randint(0, 2024) - if month is None: - month = self.rng.randint(1, 12) - - max_days = 31 if month in (1, 3, 5, 7, 8, 10, 12) else 30 - max_days = max_days if month != 2 else 29 - - day = self.rng.randint(1, max_days) - - random_date = datetime.date(year, month, day) - # Step 2: Format the date for Wikipedia URL - return random_date.strftime("%B %-d") # E.g., "January 1" - - def get( - self, - name, - pageid=None, - auto_suggest=False, - redirect=False, - selector: Selector = None, - ) -> Dict: - # Check that name is correctly formatted e.g., "January 1" - date = name.split(" ") - assert ( - len(date) == 2 - ), f"Date should be in the format 'Month D[D]' (e.g., 'January 1' or 'March 28'), but got {name!r}" - assert ( - date[0] in self.MONTHS - ), f"Month should be one of {self.MONTHS}, but got {date[0]!r}" - assert date[1].isdigit(), f"Day should be a number, but got {date[1]!r}" - - page = _get_page( - title=name, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect - ) - if page is None: - return None - - # Only return a sections which contain event-like format - # e.g. "1999 - Some event happened" - sections = process_page( - page, - valid_header=lambda x: x in self.INCLUDE_HEADERS, - valid_content=lambda x: any( - [re.search(r"^\d+", line) for line in x.splitlines()] - ), - ) - if not sections: - return None - - key = header, section_title = selector(list(sections.keys())) - line = selector(sections[key]) - year, *event = line.replace("\u2013", "-").split("-") - links = [link for link in page.links if link in line] - - return { - "title": name, # title of wiki article - "topic": header or section_title, # title of wiki section - "subtopic": year.strip(), - "content": "-".join(event).strip(". "), - "internal_links": list(sections.keys()), - "external_links": links, - "tags": filter_categories( - page.categories, exclude=WikiDataset.EXCLUDE_CATEGORIES - ), - "source": "Wikipedia", - "extra": { - "url": page.url, - "year": year, - "event": event, - "line": line, - "date": date + [year], - "section_title": section_title, - }, - } + # def _random_date(self, year: int = None, month: int = None) -> int: + # """Returns a random date in the format "Month_DD" (e.g., "January_01").""" + # if year is None: + # year = self.rng.randint(0, 2024) + # if month is None: + # month = self.rng.randint(1, 12) + + # max_days = 31 if month in (1, 3, 5, 7, 8, 10, 12) else 30 + # max_days = max_days if month != 2 else 29 + + # day = self.rng.randint(1, max_days) + + # random_date = datetime.date(year, month, day) + # # Step 2: Format the date for Wikipedia URL + # return random_date.strftime("%B %-d") # E.g., "January 1" + + # def get( + # self, + # name, + # pageid=None, + # auto_suggest=False, + # redirect=False, + # selector: Selector = None, + # ) -> Dict: + # # Check that name is correctly formatted e.g., "January 1" + # date = name.split(" ") + # assert ( + # len(date) == 2 + # ), f"Date should be in the format 'Month D[D]' (e.g., 'January 1' or 'March 28'), but got {name!r}" + # assert ( + # date[0] in self.MONTHS + # ), f"Month should be one of {self.MONTHS}, but got {date[0]!r}" + # assert date[1].isdigit(), f"Day should be a number, but got {date[1]!r}" + + # page = _get_page( + # title=name, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect + # ) + # if page is None: + # return None + + # # Only return a sections which contain event-like format + # # e.g. "1999 - Some event happened" + # sections = process_page( + # page, + # valid_header=lambda x: x in self.INCLUDE_HEADERS, + # valid_content=lambda x: any( + # [re.search(r"^\d+", line) for line in x.splitlines()] + # ), + # ) + # if not sections: + # return None + + # key = header, section_title = selector(list(sections.keys())) + # line = selector(sections[key]) + # year, *event = line.replace("\u2013", "-").split("-") + # links = [link for link in page.links if link in line] + + # return { + # "title": name, # title of wiki article + # "topic": header or section_title, # title of wiki section + # "subtopic": year.strip(), + # "content": "-".join(event).strip(". "), + # "internal_links": list(sections.keys()), + # "external_links": links, + # "tags": filter_categories( + # page.categories, exclude=WikiDataset.EXCLUDE_CATEGORIES + # ), + # "source": "Wikipedia", + # "extra": { + # "url": page.url, + # "year": year, + # "event": event, + # "line": line, + # "date": date + [year], + # "section_title": section_title, + # }, + # } + def extract_dates_and_sentences(text) -> Tuple[str, str]: + # Regular expression to find dates in various formats + date_pattern = r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?(?:,)?\s+\d{4}\b|\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember))\s+\d{4}\b|\b\d{4}\b' + + # Compile the regex pattern + date_regex = re.compile(date_pattern) + + # Split text into sentences + sentences = re.split(r'(?').strip()) + return None + + def _random_date(self) -> str: + for _ in range(self.max_tries): + try: + page = CACHED_ARTICLES.get(block=False) + page['content'] = self.extract_dates_and_sentences(page['content']) + if page['content'] is None: + continue + else: + return page + + except CACHED_ARTICLES.Empty: + bt.logging.debug("Cache is empty. Skipping date until cache is filled.") + return None def search(self, name, results=5, selector: Selector = None) -> Dict: raise NotImplementedError( From 79d2d22c2b50d23a423a87ab492c8f067d465f63 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 6 May 2024 14:52:31 +0000 Subject: [PATCH 02/13] Add get --- prompting/tools/datasets/wiki.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index b12a51316..8a8d52824 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -386,6 +386,10 @@ def _random_date(self) -> str: bt.logging.debug("Cache is empty. Skipping date until cache is filled.") return None + def get(self) -> Dict: + # Currently not a way to specify which context to fetch with Queue + return self.random() + def search(self, name, results=5, selector: Selector = None) -> Dict: raise NotImplementedError( f"Search is not implemented for {self.__class__.__name__}" From 366b3df07cc7b8ba3ce71c65c2e8fe778c14d977 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 6 May 2024 15:02:19 +0000 Subject: [PATCH 03/13] Handle Queue errors --- prompting/tools/datasets/wiki.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index 8a8d52824..64b7c6bf7 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -23,7 +23,7 @@ import bittensor as bt import wikipedia as wiki from typing import Dict, Union, List, Tuple -from queue import Queue +from queue import Queue, Full, Empty from functools import lru_cache from .base import Dataset from ..selector import Selector @@ -229,7 +229,7 @@ def get( } try: CACHED_ARTICLES.put(page, block=False) - except CACHED_ARTICLES.Full: + except Full: bt.logging.debug("Cache is full. Skipping article until cache is emptied.") return context @@ -382,7 +382,7 @@ def _random_date(self) -> str: else: return page - except CACHED_ARTICLES.Empty: + except Empty: bt.logging.debug("Cache is empty. Skipping date until cache is filled.") return None From d3754b2a8c5b37a07749d6da03b299593d272c04 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 6 May 2024 17:00:46 +0000 Subject: [PATCH 04/13] Comply with unit tests --- prompting/tasks/date_qa.py | 6 ++---- prompting/tools/datasets/wiki.py | 30 +++++++++++++++++++----------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index e24a65d4e..61eef000c 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -23,15 +23,13 @@ class DateQuestionAnsweringTask(Task): dict(name="remove_roles"), ] static_reference = True - static_query = True def __init__(self, llm_pipeline, context, create_reference =True): self.context = context self.query_system_prompt = QUERY_SYSTEM_PROMPT - self.query_prompt = QUERY_PROMPT_TEMPLATE.format(context=context.content[1]) + self.query_prompt = QUERY_PROMPT_TEMPLATE.format(context=context.content) self.query = self.generate_query(llm_pipeline) - self.reference = self.context['content'][0] - + self.reference = self.context.extra.get('date', None) self.topic = context.title self.subtopic = context.topic self.tags = context.tags diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index 64b7c6bf7..74ca29b5d 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -228,7 +228,7 @@ def get( }, } try: - CACHED_ARTICLES.put(page, block=False) + CACHED_ARTICLES.put(context, block=False) except Full: bt.logging.debug("Cache is full. Skipping article until cache is emptied.") return context @@ -349,7 +349,7 @@ def __init__(self, max_tries: int = 10, seed=None): # "section_title": section_title, # }, # } - def extract_dates_and_sentences(text) -> Tuple[str, str]: + def extract_dates_and_sentences(self, text: str) -> Tuple[str, str]: # Regular expression to find dates in various formats date_pattern = r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?(?:,)?\s+\d{4}\b|\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember))\s+\d{4}\b|\b\d{4}\b' @@ -369,25 +369,34 @@ def extract_dates_and_sentences(text) -> Tuple[str, str]: if dates: for date in dates: # Return the first date found - return tuple(str(date), sentence.replace(str(date), '').strip()) + return (str(date), sentence.replace(str(date), '').strip()) return None def _random_date(self) -> str: for _ in range(self.max_tries): try: - page = CACHED_ARTICLES.get(block=False) - page['content'] = self.extract_dates_and_sentences(page['content']) - if page['content'] is None: + context = CACHED_ARTICLES.get(block=False) + date_sentence = self.extract_dates_and_sentences(context['content']) + context['content'] = date_sentence[1] + context['extra']['date'] = date_sentence[0] + if context['content'] is None: continue else: - return page + return context except Empty: bt.logging.debug("Cache is empty. Skipping date until cache is filled.") return None - def get(self) -> Dict: - # Currently not a way to specify which context to fetch with Queue + def get( + self, + name, + pageid=None, + auto_suggest=False, + redirect=False, + selector: Selector = None, + ) -> Dict: + #TODO: Implement deterministic get method return self.random() def search(self, name, results=5, selector: Selector = None) -> Dict: @@ -396,5 +405,4 @@ def search(self, name, results=5, selector: Selector = None) -> Dict: ) def random(self, selector: Selector = None, **kwargs) -> Dict: - date = self._random_date() - return self.get(date, selector=selector) + return self._random_date() From a1240f61c5daff0e512e97e1a536b06c16357863 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Mon, 6 May 2024 17:22:54 +0000 Subject: [PATCH 05/13] Adjust date scoring --- prompting/rewards/date.py | 75 +++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/prompting/rewards/date.py b/prompting/rewards/date.py index 463eaf880..0aeb7834f 100644 --- a/prompting/rewards/date.py +++ b/prompting/rewards/date.py @@ -20,50 +20,47 @@ def date_diff(self, ref_date: tuple, comp_date: tuple) -> int: """ Calculates the absolute difference in days between two dates. """ + if not comp_date: + return 9999 + # Check if ref date is just a year + if ref_date.isdigit(): + # Extract the last 3-4 digits from the completion date using a regex pattern that would detect 3 or 4 digit years + comp_year = re.findall(r'\b\d{3,4}\b', comp_date) + if comp_year: + return abs(int(ref_date) - int(comp_year[0])*365) + else: + return 9999 + # If the reference date is not only a year, take the difference between the two dates try: - return abs(ref_date[0] - comp_date[0]).days + 365 * abs( - int(ref_date[1]) - int(comp_date[1]) - ) - except Exception as e: - return 500 + ref_date = pd.to_datetime(ref_date) + comp_date = pd.to_datetime(comp_date) + return abs((ref_date - comp_date).days) + except: + if ref_date == comp_date: + return 0 + else: + return 9999 def parse_dates_from_text(self, text: str) -> tuple: - """ - Parses dates from a body of text, handling various formats, and returns pandas datetime objects. + # Regular expression to find dates in various formats + date_pattern = r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?(?:,)?\s+\d{4}\b|\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember))\s+\d{4}\b|\b\d{4}\b' - Args: - text (str): The text to parse. + # Compile the regex pattern + date_regex = re.compile(date_pattern) - Returns: - tuple: A tuple containing a datemtime object with they year set at 2000 and the actual year. - """ + # Split text into sentences + sentences = re.split(r'(? float: """Assign a score based on the difference between two dates using a negative exponential function. @@ -77,7 +74,7 @@ def date_score(self, reference: str, completion: str) -> float: score = 0 if not completion: return score - ref_date = self.parse_dates_from_text(reference) + ref_date = reference comp_date = self.parse_dates_from_text(completion) score = np.exp(-(self.date_diff(ref_date, comp_date) ** 2 / 1000)) # Clip any very small scores From 78e3bec5c4150f21aac1566297ef2762522cc96c Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Tue, 7 May 2024 16:21:45 +0000 Subject: [PATCH 06/13] Introduce new cleaning pipeline --- prompting/cleaners/all_cleaners.py | 22 +++++++++++++++++++++- prompting/cleaners/cleaner.py | 4 +++- prompting/tasks/date_qa.py | 6 ++++-- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/prompting/cleaners/all_cleaners.py b/prompting/cleaners/all_cleaners.py index 4a7c46dbd..d7d9cf727 100644 --- a/prompting/cleaners/all_cleaners.py +++ b/prompting/cleaners/all_cleaners.py @@ -101,4 +101,24 @@ def apply(self, generation: str, min_pos: Union[int,float] = 5, max_pos: Union[i # drop everything after the last question mark. Alternatively, we can just extract the first question. generation = generation.rsplit("?",1) + '?' - return generation \ No newline at end of file + return generation + +class RemoveTags(BaseCleaner): + def __init__(self, **kwargs): + pass + + def apply(self, generation: str) -> str: + tags = [ + "",] + for tag in tags: + if tag in generation: + generation = generation.replace(tag, "") + return generation + +class FirstQuestion(BaseCleaner): + def __init__(self, **kwargs): + pass + + def apply(self, generation: str) -> str: + generation = generation.split("?")[0] + "?" + return generation \ No newline at end of file diff --git a/prompting/cleaners/cleaner.py b/prompting/cleaners/cleaner.py index 931d9de3d..e612cf09b 100644 --- a/prompting/cleaners/cleaner.py +++ b/prompting/cleaners/cleaner.py @@ -2,13 +2,15 @@ import bittensor as bt -from prompting.cleaners.all_cleaners import RemoveQuotes, RemoveRoles, PruneEnding, PrunePostQuestionText +from prompting.cleaners.all_cleaners import RemoveQuotes, RemoveRoles, PruneEnding, PrunePostQuestionText, RemoveTags, FirstQuestion SUPPORTED_CLEANERS = { "remove_quotes": RemoveQuotes, "remove_roles": RemoveRoles, "prune_ending": PruneEnding, "remove_post_question_text": PrunePostQuestionText, + "first_question": FirstQuestion, + "remove_tags": RemoveTags, } diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index 61eef000c..f4b1271a6 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -19,8 +19,10 @@ class DateQuestionAnsweringTask(Task): ] penalty_definition = [] cleaning_pipeline = [ - dict(name="remove_quotes"), - dict(name="remove_roles"), + #dict(name="remove_quotes"), + #dict(name="remove_roles"), + dict(name="remove_tags"), + dict(name="first_question"), ] static_reference = True From a12765c43c580905d8d70dde8bb3a88e663a42ce Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Tue, 14 May 2024 12:53:21 +0000 Subject: [PATCH 07/13] Use text_based reference --- prompting/cleaners/all_cleaners.py | 2 ++ prompting/tasks/date_qa.py | 27 ++++++++++++++++++--------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/prompting/cleaners/all_cleaners.py b/prompting/cleaners/all_cleaners.py index d7d9cf727..1d184a284 100644 --- a/prompting/cleaners/all_cleaners.py +++ b/prompting/cleaners/all_cleaners.py @@ -56,6 +56,8 @@ def capitalize_sentences(self, input_string): sentences = re.split(r"(?<=[.!?])\s+", input_string) capitalized_sentences = [sentence.capitalize() for sentence in sentences] result_string = " ".join(capitalized_sentences) + # Capitalize the first letter in result_string + result_string.capitalize() return result_string def apply(self, generation: str) -> str: diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index f4b1271a6..df267a4ec 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -2,10 +2,17 @@ from prompting.tasks import Task from prompting.cleaners.cleaner import CleanerPipeline -QUERY_SYSTEM_PROMPT = """You are a question creation expert. When asked to create a question, you use the context to make a specific question that would have the answer .""" +QUERY_SYSTEM_PROMPT = """You are a question creation expert. When asked to create a question, you use the context to make a specific question that would have the answer . Your question should contain the topic.""" QUERY_PROMPT_TEMPLATE = """\ -Create a question that would have as the answer using the following context: -{context} +Create a question about {topic} that would have as the answer using the following context: +topic: {topic} +context: {context} +""" +REFERENCE_PROMPT_TEMPLATE = """\ +Your answer must include the following date: {date}. +Answer the following question using the provided context. +Question: {query} +Context: {context} """ @dataclass @@ -20,18 +27,20 @@ class DateQuestionAnsweringTask(Task): penalty_definition = [] cleaning_pipeline = [ #dict(name="remove_quotes"), - #dict(name="remove_roles"), + dict(name="remove_roles"), dict(name="remove_tags"), dict(name="first_question"), ] - static_reference = True + static_reference = False def __init__(self, llm_pipeline, context, create_reference =True): self.context = context self.query_system_prompt = QUERY_SYSTEM_PROMPT - self.query_prompt = QUERY_PROMPT_TEMPLATE.format(context=context.content) + self.query_prompt = QUERY_PROMPT_TEMPLATE.format(topic = context.title, context=context.content) self.query = self.generate_query(llm_pipeline) - self.reference = self.context.extra.get('date', None) - self.topic = context.title - self.subtopic = context.topic + date = self.context.extra.get('date', None) + self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(date = date, query = self.query, context = context.content) + self.reference = self.generate_reference(llm_pipeline) + self.topic = date + self.subtopic = context.title self.tags = context.tags From 07ef5354135c5e5215d1cc12b287bc452e196de6 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Tue, 14 May 2024 13:03:27 +0000 Subject: [PATCH 08/13] Include rouge in reward definition --- prompting/rewards/date.py | 2 +- prompting/tasks/date_qa.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/prompting/rewards/date.py b/prompting/rewards/date.py index 0aeb7834f..b1b077e3b 100644 --- a/prompting/rewards/date.py +++ b/prompting/rewards/date.py @@ -74,7 +74,7 @@ def date_score(self, reference: str, completion: str) -> float: score = 0 if not completion: return score - ref_date = reference + ref_date = self.parse_dates_from_text(reference) comp_date = self.parse_dates_from_text(completion) score = np.exp(-(self.date_diff(ref_date, comp_date) ** 2 / 1000)) # Clip any very small scores diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index df267a4ec..aefb4bf57 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -22,7 +22,8 @@ class DateQuestionAnsweringTask(Task): desc = "get help answering a specific date-based question" goal = "to get the answer to the following date-based question" reward_definition = [ - dict(name="date", weight=1.0), + dict(name="date", weight=0.7), + dict(name="rouge", weight=0.3), ] penalty_definition = [] cleaning_pipeline = [ @@ -41,6 +42,6 @@ def __init__(self, llm_pipeline, context, create_reference =True): date = self.context.extra.get('date', None) self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(date = date, query = self.query, context = context.content) self.reference = self.generate_reference(llm_pipeline) - self.topic = date - self.subtopic = context.title + self.topic = context.title + self.subtopic = date self.tags = context.tags From 6fd630bc36f952408ae64d53e48a76ceffcba20c Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Tue, 14 May 2024 16:09:02 -0400 Subject: [PATCH 09/13] Remove Unecessary Comments --- prompting/tools/datasets/wiki.py | 79 +------------------------------- 1 file changed, 1 insertion(+), 78 deletions(-) diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index 74ca29b5d..f99db1323 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -271,84 +271,7 @@ def __init__(self, max_tries: int = 10, seed=None): self.max_tries = max_tries self.seed = seed self.rng = random.Random(seed) - - # def _random_date(self, year: int = None, month: int = None) -> int: - # """Returns a random date in the format "Month_DD" (e.g., "January_01").""" - # if year is None: - # year = self.rng.randint(0, 2024) - # if month is None: - # month = self.rng.randint(1, 12) - - # max_days = 31 if month in (1, 3, 5, 7, 8, 10, 12) else 30 - # max_days = max_days if month != 2 else 29 - - # day = self.rng.randint(1, max_days) - - # random_date = datetime.date(year, month, day) - # # Step 2: Format the date for Wikipedia URL - # return random_date.strftime("%B %-d") # E.g., "January 1" - - # def get( - # self, - # name, - # pageid=None, - # auto_suggest=False, - # redirect=False, - # selector: Selector = None, - # ) -> Dict: - # # Check that name is correctly formatted e.g., "January 1" - # date = name.split(" ") - # assert ( - # len(date) == 2 - # ), f"Date should be in the format 'Month D[D]' (e.g., 'January 1' or 'March 28'), but got {name!r}" - # assert ( - # date[0] in self.MONTHS - # ), f"Month should be one of {self.MONTHS}, but got {date[0]!r}" - # assert date[1].isdigit(), f"Day should be a number, but got {date[1]!r}" - - # page = _get_page( - # title=name, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect - # ) - # if page is None: - # return None - - # # Only return a sections which contain event-like format - # # e.g. "1999 - Some event happened" - # sections = process_page( - # page, - # valid_header=lambda x: x in self.INCLUDE_HEADERS, - # valid_content=lambda x: any( - # [re.search(r"^\d+", line) for line in x.splitlines()] - # ), - # ) - # if not sections: - # return None - - # key = header, section_title = selector(list(sections.keys())) - # line = selector(sections[key]) - # year, *event = line.replace("\u2013", "-").split("-") - # links = [link for link in page.links if link in line] - - # return { - # "title": name, # title of wiki article - # "topic": header or section_title, # title of wiki section - # "subtopic": year.strip(), - # "content": "-".join(event).strip(". "), - # "internal_links": list(sections.keys()), - # "external_links": links, - # "tags": filter_categories( - # page.categories, exclude=WikiDataset.EXCLUDE_CATEGORIES - # ), - # "source": "Wikipedia", - # "extra": { - # "url": page.url, - # "year": year, - # "event": event, - # "line": line, - # "date": date + [year], - # "section_title": section_title, - # }, - # } + def extract_dates_and_sentences(self, text: str) -> Tuple[str, str]: # Regular expression to find dates in various formats date_pattern = r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?(?:,)?\s+\d{4}\b|\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember))\s+\d{4}\b|\b\d{4}\b' From acd20264ea9bea8e22f6dfb75265c888a690e56a Mon Sep 17 00:00:00 2001 From: bkb2135 <98138173+bkb2135@users.noreply.github.com> Date: Wed, 15 May 2024 08:24:46 -0400 Subject: [PATCH 10/13] Disable Cleaning on Reference --- prompting/tasks/date_qa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index aefb4bf57..1737cd530 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -41,7 +41,7 @@ def __init__(self, llm_pipeline, context, create_reference =True): self.query = self.generate_query(llm_pipeline) date = self.context.extra.get('date', None) self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(date = date, query = self.query, context = context.content) - self.reference = self.generate_reference(llm_pipeline) + self.reference = self.generate_reference(llm_pipeline, clean = False) self.topic = context.title self.subtopic = date self.tags = context.tags From cf0aa289f77cb7baa3e6d6382c0cc5de27433477 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Wed, 15 May 2024 13:02:58 +0000 Subject: [PATCH 11/13] Manually remove the question marks from references --- prompting/tasks/date_qa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index 1737cd530..e37024b46 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -41,7 +41,7 @@ def __init__(self, llm_pipeline, context, create_reference =True): self.query = self.generate_query(llm_pipeline) date = self.context.extra.get('date', None) self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(date = date, query = self.query, context = context.content) - self.reference = self.generate_reference(llm_pipeline, clean = False) + self.reference = self.generate_reference(llm_pipeline, clean = False).replace('?',"") self.topic = context.title self.subtopic = date self.tags = context.tags From 3bc8b0c48ab4e28e617c99146ea28292a3fa7b31 Mon Sep 17 00:00:00 2001 From: bkb2135 Date: Fri, 17 May 2024 16:20:51 +0000 Subject: [PATCH 12/13] Create a class attribute to determine whether the reference should be cleaned --- prompting/cleaners/all_cleaners.py | 5 ++++- prompting/tasks/date_qa.py | 6 ++++-- prompting/tasks/task.py | 4 +++- prompting/tools/datasets/wiki.py | 2 -- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/prompting/cleaners/all_cleaners.py b/prompting/cleaners/all_cleaners.py index 1d184a284..d48119bf3 100644 --- a/prompting/cleaners/all_cleaners.py +++ b/prompting/cleaners/all_cleaners.py @@ -122,5 +122,8 @@ def __init__(self, **kwargs): pass def apply(self, generation: str) -> str: - generation = generation.split("?")[0] + "?" + if "?" in generation: + if ':' in generation: + generation = generation.split(':')[1] + generation = generation.split("?")[0] + "?" return generation \ No newline at end of file diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index e37024b46..e4a253368 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -19,6 +19,7 @@ class DateQuestionAnsweringTask(Task): name = "date_qa" challenge_type = 'query' + clean_reference = False desc = "get help answering a specific date-based question" goal = "to get the answer to the following date-based question" reward_definition = [ @@ -28,7 +29,7 @@ class DateQuestionAnsweringTask(Task): penalty_definition = [] cleaning_pipeline = [ #dict(name="remove_quotes"), - dict(name="remove_roles"), + #dict(name="remove_roles"), dict(name="remove_tags"), dict(name="first_question"), ] @@ -41,7 +42,8 @@ def __init__(self, llm_pipeline, context, create_reference =True): self.query = self.generate_query(llm_pipeline) date = self.context.extra.get('date', None) self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(date = date, query = self.query, context = context.content) - self.reference = self.generate_reference(llm_pipeline, clean = False).replace('?',"") + if create_reference: + self.reference = self.generate_reference(llm_pipeline) self.topic = context.title self.subtopic = date self.tags = context.tags diff --git a/prompting/tasks/task.py b/prompting/tasks/task.py index d837aac46..041798d3e 100644 --- a/prompting/tasks/task.py +++ b/prompting/tasks/task.py @@ -49,6 +49,7 @@ class Task(ABC): query_system_prompt = "" query_prompt = "" cleaner = None + clean_reference = True challenge_type = 'inference' def __str__(self): @@ -91,8 +92,9 @@ def generate_reference(self, pipeline: BasePipeline, clean=True) -> str: """Generates a reference answer to be used for scoring miner completions""" t0 = time.time() if not self.static_reference: + if not self.clean_reference: + clean = False bt.logging.info("🤖 Generating reference...") - self.reference = self.generate( system=make_system_prompt(), prompt=self.reference_prompt, diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index f99db1323..a125630aa 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -282,8 +282,6 @@ def extract_dates_and_sentences(self, text: str) -> Tuple[str, str]: # Split text into sentences sentences = re.split(r'(? Date: Fri, 17 May 2024 17:32:19 +0000 Subject: [PATCH 13/13] Used DATE_NOT_FOUND_CODE instead of 9999 --- prompting/rewards/date.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/prompting/rewards/date.py b/prompting/rewards/date.py index b1b077e3b..9cc3bf525 100644 --- a/prompting/rewards/date.py +++ b/prompting/rewards/date.py @@ -20,8 +20,9 @@ def date_diff(self, ref_date: tuple, comp_date: tuple) -> int: """ Calculates the absolute difference in days between two dates. """ + DATE_NOT_FOUND_CODE = 9999 if not comp_date: - return 9999 + return DATE_NOT_FOUND_CODE # Check if ref date is just a year if ref_date.isdigit(): # Extract the last 3-4 digits from the completion date using a regex pattern that would detect 3 or 4 digit years @@ -29,7 +30,7 @@ def date_diff(self, ref_date: tuple, comp_date: tuple) -> int: if comp_year: return abs(int(ref_date) - int(comp_year[0])*365) else: - return 9999 + return DATE_NOT_FOUND_CODE # If the reference date is not only a year, take the difference between the two dates try: ref_date = pd.to_datetime(ref_date) @@ -39,7 +40,7 @@ def date_diff(self, ref_date: tuple, comp_date: tuple) -> int: if ref_date == comp_date: return 0 else: - return 9999 + return DATE_NOT_FOUND_CODE def parse_dates_from_text(self, text: str) -> tuple: # Regular expression to find dates in various formats