Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 317 additions & 22 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions prompting/base/duckduckgo_patch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from threading import Event
from typing import cast

import httpx
Expand All @@ -13,6 +14,7 @@ def __init__(self, *args, **kwargs):
timeout=kwargs.get("timeout", 10),
verify=kwargs.get("verify", True),
)
self._exception_event = Event()

def _get_url(
self: DDGS,
Expand Down
8 changes: 5 additions & 3 deletions prompting/datasets/random_website.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class DDGDatasetEntry(DatasetEntry):
search_term: str
website_url: str = None
website_content: str = None
query: str | None = None


class DDGDataset(BaseDataset):
Expand All @@ -31,7 +32,8 @@ def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[
if results:
return random_words, results
except Exception as ex:
logger.error(f"Failed to get search results from DuckDuckGo: {ex}")
logger.debug(f"Failed to get search results from DuckDuckGo: {ex}")
logger.warning(f"Failed to get search results from DuckDuckGo after {retries} tries")
return None, None

@staticmethod
Expand All @@ -41,7 +43,7 @@ def extract_website_content(url: str) -> Optional[str]:
extracted = trafilatura.extract(website)
return extracted[:MAX_CHARS] if extracted else None
except Exception as ex:
logger.error(f"Failed to extract content from website {url}: {ex}")
logger.debug(f"Failed to extract content from website {url}: {ex}")

def next(self) -> Optional[DDGDatasetEntry]:
search_term, results = self.search_random_term(retries=5)
Expand All @@ -50,7 +52,7 @@ def next(self) -> Optional[DDGDatasetEntry]:
website_url = results[0]["href"]
website_content = self.extract_website_content(website_url)
if not website_content or len(website_content) == 0:
logger.error(f"Failed to extract content from website {website_url}")
logger.debug(f"Failed to extract content from website {website_url}")
return None

return DDGDatasetEntry(search_term=search_term, website_url=website_url, website_content=website_content)
Expand Down
2 changes: 1 addition & 1 deletion prompting/rewards/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs
"""Score response website content and URL based on the similarity to the search term and reference content."""
rewards: list[float] = []
timings: list[float] = []
dataset_entry = DDGDatasetEntry.model_validate_json(json.loads(reference))
for completion in response_event.completions:
timer_start = time.perf_counter()

Expand All @@ -59,7 +60,6 @@ def reward(self, reference: str, response_event: DendriteResponseEvent, **kwargs
rewards.append(0)
continue

dataset_entry = DDGDatasetEntry.model_validate_json(json.loads(reference))
query = dataset_entry.query
reference_content = dataset_entry.website_content

Expand Down
2 changes: 1 addition & 1 deletion prompting/tasks/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_query(self, dataset_entry: DDGDatasetEntry) -> str:
return self.query

def make_reference(self, dataset_entry: DDGDatasetEntry) -> str:
dataset_entry.query = self.query
ref_dict = dataset_entry.model_dump_json()
ref_dict["query"] = self.query
self.reference = json.dumps(ref_dict)
return self.reference
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ numpy = { version = ">=2.0.1", optional = true }
rouge = { version = ">=1.0.1", optional = true }
bs4 = { version = ">=0.0.2", optional = true }
wikipedia = { version = ">=1.4.0", optional = true }
duckduckgo-search = { version = ">=6.3.7", optional = true }
duckduckgo-search = "^7.2.1"
huggingface-hub = { version = ">=0.25.2", optional = true }
pandas = { version = ">=2.2.1", optional = true }
trafilatura = { version = ">=1.12.1", optional = true }
datasets = { version = ">=3.1.0", optional = true }
primp = { version = "==0.8.1", optional = true }
primp = "^0.10.0"
nltk = { version = ">=3.8.1", optional = true }

[tool.poetry.extras]
Expand Down
Loading