Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,18 @@ jobs:
run: |
poetry run poe test_integration

- name: Smoke Test
if: steps.changes.outputs.python == 'true'
run: |
poetry run poe test_smoke

- uses: actions/upload-artifact@v4
if: always()
with:
name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }}
path: tests/fixtures/*/output

- name: E2E Test
if: steps.changes.outputs.python == 'true'
run: |
./scripts/e2e-test.sh
# - name: Smoke Test
# if: steps.changes.outputs.python == 'true'
# run: |
# poetry run poe test_smoke

# - uses: actions/upload-artifact@v4
# if: always()
# with:
# name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }}
# path: tests/fixtures/*/output

# - name: E2E Test
# if: steps.changes.outputs.python == 'true'
# run: |
# ./scripts/e2e-test.sh
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240726181256417715.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "add encoding-model to entity/claim extraction config"
}
14 changes: 8 additions & 6 deletions docsite/posts/config/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,12 @@ These settings control the data input used by the pipeline. Any settings with a

## Data Chunking

| Parameter | Description | Type | Required or Optional | Default |
| --------------------------- | ------------------------------------------------------------------------------------------- | ----- | -------------------- | ------- |
| `GRAPHRAG_CHUNK_SIZE` | The chunk size in tokens for text-chunk analysis windows. | `str` | optional | 1200 |
| `GRAPHRAG_CHUNK_OVERLAP` | The chunk overlap in tokens for text-chunk analysis windows. | `str` | optional | 100 |
| `GRAPHRAG_CHUNK_BY_COLUMNS` | A comma-separated list of document attributes to groupby when performing TextUnit chunking. | `str` | optional | `id` |
| `GRAPHRAG_CHUNK_ENCODING_MODEL` | The encoding model to use for chunking. | `str` | optional | `None` |
| Parameter | Description | Type | Required or Optional | Default |
| ------------------------------- | ------------------------------------------------------------------------------------------- | ----- | -------------------- | ----------------------------- |
| `GRAPHRAG_CHUNK_SIZE` | The chunk size in tokens for text-chunk analysis windows. | `str` | optional | 1200 |
| `GRAPHRAG_CHUNK_OVERLAP` | The chunk overlap in tokens for text-chunk analysis windows. | `str` | optional | 100 |
| `GRAPHRAG_CHUNK_BY_COLUMNS` | A comma-separated list of document attributes to groupby when performing TextUnit chunking. | `str` | optional | `id` |
| `GRAPHRAG_CHUNK_ENCODING_MODEL` | The encoding model to use for chunking. | `str` | optional | The top-level encoding model. |

## Prompting Overrides

Expand All @@ -146,12 +146,14 @@ These settings control the data input used by the pipeline. Any settings with a
| `GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE` | The path (relative to the root) of an entity extraction prompt template text file. | `str` | optional | `None` |
| `GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS` | The maximum number of redrives (gleanings) to invoke when extracting entities in a loop. | `int` | optional | 1 |
| `GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES` | A comma-separated list of entity types to extract. | `str` | optional | `organization,person,event,geo` |
| `GRAPHRAG_ENTITY_EXTRACTION_ENCODING_MODEL` | The encoding model to use for entity extraction. | `str` | optional | The top-level encoding model. |
| `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE` | The path (relative to the root) of an description summarization prompt template text file. | `str` | optional | `None` |
| `GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH` | The maximum number of tokens to generate per description summarization. | `int` | optional | 500 |
| `GRAPHRAG_CLAIM_EXTRACTION_ENABLED` | Whether claim extraction is enabled for this pipeline. | `bool` | optional | `False` |
| `GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION` | The claim_description prompting argument to utilize. | `string` | optional | "Any claims or facts that could be relevant to threat analysis." |
| `GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE` | The claim extraction prompt to utilize. | `string` | optional | `None` |
| `GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS` | The maximum number of redrives (gleanings) to invoke when extracting claims in a loop. | `int` | optional | 1 |
| `GRAPHRAG_CLAIM_EXTRACTION_ENCODING_MODEL` | The encoding model to use for claim extraction. | `str` | optional | The top-level encoding model |
| `GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE` | The community reports extraction prompt to utilize. | `string` | optional | `None` |
| `GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH` | The maximum number of tokens to generate per community reports. | `int` | optional | 1500 |

Expand Down
2 changes: 2 additions & 0 deletions docsite/posts/config/json_yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ This is the base LLM configuration section. Other steps may override this config
- `prompt` **str** - The prompt file to use.
- `entity_types` **list[str]** - The entity types to identify.
- `max_gleanings` **int** - The maximum number of gleaning cycles to use.
- `encoding_model` **str** - The text encoding model to use. By default, this will use the top-level encoding model.
- `strategy` **dict** - Fully override the entity extraction strategy.

## summarize_descriptions
Expand All @@ -169,6 +170,7 @@ This is the base LLM configuration section. Other steps may override this config
- `prompt` **str** - The prompt file to use.
- `description` **str** - Describes the types of claims we want to extract.
- `max_gleanings` **int** - The maximum number of gleaning cycles to use.
- `encoding_model` **str** - The text encoding model to use. By default, this will use the top-level encoding model.
- `strategy` **dict** - Fully override the claim extraction strategy.

