Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,14 @@ jobs:
poetry run pip list

# Style/format checks.
- name: Run Black (code formatter)
run: |
poetry run black --check --diff .

- name: Run isort (import sorting)
run: |
poetry run isort --check-only --diff --profile black .

- name: Run Ruff (linter)
run: |
poetry run ruff check --diff .

- name: Run Ruff (formatter)
run: |
poetry run ruff format --check --diff .

- name: Test with pytest
run: |
# run tests in tests/ dir and only fail if there are failures or errors
Expand Down
16 changes: 1 addition & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,4 @@ repos:
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.7.0
hooks:
- id: black
name: black (code formatter)
language_version: python3.10
additional_dependencies: ["black[jupyter]"]

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (import sorting)
args: ["--profile", "black"]
- id: ruff-format
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ target-version = "py310"
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["C9", "E4", "E7", "E9", "F"]
# Also enable isort (I) for import sorting
select = ["C9", "E4", "E7", "E9", "F", "I"]
# Ignore "Module top level import" due to the settings initialization.
ignore = ["E402"]

Expand Down
1 change: 1 addition & 0 deletions scripts/test_api_empty_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Additional requirements:
pip install plotly kaleido
"""

import argparse
import asyncio
import csv
Expand Down
1 change: 1 addition & 0 deletions scripts/test_api_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Additional requirements:
pip install plotly kaleido
"""

import argparse
import asyncio
import csv
Expand Down
3 changes: 1 addition & 2 deletions shared/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ class BaseDataset(ABC, BaseModel):
max_tries: int = 10

@abstractmethod
def random(self) -> DatasetEntry:
...
def random(self) -> DatasetEntry: ...

def get(self) -> DatasetEntry:
return self.next()
Expand Down
10 changes: 7 additions & 3 deletions tests/prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
def _fake_tokenizer():
tok = MagicMock()
tok.apply_chat_template.side_effect = (
lambda conversation, tokenize, add_generation_prompt, continue_final_message: f"TEMPLATE::{conversation[-1]['role']}::{conversation[-1]['content']}"
lambda conversation,
tokenize,
add_generation_prompt,
continue_final_message: f"TEMPLATE::{conversation[-1]['role']}::{conversation[-1]['content']}"
)
tok.decode.side_effect = lambda ids: "<s>" if ids == [0] else f"tok{ids[0]}"
return tok
Expand Down Expand Up @@ -41,8 +44,9 @@ async def test_generate_logits(monkeypatch, messages, continue_last):
tokenizer_stub = _fake_tokenizer()
llm_stub = _fake_llm(fake_logprobs)

with patch("prompting.llms.vllm_llm.LLM", return_value=llm_stub), patch(
"prompting.llms.vllm_llm.SamplingParams", lambda **kw: kw
with (
patch("prompting.llms.vllm_llm.LLM", return_value=llm_stub),
patch("prompting.llms.vllm_llm.SamplingParams", lambda **kw: kw),
):
model = ReproducibleVLLM(model_id="mock-model")
# Swap tokenizer (LLM stub has none).
Expand Down