Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 19 additions & 40 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions prompting/rewards/web_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@
TOP_DOMAINS = set(top_domains_df["Domain"].str.lower().values)

# Load past websites
if os.path.exists(PAST_WEBSITES_FILE):
if os.path.exists(PAST_WEBSITES_FILE) and os.path.getsize(PAST_WEBSITES_FILE) > 0:
past_websites_df = pd.read_csv(PAST_WEBSITES_FILE)
past_websites = defaultdict(list)
# Group by uid and take only the last N_PAST_URLS entries
for uid, group in past_websites_df.groupby("uid"):
past_websites[uid] = group["domain"].tolist()[-N_PAST_URLS:]
else:
logger.warning(f"Past websites file {PAST_WEBSITES_FILE} does not exist, creating new dictionary")
logger.warning(f"Past websites file {PAST_WEBSITES_FILE} does not exist or empty, creating new dictionary")
past_websites = defaultdict(list)
except Exception as e:
logger.exception(f"Failed to load domains data: {e}")
Expand Down
2 changes: 1 addition & 1 deletion prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class InferenceTask(BaseTextTask):
reference: str | None = None
system_prompt: str | None = None
llm_model: ModelConfig | None = None
llm_model_id: ModelConfig | None = random.choice(ModelZoo.models_configs).llm_model_id
llm_model_id: str | None = random.choice(ModelZoo.models_configs).llm_model_id
seed: int = Field(default_factory=lambda: random.randint(0, 1_000_000), allow_mutation=False)
sampling_params: dict[str, float] = shared_settings.SAMPLING_PARAMS.copy()
messages: list[dict] | None = None
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ isort = "^5.13.2"
tiktoken = "^0.8.0"
pillow = "^11.0.0"
torch = { version = "2.5.1", optional = true }
# TODO: Switch to original repo when this PR to fix setup gets merged: https://github.com/casper-hansen/AutoAWQ/pull/715
autoawq = { git = "https://github.com/jiqing-feng/AutoAWQ.git", rev = "ae782a99df2f72a2c28764452844cb2d65bd8ffc", optional = true }
transformers = { version = "<=4.47.1", optional = true }
torchvision = { version = ">=0.20.1", optional = true }
accelerate = { version = ">=1.1.1", optional = true }
autoawq = { version = "0.2.0", optional = true }
angle-emb = { version = "0.4.3", optional = true }
numpy = { version = ">=2.0.1", optional = true }
rouge = { version = ">=1.0.1", optional = true }
Expand Down
6 changes: 0 additions & 6 deletions scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ poetry config virtualenvs.in-project true

# Install the project dependencies
poetry install --extras "validator"

# Build AutoAWQ==0.2.8 from source
if [ -d AutoAWQ ]; then rm -rf AutoAWQ; fi
git clone https://github.com/casper-hansen/AutoAWQ.git
cd AutoAWQ && git checkout 16335d087dd4f9cdc8933dd7a5681e4bf88311b6 && poetry run pip install -e . && cd ..

poetry run pip install flash-attn --no-build-isolation

# Check if jq is installed and install it if not
Expand Down
4 changes: 2 additions & 2 deletions shared/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def __hash__(self) -> int:


class ChatEntry(DatasetEntry):
messages: list[dict]
messages: list[dict[str, str]]
organic: bool
source: str | None = None
query: str | None = None
query: dict[str, str] | None = None

@model_validator(mode="after")
def check_query(self) -> "ChatEntry":
Expand Down
2 changes: 1 addition & 1 deletion shared/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class SharedSettings(BaseSettings):
)
TEST_MINER_IDS: list[int] = Field([], env="TEST_MINER_IDS")
SUBTENSOR_NETWORK: Optional[str] = Field(None, env="SUBTENSOR_NETWORK")
MAX_ALLOWED_VRAM_GB: int = Field(62, env="MAX_ALLOWED_VRAM_GB")
MAX_ALLOWED_VRAM_GB: float = Field(62, env="MAX_ALLOWED_VRAM_GB")
LLM_MAX_MODEL_LEN: int = Field(4096, env="LLM_MAX_MODEL_LEN")
PROXY_URL: Optional[str] = Field(None, env="PROXY_URL")
LLM_MODEL: str = Field("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", env="LLM_MODEL")
Expand Down