Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions neurons/miners/epistula_miner/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@

import numpy as np
import trafilatura
from duckduckgo_search.duckduckgo_search import DDGS
from openai import OpenAI

from prompting.base.duckduckgo_patch import PatchedDDGS
from shared import settings

# Import the patched DDGS and use that


# Import the patched DDGS and use that


async def fetch_url(url: str) -> str:
return trafilatura.fetch_url(url)
Expand Down Expand Up @@ -54,7 +49,7 @@ async def get_websites_with_similarity(
Returns:
List of dictionaries containing website URLs and their best matching chunks
"""
ddgs = PatchedDDGS(proxy=settings.shared_settings.PROXY_URL, verify=False)
ddgs = DDGS(proxy=settings.shared_settings.PROXY_URL, verify=False)
results = list(ddgs.text(query))
urls = [r["href"] for r in results][:n_results]

Expand Down
42 changes: 0 additions & 42 deletions prompting/base/duckduckgo_patch.py

This file was deleted.

80 changes: 60 additions & 20 deletions prompting/datasets/huggingface_github.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import random
from typing import Any, ClassVar, Iterator

from datasets import load_dataset
from loguru import logger
from pydantic import ConfigDict, model_validator

from shared.base import BaseDataset, DatasetEntry
Expand All @@ -13,6 +17,8 @@
OUTPUT_LINES = 10
MAX_LINES = 500
RETRIES = 50 # Increased retry limit
DEFAULT_NUM_SHARDS = 1126
RANDOM_SKIP = 100


class HuggingFaceGithubDatasetEntry(DatasetEntry):
Expand All @@ -24,20 +30,45 @@ class HuggingFaceGithubDatasetEntry(DatasetEntry):

class HuggingFaceGithubDataset(BaseDataset):
language: str = "python"
dataset: any = None
iterator: any = None

base_dataset: ClassVar[Any] = None
num_shards: ClassVar[int] = DEFAULT_NUM_SHARDS

# Instance-level iterator over the current shard.
current_shard_iterator: Iterator[Any] | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

@model_validator(mode="after")
def load_dataset(self) -> "HuggingFaceGithubDataset":
self.dataset = load_dataset(
"macrocosm-os/code-parrot-github-code", streaming=True, split="train", trust_remote_code=True
)
self.iterator = iter(self.dataset.filter(self._filter_function))
if HuggingFaceGithubDataset.base_dataset is None:
# Load the full streaming dataset.
HuggingFaceGithubDataset.base_dataset = load_dataset(
"macrocosm-os/code-parrot-github-code", streaming=True, split="train", trust_remote_code=True
)
# Try to determine the number of shards from the underlying file list.
files = HuggingFaceGithubDataset.base_dataset._ex_iterable.kwargs.get("files")
if files is not None:
HuggingFaceGithubDataset.num_shards = len(files)
else:
logger.warning("Cannot get number of shards")
HuggingFaceGithubDataset.num_shards = DEFAULT_NUM_SHARDS

# Select a random shard to begin iterating.
self._reset_shard()
return self

def _filter_function(self, example):
def _reset_shard(self) -> None:
"""Choose a new random shard and creates a fresh iterator over its filtered records."""
random_shard_index = random.randrange(HuggingFaceGithubDataset.num_shards)
shard_dataset = HuggingFaceGithubDataset.base_dataset.shard(
num_shards=HuggingFaceGithubDataset.num_shards, index=random_shard_index
)
# Apply filtering to the selected shard.
shard_dataset = shard_dataset.filter(self._filter_function)
HuggingFaceGithubDataset.current_shard_iterator = iter(shard_dataset)

def _filter_function(self, example: dict) -> bool:
return (
any(example["path"].endswith(ending) for ending in ALLOWED_FILE_ENDINGS[self.language])
and MIN_FILE_SIZE <= int(example["size"]) <= MAX_FILE_SIZE
Expand All @@ -51,25 +82,34 @@ def _process_entry(self, entry: dict) -> HuggingFaceGithubDatasetEntry:
github_url=url, file_path=entry["path"], file_content=file_content, source=url
)

def get(self) -> HuggingFaceGithubDatasetEntry:
return self.next()
def _try_sample(self) -> dict[str, str]:
"""Return the next record from the current shard.

When the shard is exhausted, it automatically resets to a new random shard.
"""
try:
entry = next(HuggingFaceGithubDataset.current_shard_iterator)
except StopIteration:
self._reset_shard()
entry = next(HuggingFaceGithubDataset.current_shard_iterator)
return entry

def next(self) -> HuggingFaceGithubDatasetEntry:
for _ in range(RETRIES):
for _ in range(random.randint(0, RANDOM_SKIP)):
self._try_sample()

while True:
try:
entry = next(self.iterator)
return self._process_entry(entry) # Throws failed to get a valid file after multiple attempts
except StopIteration:
self.reset()
raise Exception("Failed to get a valid file after multiple attempts")
entry = self._try_sample()
return self._process_entry(entry)
except BaseException as e:
logger.warning(f"Failed to sample from shard: {e}")

def random(self) -> HuggingFaceGithubDatasetEntry:
# Note: The dataset is streamed, so true random access is not possible.
# This method will just return the next item, similar to `next()`.
def get(self) -> HuggingFaceGithubDatasetEntry:
return self.next()

def reset(self):
self.iterator = iter(self.dataset.filter(self._filter_function))
def random(self) -> HuggingFaceGithubDatasetEntry:
return self.next()


if __name__ == "__main__":
Expand Down
21 changes: 12 additions & 9 deletions prompting/datasets/random_website.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Optional

import trafilatura
from duckduckgo_search.duckduckgo_search import DDGS
from loguru import logger

from prompting.base.duckduckgo_patch import PatchedDDGS
from prompting.datasets.utils import ENGLISH_WORDS
from shared import settings
from shared.base import BaseDataset, Context, DatasetEntry
Expand All @@ -25,7 +25,7 @@ class DDGDataset(BaseDataset):
english_words: list[str] = None

def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[list[dict[str, str]]]]:
ddg = PatchedDDGS(proxy=settings.shared_settings.PROXY_URL, verify=False)
ddg = DDGS(proxy=settings.shared_settings.PROXY_URL, verify=False)
exception: BaseException | None = None
for _ in range(retries):
random_words = " ".join(random.sample(ENGLISH_WORDS, 3))
Expand All @@ -40,13 +40,16 @@ def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[

@staticmethod
@lru_cache(maxsize=1000)
def extract_website_content(url: str) -> Optional[str]:
try:
website = trafilatura.fetch_url(url)
extracted = trafilatura.extract(website)
return extracted[:MAX_CHARS] if extracted else None
except Exception as ex:
logger.debug(f"Failed to extract content from website {url}: {ex}")
def extract_website_content(url: str, retries: int = 3) -> Optional[str]:
exception: Exception | None = None
for _ in range(retries):
try:
website = trafilatura.fetch_url(url)
extracted = trafilatura.extract(website)
return extracted[:MAX_CHARS] if extracted else None
except Exception as ex:
exception = ex
logger.debug(f"Failed to extract content from website {url} after {retries} retries: {exception}")

def next(self) -> Optional[DDGDatasetEntry]:
search_term, results = self.search_random_term(retries=5)
Expand Down
2 changes: 1 addition & 1 deletion prompting/tasks/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class WebRetrievalTask(BaseTextTask):
augmentation_system_prompt: ClassVar[str] = ""
query_system_prompt: ClassVar[Optional[str]] = QUERY_SYSTEM_PROMPT
target_results: int = Field(default_factory=lambda: random.randint(1, 10))
timeout: int = Field(default_factory=lambda: random.randint(5, 20))
timeout: int = Field(default_factory=lambda: random.randint(5, 15))

async def make_query(self, dataset_entry: DDGDatasetEntry) -> str:
self.query = await self.generate_query(
Expand Down
6 changes: 3 additions & 3 deletions shared/epistula.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def query_miners(
for uid in uids:
try:
timeout_connect = 10
timeout_postprocess = 5
timeout_postprocess = 1
response = asyncio.wait_for(
asyncio.create_task(
make_openai_query(
Expand All @@ -136,8 +136,8 @@ async def query_miners(
timeout_connect=timeout_connect,
)
),
# Give additional time for connect and result post-processings.
timeout=timeout_seconds + timeout_connect + timeout_postprocess,
# Give additional time for result post-processings.
timeout=timeout_seconds + timeout_postprocess,
)
except asyncio.TimeoutError:
logger.error(f"Timeout exceeded while querying miner {uid}")
Expand Down