Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/push-to-hf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
[[ -d hf-repo ]] && rm -rf hf-repo
git clone https://huggingface.co/${HF_REPO_TYPE}/${HF_REPO_ID} hf-repo

rsync -a --delete --exclude='.git' ./ hf-repo/ || true
rsync -a --delete --exclude='.git' --exclude='hf-repo' ./ hf-repo/

cd hf-repo
git add .
Expand Down
2 changes: 1 addition & 1 deletion graphgen/configs/multi_hop_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ search: # web search configuration
enabled: false # whether to enable web search
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
enabled: true
enabled: false
quiz_samples: 2 # number of quiz samples to generate
re_judge: false # whether to re-judge the existing quiz samples
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
Expand Down
9 changes: 5 additions & 4 deletions graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
judge_statement,
quiz,
search_all,
traverse_graph_atomically,
traverse_graph_by_edge,
traverse_graph_for_aggregated,
traverse_graph_for_atomic,
traverse_graph_for_multi_hop,
)
from .utils import (
Expand Down Expand Up @@ -69,6 +69,7 @@ def __post_init__(self):
self.tokenizer_instance: Tokenizer = Tokenizer(
model_name=self.config["tokenizer"]
)
print(os.getenv("SYNTHESIZER_MODEL"), os.getenv("SYNTHESIZER_API_KEY"))
self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
model_name=os.getenv("SYNTHESIZER_MODEL"),
api_key=os.getenv("SYNTHESIZER_API_KEY"),
Expand Down Expand Up @@ -326,7 +327,7 @@ async def async_traverse(self):
output_data_type = self.config["output_data_type"]

if output_data_type == "atomic":
results = await traverse_graph_atomically(
results = await traverse_graph_for_atomic(
self.synthesizer_llm_client,
self.tokenizer_instance,
self.graph_storage,
Expand All @@ -344,7 +345,7 @@ async def async_traverse(self):
self.progress_bar,
)
elif output_data_type == "aggregated":
results = await traverse_graph_by_edge(
results = await traverse_graph_for_aggregated(
self.synthesizer_llm_client,
self.tokenizer_instance,
self.graph_storage,
Expand Down
8 changes: 4 additions & 4 deletions graphgen/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from .judge import judge_statement
from .quiz import quiz
from .traverse_graph import (
traverse_graph_atomically,
traverse_graph_by_edge,
traverse_graph_for_aggregated,
traverse_graph_for_atomic,
traverse_graph_for_multi_hop,
)

Expand All @@ -15,8 +15,8 @@
"quiz",
"judge_statement",
"search_all",
"traverse_graph_by_edge",
"traverse_graph_atomically",
"traverse_graph_for_aggregated",
"traverse_graph_for_atomic",
"traverse_graph_for_multi_hop",
"generate_cot",
]
46 changes: 28 additions & 18 deletions graphgen/operators/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def get_average_loss(batch: tuple, loss_strategy: str) -> float:
) / (len(batch[0]) + len(batch[1]))
raise ValueError("Invalid loss strategy")
except Exception as e: # pylint: disable=broad-except
logger.error("Error calculating average loss: %s", e)
logger.warning(
"Loss not found in some nodes or edges, setting loss to -1.0: %s", e
)
return -1.0


Expand All @@ -158,7 +160,7 @@ def _post_process_synthetic_data(data):
return qas


async def traverse_graph_by_edge(
async def traverse_graph_for_aggregated(
llm_client: OpenAIModel,
tokenizer: Tokenizer,
graph_storage: NetworkXStorage,
Expand Down Expand Up @@ -251,7 +253,6 @@ async def _process_single_batch(
qas = _post_process_synthetic_data(content)

if len(qas) == 0:
print(content)
logger.error(
"Error occurred while processing batch, question or answer is None"
)
Expand Down Expand Up @@ -307,7 +308,8 @@ async def _process_single_batch(
return results


async def traverse_graph_atomically(
# pylint: disable=too-many-branches, too-many-statements
async def traverse_graph_for_atomic(
llm_client: OpenAIModel,
tokenizer: Tokenizer,
graph_storage: NetworkXStorage,
Expand All @@ -328,17 +330,28 @@ async def traverse_graph_atomically(
:param max_concurrent
:return: question and answer
"""
assert traverse_strategy.qa_form == "atomic"

assert traverse_strategy.qa_form == "atomic"
semaphore = asyncio.Semaphore(max_concurrent)

def _parse_qa(qa: str) -> tuple:
if "Question:" in qa and "Answer:" in qa:
question = qa.split("Question:")[1].split("Answer:")[0].strip()
answer = qa.split("Answer:")[1].strip()
elif "问题:" in qa and "答案:" in qa:
question = qa.split("问题:")[1].split("答案:")[0].strip()
answer = qa.split("答案:")[1].strip()
else:
return None, None
return question.strip('"'), answer.strip('"')

async def _generate_question(node_or_edge: tuple):
if len(node_or_edge) == 2:
des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
loss = node_or_edge[1]["loss"]
loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0
else:
des = node_or_edge[2]["description"]
loss = node_or_edge[2]["loss"]
loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0

async with semaphore:
try:
Expand All @@ -350,13 +363,8 @@ async def _generate_question(node_or_edge: tuple):
)
)

if "Question:" in qa and "Answer:" in qa:
question = qa.split("Question:")[1].split("Answer:")[0].strip()
answer = qa.split("Answer:")[1].strip()
elif "问题:" in qa and "答案:" in qa:
question = qa.split("问题:")[1].split("答案:")[0].strip()
answer = qa.split("答案:")[1].strip()
else:
question, answer = _parse_qa(qa)
if question is None or answer is None:
return {}

question = question.strip('"')
Expand Down Expand Up @@ -386,16 +394,18 @@ async def _generate_question(node_or_edge: tuple):
if "<SEP>" in node[1]["description"]:
description_list = node[1]["description"].split("<SEP>")
for item in description_list:
tasks.append((node[0], {"description": item, "loss": node[1]["loss"]}))
tasks.append((node[0], {"description": item}))
if "loss" in node[1]:
tasks[-1][1]["loss"] = node[1]["loss"]
else:
tasks.append((node[0], node[1]))
for edge in edges:
if "<SEP>" in edge[2]["description"]:
description_list = edge[2]["description"].split("<SEP>")
for item in description_list:
tasks.append(
(edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]})
)
tasks.append((edge[0], edge[1], {"description": item}))
if "loss" in edge[2]:
tasks[-1][2]["loss"] = edge[2]["loss"]
else:
tasks.append((edge[0], edge[1], edge[2]))

Expand Down
Loading