Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Furthermore, GraphGen incorporates multi-hop neighborhood sampling to capture co

## 📌 Latest Updates

- **2025.07.31**: We have added Google, Bing, Wikipedia, and UniProt as search back-ends, perfect for closing data gaps.
- **2025.07.31**: We have added Google, Bing, Wikipedia, and UniProt as search back-ends.
- **2025.04.21**: We have released the initial version of GraphGen.

## 🚀 Quick Start
Expand Down Expand Up @@ -136,18 +136,31 @@ For any questions, please check [FAQ](https://github.com/open-sciencelab/GraphGe
TRAINEE_BASE_URL=your_base_url_for_trainee_model
TRAINEE_API_KEY=your_api_key_for_trainee_model
```
2. (Optional) If you want to modify the default generated configuration, you can edit the content of the configs/graphgen_config.yaml file.
2. (Optional) Customize generation parameters in `graphgen/configs/` folder.

Edit the corresponding YAML file, e.g.:

```yaml
# configs/graphgen_config.yaml
# Example configuration
data_type: "raw"
input_file: "resources/examples/raw_demo.jsonl"
# more configurations...
# configs/cot_config.yaml
input_data_type: raw
input_file: resources/input_examples/raw_demo.jsonl
output_data_type: cot
tokenizer: cl100k_base
# additional settings...
```
3. Run the generation script
```bash
bash scripts/generate.sh
```

3. Generate data

Pick the desired format and run the matching script:

| Format | Script to run | Notes |
| ------------ | ---------------------------------------------- |-------------------------------------------------------------------|
| `cot` | `bash scripts/generate/generate_cot.sh` | Chain-of-Thought Q\&A pairs |
| `atomic` | `bash scripts/generate/generate_atomic.sh` | Atomic Q\&A pairs covering basic knowledge |
| `aggregated` | `bash scripts/generate/generate_aggregated.sh` | Aggregated Q\&A pairs incorporating complex, integrated knowledge |
| `multi-hop` | `bash scripts/generate/generate_multihop.sh` | Multi-hop reasoning Q\&A pairs |


4. Get the generated data
```bash
ls cache/data/graphgen
Expand Down Expand Up @@ -176,7 +189,8 @@ See [analysis](https://deepwiki.com/open-sciencelab/GraphGen) by deepwiki for a
## 🍀 Acknowledgements
- [SiliconFlow](https://siliconflow.cn) Abundant LLM API, some models are free
- [LightRAG](https://github.com/HKUDS/LightRAG) Simple and efficient graph retrieval solution
- [ROGRAG](https://github.com/tpoisonooo/ROGRAG) ROGRAG: A Robustly Optimized GraphRAG Framework
- [ROGRAG](https://github.com/tpoisonooo/ROGRAG) A robustly optimized GraphRAG framework
- [DB-GPT](https://github.com/eosphoros-ai/DB-GPT) An AI native data app development framework


## 📚 Citation
Expand Down
164 changes: 90 additions & 74 deletions baselines/EntiGraph/entigraph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# https://arxiv.org/abs/2409.07431
# https://github.com/zitongyang/synthetic_continued_pretraining

import os
import argparse
import asyncio
import json
import os
import random
import asyncio
import argparse
from hashlib import md5

from tqdm.asyncio import tqdm as tqdm_async
Expand All @@ -18,9 +18,9 @@ def compute_content_hash(content, prefix: str = ""):
return prefix + md5(content.encode()).hexdigest()


async def generate_entities(document_content: str,
system_message: str,
openai_model: str):
async def generate_entities(
document_content: str, system_message: str, openai_model: str
):
prompt = f"""
### Document Content:
{document_content}
Expand All @@ -30,41 +30,44 @@ async def generate_entities(document_content: str,
max_tries = 5
while not can_read_entities and max_tries > 0:
try:
completion = await gptqa(prompt,
openai_model,
system_message,
json_format=False)
completion = completion[completion.find("{"): completion.rfind("}") + 1]
completion = await gptqa(
prompt, openai_model, system_message, json_format=False
)
completion = completion[completion.find("{") : completion.rfind("}") + 1]
response = json.loads(completion)
can_read_entities = response['entities']
can_read_entities = response["entities"]
return response
except Exception as e: # pylint: disable=broad-except
except Exception as e: # pylint: disable=broad-except
print(f"Failed to generate entities: {str(e)}")
max_tries -= 1

async def generate_two_entity_relations(document_content: str,
entity1: str,
entity2: str,
system_message: str,
openai_model: str):

async def generate_two_entity_relations(
document_content: str,
entity1: str,
entity2: str,
system_message: str,
openai_model: str,
):
prompt = f"""
### Document Content:
{document_content}
### Entities:
- {entity1}
- {entity2}
"""
completion = await gptqa(prompt,
openai_model,
system_message)
completion = await gptqa(prompt, openai_model, system_message)
return completion

async def generate_three_entity_relations(document_content: str,
entity1: str,
entity2: str,
entity3: str,
system_message: str,
openai_model: str):

async def generate_three_entity_relations(
document_content: str,
entity1: str,
entity2: str,
entity3: str,
system_message: str,
openai_model: str,
):
prompt = f"""
### Document Content:
{document_content}
Expand All @@ -73,11 +76,10 @@ async def generate_three_entity_relations(document_content: str,
- {entity2}
- {entity3}
"""
completion = await gptqa(prompt,
openai_model,
system_message)
completion = await gptqa(prompt, openai_model, system_message)
return completion


def _post_process_synthetic_data(data):
block = data.split("\n\n")
qas = {}
Expand All @@ -87,7 +89,7 @@ def _post_process_synthetic_data(data):
answer = line.split("Answer: ")[1]
qas[compute_content_hash(question)] = {
"question": question,
"answer": answer
"answer": answer,
}
break
return qas
Expand All @@ -105,25 +107,26 @@ async def generate_document_entities(doc):
async with semaphore:
try:
entities = await generate_entities(
doc.text,
task.openai_system_generate_entities,
model_name)
doc.text, task.openai_system_generate_entities, model_name
)
if not entities:
return None
return {
'document': doc.text,
'entities': entities['entities'],
'summary': entities['summary']
"document": doc.text,
"entities": entities["entities"],
"summary": entities["summary"],
}
except Exception as e: # pylint: disable=broad-except
except Exception as e: # pylint: disable=broad-except
print(f"Error: {e}")
return None

entities_list = []
for result in tqdm_async(
asyncio.as_completed([generate_document_entities(doc) for doc in task.documents]),
total=len(task.documents),
desc="Generating entities"
asyncio.as_completed(
[generate_document_entities(doc) for doc in task.documents]
),
total=len(task.documents),
desc="Generating entities",
):
result = await result
if result:
Expand All @@ -132,38 +135,42 @@ async def generate_document_entities(doc):
# iterate over triples of entities and generate relations
pair_list = []
for doc in entities_list:
entities = doc['entities']
entities = doc["entities"]
temp = []
for i, entity_i in enumerate(entities):
if i == len(entities) - 1:
break
for j in range(i + 1, len(entities)):
entity_j = entities[j]
pair = (doc['document'], entity_i, entity_j)
pair = (doc["document"], entity_i, entity_j)
temp.append(pair)

# Compute all possible combinations of entities is impractical, so we randomly sample 10 pairs
pair_list.extend(random.sample(temp, min(len(temp), 10)))


async def process_two_entity_relations(pair):
async with semaphore:
try:
document, entity1, entity2 = pair
response = await generate_two_entity_relations(
document, entity1, entity2,
document,
entity1,
entity2,
task.openai_system_generate_two_entity_relations,
model_name)
model_name,
)
return response
except Exception as e: # pylint: disable=broad-except
except Exception as e: # pylint: disable=broad-except
print(f"Error: {e}")
return None

corpus= []
corpus = []
for result in tqdm_async(
asyncio.as_completed([process_two_entity_relations(pair) for pair in pair_list]),
total=len(pair_list),
desc="Generating two entity relations"
asyncio.as_completed(
[process_two_entity_relations(pair) for pair in pair_list]
),
total=len(pair_list),
desc="Generating two entity relations",
):
result = await result
if result:
Expand Down Expand Up @@ -194,51 +201,60 @@ async def process_two_entity_relations(pair):
# ):
# corpus.append(await result)

corpus = [doc['summary'] for doc in entities_list] + corpus
corpus = [doc["summary"] for doc in entities_list] + corpus

qa_sft_results = {}

async def generate_qa_sft(content):
async with semaphore:
completion = await gptqa(content, model_name, task.openai_system_quality_qa_sft)
completion = await gptqa(
content, model_name, task.openai_system_quality_qa_sft
)
return completion


for result in tqdm_async(
asyncio.as_completed([generate_qa_sft(content) for content in corpus]),
total=len(corpus),
desc="Generating QA SFT"
asyncio.as_completed([generate_qa_sft(content) for content in corpus]),
total=len(corpus),
desc="Generating QA SFT",
):
try:
result = await result
if result:
qa_sft_results.update(_post_process_synthetic_data(result))
except Exception as e: # pylint: disable=broad-except
except Exception as e: # pylint: disable=broad-except
print(f"Error: {e}")

return qa_sft_results


if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_file',
help='Raw context jsonl path.',
default='resources/examples/chunked_demo.json',
type=str)
parser.add_argument('--data_type',
help='Data type of input file. (Raw context or chunked context)',
choices=['raw', 'chunked'],
default='raw',
type=str)
parser.add_argument('--output_file',
help='Output file path.',
default='cache/data/entigraph.json',
type=str)
parser.add_argument(
"--input_file",
help="Raw context jsonl path.",
default="resources/input_examples/chunked_demo.json",
type=str,
)
parser.add_argument(
"--data_type",
help="Data type of input file. (Raw context or chunked context)",
choices=["raw", "chunked"],
default="raw",
type=str,
)
parser.add_argument(
"--output_file",
help="Output file path.",
default="cache/data/entigraph.json",
type=str,
)

args = parser.parse_args()

results = asyncio.run(generate_synthetic_data_for_document(args.input_file, args.data_type))
results = asyncio.run(
generate_synthetic_data_for_document(args.input_file, args.data_type)
)

# Save results
with open(args.output_file, "w", encoding='utf-8') as f:
with open(args.output_file, "w", encoding="utf-8") as f:
json.dump(results, f, indent=4, ensure_ascii=False)
Loading