From 556341c7bd4b2dc20d5e4d810ab631a0e8328747 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 23 Sep 2025 16:50:46 +0800 Subject: [PATCH 1/4] feat: support Reader classes --- README.md | 4 +- README_ZH.md | 4 +- baselines/EntiGraph/entigraph.py | 2 +- baselines/Genie/genie.py | 2 +- baselines/LongForm/longform.py | 2 +- baselines/SELF-QA/self-qa.py | 2 +- baselines/Wrap/wrap.py | 2 +- graphgen/{version.py => _version.py} | 11 +- graphgen/{models/embed => bases}/__init__.py | 0 graphgen/bases/base_reader.py | 20 +++ .../{models/storage => bases}/base_storage.py | 4 - graphgen/configs/aggregated_config.yaml | 3 +- graphgen/configs/atomic_config.yaml | 3 +- graphgen/configs/cot_config.yaml | 3 +- graphgen/configs/multi_hop_config.yaml | 3 +- graphgen/graphgen.py | 127 ++++++------------ graphgen/judge.py | 60 --------- graphgen/models/__init__.py | 1 + graphgen/models/embed/embedding.py | 29 ---- graphgen/models/reader/__init__.py | 22 +++ graphgen/models/reader/csv_reader.py | 13 ++ graphgen/models/reader/json_reader.py | 18 +++ graphgen/models/reader/jsonl_reader.py | 22 +++ graphgen/models/reader/txt_reader.py | 14 ++ graphgen/models/storage/json_storage.py | 2 +- graphgen/models/storage/networkx_storage.py | 34 +++-- graphgen/models/strategy/base_strategy.py | 5 - .../models/strategy/travserse_strategy.py | 12 +- graphgen/operators/kg/extract_kg.py | 2 +- graphgen/operators/kg/merge_kg.py | 2 +- graphgen/utils/__init__.py | 1 - graphgen/utils/file.py | 24 ---- resources/input_examples/chunked_demo.json | 14 -- resources/input_examples/csv_demo.csv | 5 + resources/input_examples/json_demo.json | 6 + .../{raw_demo.jsonl => jsonl_demo.jsonl} | 0 resources/input_examples/keywords_demo.txt | 5 - setup.py | 54 ++++---- webui/app.py | 38 +----- webui/count_tokens.py | 33 +++-- webui/examples/chunked_demo.json | 14 -- webui/examples/csv_demo.csv | 5 + webui/examples/json_demo.json | 6 + .../{raw_demo.jsonl => jsonl_demo.jsonl} | 0 44 files changed, 273 insertions(+), 360 deletions(-) rename graphgen/{version.py => _version.py} (73%) rename graphgen/{models/embed => bases}/__init__.py (100%) create mode 100644 graphgen/bases/base_reader.py rename graphgen/{models/storage => bases}/base_storage.py (96%) delete mode 100644 graphgen/judge.py delete mode 100644 graphgen/models/embed/embedding.py create mode 100644 graphgen/models/reader/__init__.py create mode 100644 graphgen/models/reader/csv_reader.py create mode 100644 graphgen/models/reader/json_reader.py create mode 100644 graphgen/models/reader/jsonl_reader.py create mode 100644 graphgen/models/reader/txt_reader.py delete mode 100644 graphgen/models/strategy/base_strategy.py delete mode 100644 graphgen/utils/file.py delete mode 100644 resources/input_examples/chunked_demo.json create mode 100644 resources/input_examples/csv_demo.csv create mode 100644 resources/input_examples/json_demo.json rename resources/input_examples/{raw_demo.jsonl => jsonl_demo.jsonl} (100%) delete mode 100644 resources/input_examples/keywords_demo.txt delete mode 100644 webui/examples/chunked_demo.json create mode 100644 webui/examples/csv_demo.csv create mode 100644 webui/examples/json_demo.json rename webui/examples/{raw_demo.jsonl => jsonl_demo.jsonl} (100%) diff --git a/README.md b/README.md index dec91996..c24cd861 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ For any questions, please check [FAQ](https://github.com/open-sciencelab/GraphGe ### Run Gradio Demo ```bash - python -m webui.app.py + python -m webui.app ``` ![ui](https://github.com/user-attachments/assets/3024e9bc-5d45-45f8-a4e6-b57bd2350d84) @@ -148,7 +148,7 @@ For any questions, please check [FAQ](https://github.com/open-sciencelab/GraphGe ```yaml # configs/cot_config.yaml input_data_type: raw - input_file: resources/input_examples/raw_demo.jsonl + input_file: resources/input_examples/jsonl_demo.jsonl output_data_type: cot tokenizer: cl100k_base # additional settings... diff --git a/README_ZH.md b/README_ZH.md index b577b0d6..151b7638 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -99,7 +99,7 @@ GraphGen 首先根据源文本构建细粒度的知识图谱,然后利用期 ### 运行 Gradio 演示 ```bash - python -m webui.app.py + python -m webui.app ``` ![ui](https://github.com/user-attachments/assets/3024e9bc-5d45-45f8-a4e6-b57bd2350d84) @@ -147,7 +147,7 @@ GraphGen 首先根据源文本构建细粒度的知识图谱,然后利用期 ```yaml # configs/cot_config.yaml input_data_type: raw - input_file: resources/input_examples/raw_demo.jsonl + input_file: resources/input_examples/jsonl_demo.jsonl output_data_type: cot tokenizer: cl100k_base # 其他设置... diff --git a/baselines/EntiGraph/entigraph.py b/baselines/EntiGraph/entigraph.py index 3020c71d..07d3d5dc 100644 --- a/baselines/EntiGraph/entigraph.py +++ b/baselines/EntiGraph/entigraph.py @@ -232,7 +232,7 @@ async def generate_qa_sft(content): parser.add_argument( "--input_file", help="Raw context jsonl path.", - default="resources/input_examples/chunked_demo.json", + default="resources/input_examples/json_demo.json", type=str, ) parser.add_argument( diff --git a/baselines/Genie/genie.py b/baselines/Genie/genie.py index 3ca529af..75e713e7 100644 --- a/baselines/Genie/genie.py +++ b/baselines/Genie/genie.py @@ -100,7 +100,7 @@ async def process_chunk(content: str): parser.add_argument( "--input_file", help="Raw context jsonl path.", - default="resources/input_examples/chunked_demo.json", + default="resources/input_examples/json_demo.json", type=str, ) parser.add_argument( diff --git a/baselines/LongForm/longform.py b/baselines/LongForm/longform.py index db352fed..31feb01a 100644 --- a/baselines/LongForm/longform.py +++ b/baselines/LongForm/longform.py @@ -67,7 +67,7 @@ async def process_chunk(content: str): parser.add_argument( "--input_file", help="Raw context jsonl path.", - default="resources/input_examples/chunked_demo.json", + default="resources/input_examples/json_demo.json", type=str, ) parser.add_argument( diff --git a/baselines/SELF-QA/self-qa.py b/baselines/SELF-QA/self-qa.py index d0a8b878..8ee0307f 100644 --- a/baselines/SELF-QA/self-qa.py +++ b/baselines/SELF-QA/self-qa.py @@ -134,7 +134,7 @@ async def process_chunk(content: str): parser.add_argument( "--input_file", help="Raw context jsonl path.", - default="resources/input_examples/chunked_demo.json", + default="resources/input_examples/json_demo.json", type=str, ) parser.add_argument( diff --git a/baselines/Wrap/wrap.py b/baselines/Wrap/wrap.py index 8618d613..3f71b2f4 100644 --- a/baselines/Wrap/wrap.py +++ b/baselines/Wrap/wrap.py @@ -87,7 +87,7 @@ async def process_chunk(content: str): parser.add_argument( "--input_file", help="Raw context jsonl path.", - default="resources/input_examples/chunked_demo.json", + default="resources/input_examples/json_demo.json", type=str, ) parser.add_argument( diff --git a/graphgen/version.py b/graphgen/_version.py similarity index 73% rename from graphgen/version.py rename to graphgen/_version.py index 73315e64..70316f9c 100644 --- a/graphgen/version.py +++ b/graphgen/_version.py @@ -1,7 +1,6 @@ - from typing import Tuple -__version__ = '20250416' +__version__ = "20250416" short_version = __version__ @@ -15,13 +14,13 @@ def parse_version_info(version_str: str) -> Tuple: tuple: A sequence of integer and string represents version. """ _version_info = [] - for x in version_str.split('.'): + for x in version_str.split("."): if x.isdigit(): _version_info.append(int(x)) - elif x.find('rc') != -1: - patch_version = x.split('rc') + elif x.find("rc") != -1: + patch_version = x.split("rc") _version_info.append(int(patch_version[0])) - _version_info.append(f'rc{patch_version[1]}') + _version_info.append(f"rc{patch_version[1]}") return tuple(_version_info) diff --git a/graphgen/models/embed/__init__.py b/graphgen/bases/__init__.py similarity index 100% rename from graphgen/models/embed/__init__.py rename to graphgen/bases/__init__.py diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py new file mode 100644 index 00000000..118b5258 --- /dev/null +++ b/graphgen/bases/base_reader.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List + + +class BaseReader(ABC): + """ + Abstract base class for reading and processing data. + """ + + def __init__(self, text_column: str = "content"): + self.text_column = text_column + + @abstractmethod + def read(self, file_path: str) -> List[Dict[str, Any]]: + """ + Read data from the specified file path. + + :param file_path: Path to the input file. + :return: List of dictionaries containing the data. + """ diff --git a/graphgen/models/storage/base_storage.py b/graphgen/bases/base_storage.py similarity index 96% rename from graphgen/models/storage/base_storage.py rename to graphgen/bases/base_storage.py index c09df074..dff83778 100644 --- a/graphgen/models/storage/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -1,8 +1,6 @@ from dataclasses import dataclass from typing import Generic, TypeVar, Union -from graphgen.models.embed.embedding import EmbeddingFunc - T = TypeVar("T") @@ -62,8 +60,6 @@ async def drop(self): @dataclass class BaseGraphStorage(StorageNameSpace): - embedding_func: EmbeddingFunc = None - async def has_node(self, node_id: str) -> bool: raise NotImplementedError diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml index e13a6606..a65cf2ac 100644 --- a/graphgen/configs/aggregated_config.yaml +++ b/graphgen/configs/aggregated_config.yaml @@ -1,5 +1,4 @@ -input_data_type: raw # raw, chunked -input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples +input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples output_data_type: aggregated # atomic, aggregated, multi_hop, cot output_data_format: ChatML # Alpaca, Sharegpt, ChatML tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml index 8e2c081f..4e8c4e29 100644 --- a/graphgen/configs/atomic_config.yaml +++ b/graphgen/configs/atomic_config.yaml @@ -1,5 +1,4 @@ -input_data_type: raw # raw, chunked -input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples +input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples output_data_type: atomic # atomic, aggregated, multi_hop, cot output_data_format: Alpaca # Alpaca, Sharegpt, ChatML tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml index 1073e97d..6aa6bf52 100644 --- a/graphgen/configs/cot_config.yaml +++ b/graphgen/configs/cot_config.yaml @@ -1,5 +1,4 @@ -input_data_type: raw # raw, chunked -input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples +input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples output_data_type: cot # atomic, aggregated, multi_hop, cot output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml index bb75d0a9..02e5e787 100644 --- a/graphgen/configs/multi_hop_config.yaml +++ b/graphgen/configs/multi_hop_config.yaml @@ -1,5 +1,4 @@ -input_data_type: raw # raw, chunked -input_file: resources/input_examples/raw_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples +input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples output_data_type: multi_hop # atomic, aggregated, multi_hop, cot output_data_format: ChatML # Alpaca, Sharegpt, ChatML tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 8521e744..fcb62387 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -7,7 +7,8 @@ import gradio as gr from tqdm.asyncio import tqdm as tqdm_async -from .models import ( +from graphgen.bases.base_storage import StorageNameSpace +from graphgen.models import ( Chunk, JsonKVStorage, JsonListStorage, @@ -15,8 +16,9 @@ OpenAIModel, Tokenizer, TraverseStrategy, + read_file, ) -from .models.storage.base_storage import StorageNameSpace + from .operators import ( extract_kg, generate_cot, @@ -32,7 +34,6 @@ create_event_loop, format_generation_results, logger, - read_file, ) sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -108,94 +109,54 @@ def __post_init__(self): namespace=f"qa-{self.unique_id}", ) - async def async_split_chunks( - self, data: List[Union[List, Dict]], data_type: str - ) -> dict: + async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict: # TODO: configurable whether to use coreference resolution if len(data) == 0: return {} inserting_chunks = {} - if data_type == "raw": - assert isinstance(data, list) and isinstance(data[0], dict) - # compute hash for each document - new_docs = { - compute_content_hash(doc["content"], prefix="doc-"): { - "content": doc["content"] - } - for doc in data - } - _add_doc_keys = await self.full_docs_storage.filter_keys( - list(new_docs.keys()) - ) - new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} - if len(new_docs) == 0: - logger.warning("All docs are already in the storage") - return {} - logger.info("[New Docs] inserting %d docs", len(new_docs)) - - cur_index = 1 - doc_number = len(new_docs) - async for doc_key, doc in tqdm_async( - new_docs.items(), desc="[1/4]Chunking documents", unit="doc" - ): - chunks = { - compute_content_hash(dp["content"], prefix="chunk-"): { - **dp, - "full_doc_id": doc_key, - } - for dp in self.tokenizer_instance.chunk_by_token_size( - doc["content"], self.chunk_overlap_size, self.chunk_size - ) - } - inserting_chunks.update(chunks) - - if self.progress_bar is not None: - self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}") - cur_index += 1 + assert isinstance(data, list) and isinstance(data[0], dict) - _add_chunk_keys = await self.text_chunks_storage.filter_keys( - list(inserting_chunks.keys()) - ) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + # compute hash for each document + new_docs = { + compute_content_hash(doc["content"], prefix="doc-"): { + "content": doc["content"] } - elif data_type == "chunked": - assert isinstance(data, list) and isinstance(data[0], list) - new_docs = { - compute_content_hash("".join(chunk["content"]), prefix="doc-"): { - "content": "".join(chunk["content"]) + for doc in data + } + _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) + new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} + if len(new_docs) == 0: + logger.warning("All docs are already in the storage") + return {} + logger.info("[New Docs] inserting %d docs", len(new_docs)) + + cur_index = 1 + doc_number = len(new_docs) + async for doc_key, doc in tqdm_async( + new_docs.items(), desc="[1/4]Chunking documents", unit="doc" + ): + chunks = { + compute_content_hash(dp["content"], prefix="chunk-"): { + **dp, + "full_doc_id": doc_key, } - for doc in data - for chunk in doc - } - _add_doc_keys = await self.full_docs_storage.filter_keys( - list(new_docs.keys()) - ) - new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} - if len(new_docs) == 0: - logger.warning("All docs are already in the storage") - return {} - logger.info("[New Docs] inserting %d docs", len(new_docs)) - async for doc in tqdm_async( - data, desc="[1/4]Chunking documents", unit="doc" - ): - doc_str = "".join([chunk["content"] for chunk in doc]) - for chunk in doc: - chunk_key = compute_content_hash(chunk["content"], prefix="chunk-") - inserting_chunks[chunk_key] = { - **chunk, - "full_doc_id": compute_content_hash(doc_str, prefix="doc-"), - } - _add_chunk_keys = await self.text_chunks_storage.filter_keys( - list(inserting_chunks.keys()) - ) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + for dp in self.tokenizer_instance.chunk_by_token_size( + doc["content"], self.chunk_overlap_size, self.chunk_size + ) } - else: - raise ValueError(f"Unknown data type: {data_type}") + inserting_chunks.update(chunks) + + if self.progress_bar is not None: + self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}") + cur_index += 1 + _add_chunk_keys = await self.text_chunks_storage.filter_keys( + list(inserting_chunks.keys()) + ) + inserting_chunks = { + k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + } await self.full_docs_storage.upsert(new_docs) await self.text_chunks_storage.upsert(inserting_chunks) @@ -211,10 +172,8 @@ async def async_insert(self): """ input_file = self.config["input_file"] - data_type = self.config["input_data_type"] data = read_file(input_file) - - inserting_chunks = await self.async_split_chunks(data, data_type) + inserting_chunks = await self.async_split_chunks(data) if len(inserting_chunks) == 0: logger.warning("All chunks are already in the storage") diff --git a/graphgen/judge.py b/graphgen/judge.py deleted file mode 100644 index f05bdf1d..00000000 --- a/graphgen/judge.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -import argparse -import asyncio -from dotenv import load_dotenv - -from .models import NetworkXStorage, JsonKVStorage, OpenAIModel -from .operators import judge_statement - -sys_path = os.path.abspath(os.path.dirname(__file__)) - -load_dotenv() - -def calculate_average_loss(graph: NetworkXStorage): - """ - Calculate the average loss of the graph. - - :param graph: NetworkXStorage - :return: float - """ - edges = asyncio.run(graph.get_all_edges()) - total_loss = 0 - for edge in edges: - total_loss += edge[2]['loss'] - return total_loss / len(edges) - - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--input', type=str, default=os.path.join(sys_path, "cache"), help='path to load input graph') - parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output') - - args = parser.parse_args() - - llm_client = OpenAIModel( - model_name=os.getenv("TRAINEE_MODEL"), - api_key=os.getenv("TRAINEE_API_KEY"), - base_url=os.getenv("TRAINEE_BASE_URL") - ) - - graph_storage = NetworkXStorage( - args.input, - namespace="graph" - ) - average_loss = calculate_average_loss(graph_storage) - print(f"Average loss of the graph: {average_loss}") - - rephrase_storage = JsonKVStorage( - os.path.join(sys_path, "cache"), - namespace="rephrase" - ) - - new_graph = asyncio.run(judge_statement(llm_client, graph_storage, rephrase_storage, re_judge=True)) - - graph_file = asyncio.run(graph_storage.get_graph()) - - new_graph.write_nx_graph(graph_file, args.output) - - average_loss = calculate_average_loss(new_graph) - print(f"Average loss of the graph: {average_loss}") diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index f7153358..79111b00 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -6,6 +6,7 @@ from .llm.openai_model import OpenAIModel from .llm.tokenizer import Tokenizer from .llm.topk_token_model import Token, TopkTokenModel +from .reader import read_file from .search.db.uniprot_search import UniProtSearch from .search.kg.wiki_search import WikiSearch from .search.web.bing_search import BingSearch diff --git a/graphgen/models/embed/embedding.py b/graphgen/models/embed/embedding.py deleted file mode 100644 index 8213b90f..00000000 --- a/graphgen/models/embed/embedding.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass -import asyncio -import numpy as np - -class UnlimitedSemaphore: - """A context manager that allows unlimited access.""" - - async def __aenter__(self): - pass - - async def __aexit__(self, exc_type, exc, tb): - pass - -@dataclass -class EmbeddingFunc: - embedding_dim: int - max_token_size: int - func: callable - concurrent_limit: int = 16 - - def __post_init__(self): - if self.concurrent_limit != 0: - self._semaphore = asyncio.Semaphore(self.concurrent_limit) - else: - self._semaphore = UnlimitedSemaphore() - - async def __call__(self, *args, **kwargs) -> np.ndarray: - async with self._semaphore: - return await self.func(*args, **kwargs) diff --git a/graphgen/models/reader/__init__.py b/graphgen/models/reader/__init__.py new file mode 100644 index 00000000..fde3962d --- /dev/null +++ b/graphgen/models/reader/__init__.py @@ -0,0 +1,22 @@ +from .csv_reader import CsvReader +from .json_reader import JsonReader +from .jsonl_reader import JsonlReader +from .txt_reader import TxtReader + +_MAPPING = { + "jsonl": JsonlReader, + "json": JsonReader, + "txt": TxtReader, + "csv": CsvReader, +} + + +def read_file(file_path: str): + suffix = file_path.split(".")[-1] + if suffix in _MAPPING: + reader = _MAPPING[suffix]() + else: + raise ValueError( + f"Unsupported file format: {suffix}. Supported formats are: {list(_MAPPING.keys())}" + ) + return reader.read(file_path) diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py new file mode 100644 index 00000000..f46c357e --- /dev/null +++ b/graphgen/models/reader/csv_reader.py @@ -0,0 +1,13 @@ +from typing import Any, Dict, List + +from graphgen.bases.base_reader import BaseReader + + +class CsvReader(BaseReader): + def read(self, file_path: str) -> List[Dict[str, Any]]: + import pandas as pd + + df = pd.read_csv(file_path) + if self.text_column not in df.columns: + raise ValueError(f"Missing '{self.text_column}' column in CSV file.") + return df.to_dict(orient="records") diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py new file mode 100644 index 00000000..98e1e16a --- /dev/null +++ b/graphgen/models/reader/json_reader.py @@ -0,0 +1,18 @@ +import json +from typing import Any, Dict, List + +from graphgen.bases.base_reader import BaseReader + + +class JsonReader(BaseReader): + def read(self, file_path: str) -> List[Dict[str, Any]]: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, list): + for doc in data: + if self.text_column not in doc: + raise ValueError( + f"Missing '{self.text_column}' in document: {doc}" + ) + return data + raise ValueError("JSON file must contain a list of documents.") diff --git a/graphgen/models/reader/jsonl_reader.py b/graphgen/models/reader/jsonl_reader.py new file mode 100644 index 00000000..d923d8eb --- /dev/null +++ b/graphgen/models/reader/jsonl_reader.py @@ -0,0 +1,22 @@ +import json +from typing import Any, Dict, List + +from graphgen.bases.base_reader import BaseReader + + +class JsonlReader(BaseReader): + def read(self, file_path: str) -> List[Dict[str, Any]]: + docs = [] + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + try: + doc = json.loads(line) + if self.text_column in doc: + docs.append(doc) + else: + raise ValueError( + f"Missing '{self.text_column}' in document: {doc}" + ) + except json.JSONDecodeError as e: + print(f"Error decoding JSON line: {line}. Error: {e}") + return docs diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py new file mode 100644 index 00000000..f9419ebd --- /dev/null +++ b/graphgen/models/reader/txt_reader.py @@ -0,0 +1,14 @@ +from typing import Any, Dict, List + +from graphgen.bases.base_reader import BaseReader + + +class TxtReader(BaseReader): + def read(self, file_path: str) -> List[Dict[str, Any]]: + docs = [] + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + docs.append({self.text_column: line}) + return docs diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index b61572f5..e37d033b 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass -from graphgen.models.storage.base_storage import BaseKVStorage, BaseListStorage +from graphgen.bases.base_storage import BaseKVStorage, BaseListStorage from graphgen.utils import load_json, logger, write_json diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/networkx_storage.py index 92643760..28baebda 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/networkx_storage.py @@ -1,11 +1,13 @@ -import os import html -from typing import Any, Union, cast, Optional +import os from dataclasses import dataclass +from typing import Any, Optional, Union, cast + import networkx as nx +from graphgen.bases.base_storage import BaseGraphStorage from graphgen.utils import logger -from .base_storage import BaseGraphStorage + @dataclass class NetworkXStorage(BaseGraphStorage): @@ -17,7 +19,11 @@ def load_nx_graph(file_name) -> Optional[nx.Graph]: @staticmethod def write_nx_graph(graph: nx.Graph, file_name): - logger.info("Writing graph with %d nodes, %d edges", graph.number_of_nodes(), graph.number_of_edges()) + logger.info( + "Writing graph with %d nodes, %d edges", + graph.number_of_nodes(), + graph.number_of_edges(), + ) nx.write_graphml(graph, file_name) @staticmethod @@ -77,8 +83,10 @@ def __post_init__(self): preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) if preloaded_graph is not None: logger.info( - "Loaded graph from %s with %d nodes, %d edges", self._graphml_xml_file, - preloaded_graph.number_of_nodes(), preloaded_graph.number_of_edges() + "Loaded graph from %s with %d nodes, %d edges", + self._graphml_xml_file, + preloaded_graph.number_of_nodes(), + preloaded_graph.number_of_edges(), ) self._graph = preloaded_graph or nx.Graph() @@ -111,7 +119,9 @@ async def get_edge( async def get_all_edges(self) -> Union[list[dict], None]: return self._graph.edges(data=True) - async def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]: + async def get_node_edges( + self, source_node_id: str + ) -> Union[list[tuple[str, str]], None]: if self._graph.has_node(source_node_id): return list(self._graph.edges(source_node_id, data=True)) return None @@ -133,11 +143,17 @@ async def upsert_edge( ): self._graph.add_edge(source_node_id, target_node_id, **edge_data) - async def update_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]): + async def update_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): if self._graph.has_edge(source_node_id, target_node_id): self._graph.edges[(source_node_id, target_node_id)].update(edge_data) else: - logger.warning("Edge %s -> %s not found in the graph for update.", source_node_id, target_node_id) + logger.warning( + "Edge %s -> %s not found in the graph for update.", + source_node_id, + target_node_id, + ) async def delete_node(self, node_id: str): """ diff --git a/graphgen/models/strategy/base_strategy.py b/graphgen/models/strategy/base_strategy.py deleted file mode 100644 index 70e0cc54..00000000 --- a/graphgen/models/strategy/base_strategy.py +++ /dev/null @@ -1,5 +0,0 @@ -from dataclasses import dataclass - -@dataclass -class BaseStrategy: - pass diff --git a/graphgen/models/strategy/travserse_strategy.py b/graphgen/models/strategy/travserse_strategy.py index 06882c5f..5739dea8 100644 --- a/graphgen/models/strategy/travserse_strategy.py +++ b/graphgen/models/strategy/travserse_strategy.py @@ -1,14 +1,12 @@ from dataclasses import dataclass, fields -from graphgen.models.strategy.base_strategy import BaseStrategy - @dataclass -class TraverseStrategy(BaseStrategy): +class TraverseStrategy: # 生成的QA形式:原子、多跳、聚合型 - qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated" + qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated" # 最大边数和最大token数方法中选择一个生效 - expand_method: str = "max_tokens" # "max_width" or "max_tokens" + expand_method: str = "max_tokens" # "max_width" or "max_tokens" # 单向拓展还是双向拓展 bidirectional: bool = True # 每个方向拓展的最大边数 @@ -18,9 +16,9 @@ class TraverseStrategy(BaseStrategy): # 每个方向拓展的最大深度 max_depth: int = 2 # 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合) - edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random" + edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random" # 孤立节点的处理策略 - isolated_node_strategy: str = "add" # "add" or "ignore" + isolated_node_strategy: str = "add" # "add" or "ignore" loss_strategy: str = "only_edge" # only_edge, both def to_yaml(self): diff --git a/graphgen/operators/kg/extract_kg.py b/graphgen/operators/kg/extract_kg.py index 406e400b..ec1f959c 100644 --- a/graphgen/operators/kg/extract_kg.py +++ b/graphgen/operators/kg/extract_kg.py @@ -6,8 +6,8 @@ import gradio as gr from tqdm.asyncio import tqdm as tqdm_async +from graphgen.bases.base_storage import BaseGraphStorage from graphgen.models import Chunk, OpenAIModel, Tokenizer -from graphgen.models.storage.base_storage import BaseGraphStorage from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes from graphgen.templates import KG_EXTRACTION_PROMPT from graphgen.utils import ( diff --git a/graphgen/operators/kg/merge_kg.py b/graphgen/operators/kg/merge_kg.py index 30379e66..fca35f3d 100644 --- a/graphgen/operators/kg/merge_kg.py +++ b/graphgen/operators/kg/merge_kg.py @@ -3,8 +3,8 @@ from tqdm.asyncio import tqdm as tqdm_async +from graphgen.bases.base_storage import BaseGraphStorage from graphgen.models import Tokenizer, TopkTokenModel -from graphgen.models.storage.base_storage import BaseGraphStorage from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT from graphgen.utils import detect_main_language, logger from graphgen.utils.format import split_string_by_multi_markers diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index b3c8e1e6..a3bf4965 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -1,6 +1,5 @@ from .calculate_confidence import yes_no_loss_entropy from .detect_lang import detect_if_chinese, detect_main_language -from .file import read_file from .format import ( format_generation_results, handle_single_entity_extraction, diff --git a/graphgen/utils/file.py b/graphgen/utils/file.py deleted file mode 100644 index 11298616..00000000 --- a/graphgen/utils/file.py +++ /dev/null @@ -1,24 +0,0 @@ -import json - - -def read_file(input_file: str) -> list: - """ - Read data from a file based on the specified data type. - :param input_file - :return: - """ - - if input_file.endswith(".jsonl"): - with open(input_file, "r", encoding="utf-8") as f: - data = [json.loads(line) for line in f] - elif input_file.endswith(".json"): - with open(input_file, "r", encoding="utf-8") as f: - data = json.load(f) - elif input_file.endswith(".txt"): - with open(input_file, "r", encoding="utf-8") as f: - data = [line.strip() for line in f if line.strip()] - data = [{"content": line} for line in data] - else: - raise ValueError(f"Unsupported file format: {input_file}") - - return data diff --git a/resources/input_examples/chunked_demo.json b/resources/input_examples/chunked_demo.json deleted file mode 100644 index ad7219a3..00000000 --- a/resources/input_examples/chunked_demo.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - [ - {"content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"} - ], - [ - {"content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"} - ], - [ - {"content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."} - ], - [ - {"content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} - ] -] diff --git a/resources/input_examples/csv_demo.csv b/resources/input_examples/csv_demo.csv new file mode 100644 index 00000000..11e6dde3 --- /dev/null +++ b/resources/input_examples/csv_demo.csv @@ -0,0 +1,5 @@ +content +"云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。" +"隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。" +"Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity." +"Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture." diff --git a/resources/input_examples/json_demo.json b/resources/input_examples/json_demo.json new file mode 100644 index 00000000..b496c16f --- /dev/null +++ b/resources/input_examples/json_demo.json @@ -0,0 +1,6 @@ +[ + {"content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"}, + {"content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"}, + {"content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."}, + {"content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} +] diff --git a/resources/input_examples/raw_demo.jsonl b/resources/input_examples/jsonl_demo.jsonl similarity index 100% rename from resources/input_examples/raw_demo.jsonl rename to resources/input_examples/jsonl_demo.jsonl diff --git a/resources/input_examples/keywords_demo.txt b/resources/input_examples/keywords_demo.txt deleted file mode 100644 index 0d4c47eb..00000000 --- a/resources/input_examples/keywords_demo.txt +++ /dev/null @@ -1,5 +0,0 @@ -TATA Box -TFBS -E-box -Enhancer -AP-1 diff --git a/setup.py b/setup.py index 3dee7f8b..d517fc8f 100644 --- a/setup.py +++ b/setup.py @@ -1,30 +1,31 @@ +# pylint: skip-file import os from setuptools import find_packages, setup pwd = os.path.dirname(__file__) -version_file = 'graphgen/version.py' +version_file = "graphgen/_version.py" def readme(): - with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f: + with open(os.path.join(pwd, "README.md"), encoding="utf-8") as f: content = f.read() return content def get_version(): - with open(os.path.join(pwd, version_file), 'r') as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] + with open(os.path.join(pwd, version_file), "r") as f: + exec(compile(f.read(), version_file, "exec")) + return locals()["__version__"] def read_requirements(): lines = [] - with open('requirements.txt', 'r') as f: + with open("requirements.txt", "r") as f: for line in f.readlines(): - if line.startswith('#'): + if line.startswith("#"): continue - if 'textract' in line: + if "textract" in line: continue if len(line) > 0: lines.append(line) @@ -33,32 +34,29 @@ def read_requirements(): install_packages = read_requirements() -if __name__ == '__main__': +if __name__ == "__main__": setup( - name='graphg', + name="graphg", version=get_version(), - url='https://github.com/open-sciencelab/GraphGen', - description= # noqa E251 - 'GraphGen: Enhancing Supervised Fine-Tuning for LLMs with Knowledge-Driven Synthetic Data Generation', # noqa E501 + url="https://github.com/open-sciencelab/GraphGen", + description="GraphGen: Enhancing Supervised Fine-Tuning for LLMs with Knowledge-Driven Synthetic Data Generation", long_description=readme(), - long_description_content_type='text/markdown', - author='open-sciencelab', - author_email='open-sciencelab@pjlab.org.cn', + long_description_content_type="text/markdown", + author="open-sciencelab", + author_email="open-sciencelab@pjlab.org.cn", packages=find_packages(exclude=["models"]), - package_data={ - 'GraphGen': ['configs/*'] - }, + package_data={"GraphGen": ["configs/*"]}, include_package_data=True, install_requires=install_packages, classifiers=[ - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3.12', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", ], - entry_points={'console_scripts': ['graphgen=graphgen.generate:main']}, + entry_points={"console_scripts": ["graphgen=graphgen.generate:main"]}, ) diff --git a/webui/app.py b/webui/app.py index f8b8f4cc..6822405f 100644 --- a/webui/app.py +++ b/webui/app.py @@ -116,35 +116,6 @@ def sum_tokens(client): env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"] ) - # Load input data - file = config["input_file"] - if isinstance(file, list): - file = file[0] - - data = [] - - if file.endswith(".jsonl"): - config["input_data_type"] = "raw" - with open(file, "r", encoding="utf-8") as f: - data.extend(json.loads(line) for line in f) - elif file.endswith(".json"): - config["input_data_type"] = "chunked" - with open(file, "r", encoding="utf-8") as f: - data.extend(json.load(f)) - elif file.endswith(".txt"): - # 读取文件后根据chunk_size转成raw格式的数据 - config["input_data_type"] = "raw" - content = "" - with open(file, "r", encoding="utf-8") as f: - lines = f.readlines() - for line in lines: - content += line.strip() + " " - size = int(config.get("chunk_size", 512)) - chunks = [content[i : i + size] for i in range(0, len(content), size)] - data.extend([{"content": chunk} for chunk in chunks]) - else: - raise ValueError(f"Unsupported file type: {file}") - # Initialize GraphGen graph_gen = init_graph_gen(config, env) graph_gen.clear() @@ -436,19 +407,20 @@ def sum_tokens(client): upload_file = gr.File( label=_("Upload File"), file_count="single", - file_types=[".txt", ".json", ".jsonl"], + file_types=[".txt", ".json", ".jsonl", ".csv"], interactive=True, ) examples_dir = os.path.join(root_dir, "webui", "examples") gr.Examples( examples=[ [os.path.join(examples_dir, "txt_demo.txt")], - [os.path.join(examples_dir, "raw_demo.jsonl")], - [os.path.join(examples_dir, "chunked_demo.json")], + [os.path.join(examples_dir, "jsonl_demo.jsonl")], + [os.path.join(examples_dir, "json_demo.json")], + [os.path.join(examples_dir, "csv_demo.csv")], ], inputs=upload_file, label=_("Example Files"), - examples_per_page=3, + examples_per_page=4, ) with gr.Column(scale=1): output = gr.File( diff --git a/webui/count_tokens.py b/webui/count_tokens.py index 53bed59a..210bd267 100644 --- a/webui/count_tokens.py +++ b/webui/count_tokens.py @@ -1,6 +1,7 @@ +import json import os import sys -import json + import pandas as pd # pylint: disable=wrong-import-position @@ -8,24 +9,29 @@ sys.path.append(root_dir) from graphgen.models import Tokenizer + def count_tokens(file, tokenizer_name, data_frame): if not file or not os.path.exists(file): return data_frame if file.endswith(".jsonl"): - with open(file, "r", encoding='utf-8') as f: + with open(file, "r", encoding="utf-8") as f: data = [json.loads(line) for line in f] elif file.endswith(".json"): - with open(file, "r", encoding='utf-8') as f: + with open(file, "r", encoding="utf-8") as f: data = json.load(f) data = [item for sublist in data for item in sublist] elif file.endswith(".txt"): - with open(file, "r", encoding='utf-8') as f: + with open(file, "r", encoding="utf-8") as f: data = f.read() - chunks = [ - data[i:i + 512] for i in range(0, len(data), 512) - ] + chunks = [data[i : i + 512] for i in range(0, len(data), 512)] data = [{"content": chunk} for chunk in chunks] + elif file.endswith(".csv"): + df = pd.read_csv(file) + if "content" in df.columns: + data = df["content"].tolist() + else: + data = df.iloc[:, 0].tolist() else: raise ValueError(f"Unsupported file type: {file}") @@ -41,20 +47,13 @@ def count_tokens(file, tokenizer_name, data_frame): content = item token_count += len(tokenizer.encode_string(content)) - _update_data = [[ - str(token_count), - str(token_count * 50), - "N/A" - ]] + _update_data = [[str(token_count), str(token_count * 50), "N/A"]] try: - new_df = pd.DataFrame( - _update_data, - columns=data_frame.columns - ) + new_df = pd.DataFrame(_update_data, columns=data_frame.columns) data_frame = new_df - except Exception as e: # pylint: disable=broad-except + except Exception as e: # pylint: disable=broad-except print("[ERROR] DataFrame操作异常:", str(e)) return data_frame diff --git a/webui/examples/chunked_demo.json b/webui/examples/chunked_demo.json deleted file mode 100644 index ad7219a3..00000000 --- a/webui/examples/chunked_demo.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - [ - {"content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"} - ], - [ - {"content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"} - ], - [ - {"content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."} - ], - [ - {"content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} - ] -] diff --git a/webui/examples/csv_demo.csv b/webui/examples/csv_demo.csv new file mode 100644 index 00000000..11e6dde3 --- /dev/null +++ b/webui/examples/csv_demo.csv @@ -0,0 +1,5 @@ +content +"云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。" +"隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。" +"Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity." +"Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture." diff --git a/webui/examples/json_demo.json b/webui/examples/json_demo.json new file mode 100644 index 00000000..b496c16f --- /dev/null +++ b/webui/examples/json_demo.json @@ -0,0 +1,6 @@ +[ + {"content": "云南省农业科学院粮食作物研究所于2005年育成早熟品种云粳26号,该品种外观特点为: 颖尖无色、无芒,谷壳黄色,落粒性适中,米粒大,有香味,食味品质好,高抗稻瘟病,适宜在云南中海拔 1 500∼1 800 m 稻区种植。2012年被农业部列为西南稻区农业推广主导品种。"}, + {"content": "隆两优1212 于2017 年引入福建省龙岩市长汀县试种,在长汀县圣丰家庭农场(河田镇南塘村)种植,土壤肥力中等、排灌方便[2],试种面积 0.14 hm^2 ,作烟后稻种植,6 月15 日机播,7月5 日机插,10 月21 日成熟,产量 8.78 t/hm^2 。2018 和2019 年分别在长汀润丰优质稻专业合作社(濯田镇永巫村)和长汀县绿丰优质稻专业合作社(河田镇中街村)作烟后稻进一步扩大示范种植,均采用机播机插机收。2018 年示范面积 4.00 hm^2 ,平均产量 8.72 t/hm^2 ;2019 年示范面积 13.50 hm^2 ,平均产量 8.74 t/hm^2 。经3 a 试种、示范,隆两优1212 表现出分蘖力强、抗性好、抽穗整齐、后期转色好、生育期适中、产量高、适应性好等特点,可作为烟后稻在长汀县推广种植。"}, + {"content": "Grain size is one of the key factors determining grain yield. However, it remains largely unknown how grain size is regulated by developmental signals. Here, we report the identification and characterization of a dominant mutant big grain1 (Bg1-D) that shows an extra-large grain phenotype from our rice T-DNA insertion population. Overexpression of BG1 leads to significantly increased grain size, and the severe lines exhibit obviously perturbed gravitropism. In addition, the mutant has increased sensitivities to both auxin and N-1-naphthylphthalamic acid, an auxin transport inhibitor, whereas knockdown of BG1 results in decreased sensitivities and smaller grains. Moreover, BG1 is specifically induced by auxin treatment, preferentially expresses in the vascular tissue of culms and young panicles, and encodes a novel membrane-localized protein, strongly suggesting its role in regulating auxin transport. Consistent with this finding, the mutant has increased auxin basipetal transport and altered auxin distribution, whereas the knockdown plants have decreased auxin transport. Manipulation of BG1 in both rice and Arabidopsis can enhance plant biomass, seed weight, and yield. Taking these data together, we identify a novel positive regulator of auxin response and transport in a crop plant and demonstrate its role in regulating grain size, thus illuminating a new strategy to improve plant productivity."}, + {"content": "Tiller angle, an important component of plant architecture, greatly influences the grain yield of rice (Oryza sativa L.). Here, we identified Tiller Angle Control 4 (TAC4) as a novel regulator of rice tiller angle. TAC4 encodes a plant-specific, highly conserved nuclear protein. The loss of TAC4 function leads to a significant increase in the tiller angle. TAC4 can regulate rice shoot\n\ngravitropism by increasing the indole acetic acid content and affecting the auxin distribution. A sequence analysis revealed that TAC4 has undergone a bottleneck and become fixed in indica cultivars during domestication and improvement. Our findings facilitate an increased understanding of the regulatory mechanisms of tiller angle and also provide a potential gene resource for the improvement of rice plant architecture."} +] diff --git a/webui/examples/raw_demo.jsonl b/webui/examples/jsonl_demo.jsonl similarity index 100% rename from webui/examples/raw_demo.jsonl rename to webui/examples/jsonl_demo.jsonl From 1696f58b8ec974e639a57ba7eb7d7bbf82e7b209 Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:58:05 +0800 Subject: [PATCH 2/4] Update graphgen/models/reader/csv_reader.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- graphgen/models/reader/csv_reader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index f46c357e..5844a2e1 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -1,11 +1,10 @@ from typing import Any, Dict, List from graphgen.bases.base_reader import BaseReader - +import pandas as pd class CsvReader(BaseReader): def read(self, file_path: str) -> List[Dict[str, Any]]: - import pandas as pd df = pd.read_csv(file_path) if self.text_column not in df.columns: From ad3060d8c9b5441a2dead3ec7f4b70df34792c57 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 23 Sep 2025 16:59:48 +0800 Subject: [PATCH 3/4] fix: fix import --- graphgen/models/reader/csv_reader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index 5844a2e1..05960082 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -1,8 +1,10 @@ from typing import Any, Dict, List -from graphgen.bases.base_reader import BaseReader import pandas as pd +from graphgen.bases.base_reader import BaseReader + + class CsvReader(BaseReader): def read(self, file_path: str) -> List[Dict[str, Any]]: From 43ba53a45b41e4c994b1b10b7e421f54680fc58b Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Tue, 23 Sep 2025 17:02:20 +0800 Subject: [PATCH 4/4] fix: use logger instead of print --- graphgen/models/reader/jsonl_reader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphgen/models/reader/jsonl_reader.py b/graphgen/models/reader/jsonl_reader.py index d923d8eb..8904bbb3 100644 --- a/graphgen/models/reader/jsonl_reader.py +++ b/graphgen/models/reader/jsonl_reader.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class JsonlReader(BaseReader): @@ -18,5 +19,5 @@ def read(self, file_path: str) -> List[Dict[str, Any]]: f"Missing '{self.text_column}' in document: {doc}" ) except json.JSONDecodeError as e: - print(f"Error decoding JSON line: {line}. Error: {e}") + logger.error("Error decoding JSON line: %s. Error: %s", line, e) return docs