diff --git a/prompting/cleaners/all_cleaners.py b/prompting/cleaners/all_cleaners.py index 4a7c46dbd..d48119bf3 100644 --- a/prompting/cleaners/all_cleaners.py +++ b/prompting/cleaners/all_cleaners.py @@ -56,6 +56,8 @@ def capitalize_sentences(self, input_string): sentences = re.split(r"(?<=[.!?])\s+", input_string) capitalized_sentences = [sentence.capitalize() for sentence in sentences] result_string = " ".join(capitalized_sentences) + # Capitalize the first letter in result_string + result_string.capitalize() return result_string def apply(self, generation: str) -> str: @@ -101,4 +103,27 @@ def apply(self, generation: str, min_pos: Union[int,float] = 5, max_pos: Union[i # drop everything after the last question mark. Alternatively, we can just extract the first question. generation = generation.rsplit("?",1) + '?' - return generation \ No newline at end of file + return generation + +class RemoveTags(BaseCleaner): + def __init__(self, **kwargs): + pass + + def apply(self, generation: str) -> str: + tags = [ + "",] + for tag in tags: + if tag in generation: + generation = generation.replace(tag, "") + return generation + +class FirstQuestion(BaseCleaner): + def __init__(self, **kwargs): + pass + + def apply(self, generation: str) -> str: + if "?" in generation: + if ':' in generation: + generation = generation.split(':')[1] + generation = generation.split("?")[0] + "?" + return generation \ No newline at end of file diff --git a/prompting/cleaners/cleaner.py b/prompting/cleaners/cleaner.py index 931d9de3d..e612cf09b 100644 --- a/prompting/cleaners/cleaner.py +++ b/prompting/cleaners/cleaner.py @@ -2,13 +2,15 @@ import bittensor as bt -from prompting.cleaners.all_cleaners import RemoveQuotes, RemoveRoles, PruneEnding, PrunePostQuestionText +from prompting.cleaners.all_cleaners import RemoveQuotes, RemoveRoles, PruneEnding, PrunePostQuestionText, RemoveTags, FirstQuestion SUPPORTED_CLEANERS = { "remove_quotes": RemoveQuotes, "remove_roles": RemoveRoles, "prune_ending": PruneEnding, "remove_post_question_text": PrunePostQuestionText, + "first_question": FirstQuestion, + "remove_tags": RemoveTags, } diff --git a/prompting/rewards/date.py b/prompting/rewards/date.py index 463eaf880..9cc3bf525 100644 --- a/prompting/rewards/date.py +++ b/prompting/rewards/date.py @@ -20,50 +20,48 @@ def date_diff(self, ref_date: tuple, comp_date: tuple) -> int: """ Calculates the absolute difference in days between two dates. """ + DATE_NOT_FOUND_CODE = 9999 + if not comp_date: + return DATE_NOT_FOUND_CODE + # Check if ref date is just a year + if ref_date.isdigit(): + # Extract the last 3-4 digits from the completion date using a regex pattern that would detect 3 or 4 digit years + comp_year = re.findall(r'\b\d{3,4}\b', comp_date) + if comp_year: + return abs(int(ref_date) - int(comp_year[0])*365) + else: + return DATE_NOT_FOUND_CODE + # If the reference date is not only a year, take the difference between the two dates try: - return abs(ref_date[0] - comp_date[0]).days + 365 * abs( - int(ref_date[1]) - int(comp_date[1]) - ) - except Exception as e: - return 500 + ref_date = pd.to_datetime(ref_date) + comp_date = pd.to_datetime(comp_date) + return abs((ref_date - comp_date).days) + except: + if ref_date == comp_date: + return 0 + else: + return DATE_NOT_FOUND_CODE def parse_dates_from_text(self, text: str) -> tuple: - """ - Parses dates from a body of text, handling various formats, and returns pandas datetime objects. + # Regular expression to find dates in various formats + date_pattern = r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?(?:,)?\s+\d{4}\b|\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember))\s+\d{4}\b|\b\d{4}\b' - Args: - text (str): The text to parse. + # Compile the regex pattern + date_regex = re.compile(date_pattern) - Returns: - tuple: A tuple containing a datemtime object with they year set at 2000 and the actual year. - """ + # Split text into sentences + sentences = re.split(r'(? float: """Assign a score based on the difference between two dates using a negative exponential function. diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index bee5776b2..e4a253368 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -2,34 +2,48 @@ from prompting.tasks import Task from prompting.cleaners.cleaner import CleanerPipeline - -SECTION_MESSAGES = {"Births": " was born ", "Deaths": " died ", "Events": " "} - +QUERY_SYSTEM_PROMPT = """You are a question creation expert. When asked to create a question, you use the context to make a specific question that would have the answer . Your question should contain the topic.""" +QUERY_PROMPT_TEMPLATE = """\ +Create a question about {topic} that would have as the answer using the following context: +topic: {topic} +context: {context} +""" +REFERENCE_PROMPT_TEMPLATE = """\ +Your answer must include the following date: {date}. +Answer the following question using the provided context. +Question: {query} +Context: {context} +""" @dataclass class DateQuestionAnsweringTask(Task): name = "date_qa" + challenge_type = 'query' + clean_reference = False desc = "get help answering a specific date-based question" goal = "to get the answer to the following date-based question" reward_definition = [ - dict(name="date", weight=1.0), + dict(name="date", weight=0.7), + dict(name="rouge", weight=0.3), ] penalty_definition = [] cleaning_pipeline = [ - dict(name="remove_quotes"), - dict(name="remove_roles"), + #dict(name="remove_quotes"), + #dict(name="remove_roles"), + dict(name="remove_tags"), + dict(name="first_question"), ] - static_reference = True - static_query = True + static_reference = False def __init__(self, llm_pipeline, context, create_reference =True): self.context = context - - self.query = ( - context.content + SECTION_MESSAGES[context.topic] + "on what exact date?" - ) - self.reference = self.context.title.replace("_", " ") + ", " + context.subtopic - + self.query_system_prompt = QUERY_SYSTEM_PROMPT + self.query_prompt = QUERY_PROMPT_TEMPLATE.format(topic = context.title, context=context.content) + self.query = self.generate_query(llm_pipeline) + date = self.context.extra.get('date', None) + self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(date = date, query = self.query, context = context.content) + if create_reference: + self.reference = self.generate_reference(llm_pipeline) self.topic = context.title - self.subtopic = context.topic + self.subtopic = date self.tags = context.tags diff --git a/prompting/tasks/task.py b/prompting/tasks/task.py index d837aac46..041798d3e 100644 --- a/prompting/tasks/task.py +++ b/prompting/tasks/task.py @@ -49,6 +49,7 @@ class Task(ABC): query_system_prompt = "" query_prompt = "" cleaner = None + clean_reference = True challenge_type = 'inference' def __str__(self): @@ -91,8 +92,9 @@ def generate_reference(self, pipeline: BasePipeline, clean=True) -> str: """Generates a reference answer to be used for scoring miner completions""" t0 = time.time() if not self.static_reference: + if not self.clean_reference: + clean = False bt.logging.info("🤖 Generating reference...") - self.reference = self.generate( system=make_system_prompt(), prompt=self.reference_prompt, diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index 6da057f64..a125630aa 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -23,12 +23,16 @@ import bittensor as bt import wikipedia as wiki from typing import Dict, Union, List, Tuple - +from queue import Queue, Full, Empty from functools import lru_cache from .base import Dataset from ..selector import Selector +# Create a queue called CACHED_ARTICLES to store wikipedia articles that have been fetched +CACHED_ARTICLES = Queue(maxsize=300) + + # speed up page loading @lru_cache(maxsize=1000) def _get_page( @@ -208,7 +212,7 @@ def get( key = header, section_title = selector(list(sections.keys())) content = "\n".join(sections[key]) section_length = len(content.split()) - return { + context = { "title": name, # title of wiki article "topic": header or section_title, # title of wiki section "subtopic": section_title, @@ -223,6 +227,11 @@ def get( "section_length": section_length, }, } + try: + CACHED_ARTICLES.put(context, block=False) + except Full: + bt.logging.debug("Cache is full. Skipping article until cache is emptied.") + return context def search(self, name, results=3, selector: Selector = None) -> Dict: titles = _wiki_search(name, results=results) @@ -262,22 +271,43 @@ def __init__(self, max_tries: int = 10, seed=None): self.max_tries = max_tries self.seed = seed self.rng = random.Random(seed) - - def _random_date(self, year: int = None, month: int = None) -> int: - """Returns a random date in the format "Month_DD" (e.g., "January_01").""" - if year is None: - year = self.rng.randint(0, 2024) - if month is None: - month = self.rng.randint(1, 12) - - max_days = 31 if month in (1, 3, 5, 7, 8, 10, 12) else 30 - max_days = max_days if month != 2 else 29 - - day = self.rng.randint(1, max_days) - - random_date = datetime.date(year, month, day) - # Step 2: Format the date for Wikipedia URL - return random_date.strftime("%B %-d") # E.g., "January 1" + + def extract_dates_and_sentences(self, text: str) -> Tuple[str, str]: + # Regular expression to find dates in various formats + date_pattern = r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?(?:,)?\s+\d{4}\b|\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember))\s+\d{4}\b|\b\d{4}\b' + + # Compile the regex pattern + date_regex = re.compile(date_pattern) + + # Split text into sentences + sentences = re.split(r'(?').strip()) + return None + + def _random_date(self) -> str: + for _ in range(self.max_tries): + try: + context = CACHED_ARTICLES.get(block=False) + date_sentence = self.extract_dates_and_sentences(context['content']) + context['content'] = date_sentence[1] + context['extra']['date'] = date_sentence[0] + if context['content'] is None: + continue + else: + return context + + except Empty: + bt.logging.debug("Cache is empty. Skipping date until cache is filled.") + return None def get( self, @@ -287,59 +317,8 @@ def get( redirect=False, selector: Selector = None, ) -> Dict: - # Check that name is correctly formatted e.g., "January 1" - date = name.split(" ") - assert ( - len(date) == 2 - ), f"Date should be in the format 'Month D[D]' (e.g., 'January 1' or 'March 28'), but got {name!r}" - assert ( - date[0] in self.MONTHS - ), f"Month should be one of {self.MONTHS}, but got {date[0]!r}" - assert date[1].isdigit(), f"Day should be a number, but got {date[1]!r}" - - page = _get_page( - title=name, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect - ) - if page is None: - return None - - # Only return a sections which contain event-like format - # e.g. "1999 - Some event happened" - sections = process_page( - page, - valid_header=lambda x: x in self.INCLUDE_HEADERS, - valid_content=lambda x: any( - [re.search(r"^\d+", line) for line in x.splitlines()] - ), - ) - if not sections: - return None - - key = header, section_title = selector(list(sections.keys())) - line = selector(sections[key]) - year, *event = line.replace("\u2013", "-").split("-") - links = [link for link in page.links if link in line] - - return { - "title": name, # title of wiki article - "topic": header or section_title, # title of wiki section - "subtopic": year.strip(), - "content": "-".join(event).strip(". "), - "internal_links": list(sections.keys()), - "external_links": links, - "tags": filter_categories( - page.categories, exclude=WikiDataset.EXCLUDE_CATEGORIES - ), - "source": "Wikipedia", - "extra": { - "url": page.url, - "year": year, - "event": event, - "line": line, - "date": date + [year], - "section_title": section_title, - }, - } + #TODO: Implement deterministic get method + return self.random() def search(self, name, results=5, selector: Selector = None) -> Dict: raise NotImplementedError( @@ -347,5 +326,4 @@ def search(self, name, results=5, selector: Selector = None) -> Dict: ) def random(self, selector: Selector = None, **kwargs) -> Dict: - date = self._random_date() - return self.get(date, selector=selector) + return self._random_date()