diff --git a/prompting/datasets/huggingface_github.py b/prompting/datasets/huggingface_github.py index 74e5d19fa..b65e3fc20 100644 --- a/prompting/datasets/huggingface_github.py +++ b/prompting/datasets/huggingface_github.py @@ -20,6 +20,7 @@ class HuggingFaceGithubDatasetEntry(DatasetEntry): github_url: str file_path: str file_content: str + source: str | None = None class HuggingFaceGithubDataset(BaseDataset): @@ -46,8 +47,9 @@ def _filter_function(self, example): def _process_entry(self, entry: dict) -> HuggingFaceGithubDatasetEntry: file_content = "\n".join(entry["content"].split("\n")[:MAX_LINES]) + url = f"https://github.com/{entry['repo_name']}" return HuggingFaceGithubDatasetEntry( - github_url=f"https://github.com/{entry['repo_name']}", file_path=entry["path"], file_content=file_content + github_url=url, file_path=entry["path"], file_content=file_content, source=url ) def get(self) -> HuggingFaceGithubDatasetEntry: diff --git a/prompting/datasets/random_website.py b/prompting/datasets/random_website.py index 8496711c1..ae70ed41f 100644 --- a/prompting/datasets/random_website.py +++ b/prompting/datasets/random_website.py @@ -18,6 +18,7 @@ class DDGDatasetEntry(DatasetEntry): website_url: str = None website_content: str = None query: str | None = None + source: str | None = None class DDGDataset(BaseDataset): @@ -55,7 +56,9 @@ def next(self) -> Optional[DDGDatasetEntry]: logger.debug(f"Failed to extract content from website {website_url}") return None - return DDGDatasetEntry(search_term=search_term, website_url=website_url, website_content=website_content) + return DDGDatasetEntry( + search_term=search_term, website_url=website_url, website_content=website_content, source=website_url + ) def get(self) -> Optional[DDGDatasetEntry]: return self.next() diff --git a/prompting/datasets/wiki.py b/prompting/datasets/wiki.py index e07d16109..56e1a77d6 100644 --- a/prompting/datasets/wiki.py +++ b/prompting/datasets/wiki.py @@ -2,20 +2,15 @@ import re import sys from functools import lru_cache -from queue import Empty, Full, Queue -from typing import ClassVar, Optional +from typing import ClassVar import requests import wikipedia from bs4 import BeautifulSoup from loguru import logger -from pydantic import ConfigDict, model_validator from shared.base import BaseDataset, Context -# Create a queue called CACHED_ARTICLES to store wikipedia articles that have been fetched -CACHED_ARTICLES: Queue[Context] = Queue(maxsize=300) - # speed up page loading @lru_cache(maxsize=1000) @@ -183,17 +178,13 @@ def get( internal_links=list(filter(lambda x: x not in exclude, page.sections)), external_links=most_relevant_links(page, num_links=self.max_links), tags=filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES), - source="Wikipedia", + source=page.url, extra={ "url": page.url, "page_length": len(page.content.split()), "section_length": section_length, }, ) - try: - CACHED_ARTICLES.put(context, block=False) - except Full: - logger.debug("Cache is full. Skipping article until cache is emptied.") return context def search(self, name, results=3) -> Context: @@ -207,111 +198,3 @@ def random(self, pages=10) -> dict: if context := self.get(title): return context return None - - -class DateContext(Context): - date: str = None - - @classmethod - def from_context(cls, context: Context, date: str) -> "DateContext": - return cls( - **context.model_dump(), - date=date, - ) - - -class WikiDateDataset(BaseDataset): - name: ClassVar[str] = "wikipedia_date" - INCLUDE_HEADERS: tuple = ("Events", "Births", "Deaths") - MONTHS: tuple = ( - "January", - "February", - "March", - "April", - "May", - "June", - "July", - "August", - "September", - "October", - "November", - "December", - ) - EXCLUDE_CATEGORIES: tuple = ("articles", "wikipedia", "pages", "cs1") - seed: int | None = None - rng: Optional[random.Random] = None - model_config = ConfigDict(arbitrary_types_allowed=True) - - @model_validator(mode="after") - def create_rng(self) -> "WikiDateDataset": - self.rng = random.Random(self.seed) - return self - - def _extract_dates_and_sentences(self, text: str) -> tuple[str, str]: - # Regular expression to find dates in various formats - date_pattern = r"\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?(?:,)?\s+\d{4}\b|\b\d{1,2}\s+(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember))\s+\d{4}\b|\b\d{4}\b" - - # Compile the regex pattern - date_regex = re.compile(date_pattern) - - # Split text into sentences - sentences = re.split(r"(?").strip()) - - # If no dates are found, search for dates in the form of "Month DD" - secondary_date_pattern = r"\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:st|nd|rd|th)?\b" - secondary_date_regex = re.compile(secondary_date_pattern) - - for sentence in sentences: - # Find all dates in the sentence - dates = secondary_date_regex.findall(sentence) - # If dates are found, add them to the result dictionary with the corresponding sentence - if dates: - for date in dates: - # Return the first date found - return (str(date), sentence.replace(str(date), "").strip()) - - return None - - def _random_date(self) -> DateContext: - for i in range(self.max_tries): - try: - context = CACHED_ARTICLES.get(block=False) - if not context: - continue - - date_sentence = self._extract_dates_and_sentences(context.content) - - if date_sentence and all(date_sentence): - content, date = date_sentence - date_context = DateContext.from_context(context, date=date) - date_context.content = content - return date_context - - except Empty: - logger.debug(f"Retry {i} Cache is empty. Skipping date until cache is filled.") - return None - - except Exception as e: - logger.exception(f"Error fetching date: {e}") - continue - - def get( - self, - ) -> dict: - raise NotImplementedError(f"Search is not implemented for {self.__class__.__name__}") - - def search(self, name: str, results: int = 5) -> dict: - raise NotImplementedError(f"Search is not implemented for {self.__class__.__name__}") - - def random(self) -> DateContext: - return self._random_date() diff --git a/prompting/miner_availability/miner_availability.py b/prompting/miner_availability/miner_availability.py index dfecac3f9..e1308d73b 100644 --- a/prompting/miner_availability/miner_availability.py +++ b/prompting/miner_availability/miner_availability.py @@ -79,7 +79,7 @@ async def run_step(self): llm_model_availabilities=response["llm_model_availabilities"], ) except Exception: - logger.debug("Availability Response Invalid") + # logger.debug("Availability Response Invalid") miner_availabilities.miners[uid] = MinerAvailability( task_availabilities={task: True for task in task_config}, llm_model_availabilities={model: False for model in model_config}, diff --git a/prompting/rewards/scoring.py b/prompting/rewards/scoring.py index c314b6efa..8c33775d7 100644 --- a/prompting/rewards/scoring.py +++ b/prompting/rewards/scoring.py @@ -106,6 +106,7 @@ async def run_step(self) -> RewardLoggingEvent: step=scoring_config.step, task_id=scoring_config.task_id, task_dict=scoring_config.task.model_dump(), + source=scoring_config.dataset_entry.source, ) ) logger.info("Adding scores to rewards_and_uids") diff --git a/shared/base.py b/shared/base.py index 1d851aea3..91ae84525 100644 --- a/shared/base.py +++ b/shared/base.py @@ -79,7 +79,6 @@ def next(self, method: Literal["random", "search", "get"] = "random", **kwargs) logger.error(f"Failed to fetch context after {RETRIES} tries.") return None - context.source = self.__class__.__name__ context.stats = { "fetch_time": timer.final_time, "num_tries": tries, diff --git a/shared/logging.py b/shared/logging.py index d04d68fb1..4b24735a3 100644 --- a/shared/logging.py +++ b/shared/logging.py @@ -181,6 +181,7 @@ class RewardLoggingEvent(BaseEvent): challenge: str | list[dict] task: str task_dict: dict + source: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True)