## community_reports
Expand Down
3 changes: 3 additions & 0 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def hydrate_parallelization_params(
size=reader.int("size") or defs.CHUNK_SIZE,
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
group_by_columns=group_by_columns,
encoding_model=reader.str(Fragment.encoding_model),
)
with (
reader.envvar_prefix(Section.snapshot),
Expand Down Expand Up @@ -428,6 +429,7 @@ def hydrate_parallelization_params(
or defs.ENTITY_EXTRACTION_ENTITY_TYPES,
max_gleanings=max_gleanings,
prompt=reader.str("prompt", Fragment.prompt_file),
encoding_model=reader.str(Fragment.encoding_model),
)

claim_extraction_config = values.get("claim_extraction") or {}
Expand All @@ -449,6 +451,7 @@ def hydrate_parallelization_params(
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
prompt=reader.str("prompt", Fragment.prompt_file),
max_gleanings=max_gleanings,
encoding_model=reader.str(Fragment.encoding_model),
)

community_report_config = values.get("community_reports") or {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ class ClaimExtractionConfigInput(LLMConfigInput):
description: NotRequired[str | None]
max_gleanings: NotRequired[int | str | None]
strategy: NotRequired[dict | None]
encoding_model: NotRequired[str | None]
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ class EntityExtractionConfigInput(LLMConfigInput):
entity_types: NotRequired[list[str] | str | None]
max_gleanings: NotRequired[int | str | None]
strategy: NotRequired[dict | None]
encoding_model: NotRequired[str | None]
6 changes: 5 additions & 1 deletion graphrag/config/models/claim_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ class ClaimExtractionConfig(LLMConfig):
strategy: dict | None = Field(
description="The override strategy to use.", default=None
)
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
)

def resolved_strategy(self, root_dir: str) -> dict:
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
"""Get the resolved claim extraction strategy."""
from graphrag.index.verbs.covariates.extract_covariates import (
ExtractClaimsStrategyType,
Expand All @@ -50,4 +53,5 @@ def resolved_strategy(self, root_dir: str) -> dict:
else None,
"claim_description": self.description,
"max_gleanings": self.max_gleanings,
"encoding_name": self.encoding_model or encoding_model,
}
5 changes: 4 additions & 1 deletion graphrag/config/models/entity_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class EntityExtractionConfig(LLMConfig):
strategy: dict | None = Field(
description="Override the default entity extraction strategy", default=None
)
encoding_model: str | None = Field(
default=None, description="The encoding model to use."
)

def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
"""Get the resolved entity extraction strategy."""
Expand All @@ -45,6 +48,6 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
else None,
"max_gleanings": self.max_gleanings,
# It's prechunked in create_base_text_units
"encoding_name": encoding_model,
"encoding_name": self.encoding_model or encoding_model,
"prechunked": True,
}
2 changes: 1 addition & 1 deletion graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _covariate_workflows(
"claim_extract": {
**settings.claim_extraction.parallelization.model_dump(),
"strategy": settings.claim_extraction.resolved_strategy(
settings.root_dir
settings.root_dir, settings.encoding_model
),
},
},
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/config/test_default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123",
"GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000",
"GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-a.txt",
"GRAPHRAG_CLAIM_EXTRACTION_ENCODING_MODEL": "encoding_a",
"GRAPHRAG_COMMUNITY_REPORTS_MAX_LENGTH": "23456",
"GRAPHRAG_COMMUNITY_REPORTS_PROMPT_FILE": "tests/unit/config/prompt-b.txt",
"GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17",
Expand All @@ -115,6 +116,7 @@
"GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant",
"GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112",
"GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/config/prompt-c.txt",
"GRAPHRAG_ENTITY_EXTRACTION_ENCODING_MODEL": "encoding_b",
"GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir",
"GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs",
"GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn",
Expand Down Expand Up @@ -543,6 +545,7 @@ def test_create_parameters_from_env_vars(self) -> None:
assert parameters.claim_extraction.description == "test 123"
assert parameters.claim_extraction.max_gleanings == 5000
assert parameters.claim_extraction.prompt == "tests/unit/config/prompt-a.txt"
assert parameters.claim_extraction.encoding_model == "encoding_a"
assert parameters.cluster_graph.max_cluster_size == 123
assert parameters.community_reports.max_length == 23456
assert parameters.community_reports.prompt == "tests/unit/config/prompt-b.txt"
Expand Down Expand Up @@ -572,6 +575,7 @@ def test_create_parameters_from_env_vars(self) -> None:
assert parameters.entity_extraction.llm.api_base == "http://some/base"
assert parameters.entity_extraction.max_gleanings == 112
assert parameters.entity_extraction.prompt == "tests/unit/config/prompt-c.txt"
assert parameters.entity_extraction.encoding_model == "encoding_b"
assert parameters.input.storage_account_blob_url == "input_account_blob_url"
assert parameters.input.base_dir == "/some/input/dir"
assert parameters.input.connection_string == "input_cs"
Expand Down Expand Up @@ -910,7 +914,7 @@ def test_prompt_file_reading(self):
assert strategy["extraction_prompt"] == "Hello, World! A"
assert strategy["encoding_name"] == "abc123"

strategy = config.claim_extraction.resolved_strategy(".")
strategy = config.claim_extraction.resolved_strategy(".", "encoding_b")
assert strategy["extraction_prompt"] == "Hello, World! B"

strategy = config.community_reports.resolved_strategy(".")
Expand Down