Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(LAB-2944): add file data conversion prior to import for new llm format #1714

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/kili/services/asset_import/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class ImportValidationError(Exception):
"""Raised when data given to import does not follow a right format."""


class ImportFileConversionError(Exception):
"""Raised when an error occurs during processing a llm file for conversion."""


class UploadFromLocalDataForbiddenError(Exception):
"""Raised when data given to import does not follow a right format."""

Expand Down
82 changes: 82 additions & 0 deletions src/kili/services/asset_import/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Helpers for the asset_import."""

import warnings


def is_chat_format(data, required_keys):
"""Checks if llm file data is in chat format."""
if isinstance(data, dict):
return False

if not isinstance(data, list):
warnings.warn("Json file is not an array.")
return False

# Check each item in the array
for item in data:
# Ensure each item is a dictionary with the required keys
if not isinstance(item, dict) or not required_keys.issubset(item.keys()):
missing_keys = required_keys - set(item.keys())
warnings.warn(f"Array item missing keys : {missing_keys}", stacklevel=3)
return False
return True


def process_json(data):
"""Processes the llm file data : converts it to Kili format is chat format is present."""
# Initialize the transformed structure
transformed_data = {"prompts": [], "type": "markdown", "version": "0.1"}

# Temporary variables for processing
current_prompt = None
completions = []
models = [] # To store models for determining the last two
item_ids = [] # To store all item IDs for concatenation
chat_id = None

for item in data:
chat_id = item.get("chat_id", None)
item_ids.append(item["id"])

# Check if the model is null (indicating a prompt)
if item["model"] is None:
# If there's an existing prompt being processed, add it to the prompts list
if current_prompt is not None:
transformed_data["prompts"].append(
{
"completions": completions,
"prompt": current_prompt["content"],
}
)
completions = [] # Reset completions for the next prompt

# Update the current prompt
current_prompt = item
else:
# Add completion to the current prompt
completions.append(
{
"content": item["content"],
"title": item["role"],
}
)
# Collect model for this item
models.append(item["model"])

# Add the last prompt if it exists
if current_prompt is not None:
transformed_data["prompts"].append(
{
"completions": completions,
"prompt": current_prompt["content"],
}
)

# Prepare additional_json_metadata
additional_json_metadata = {
"chat_id": chat_id,
"models": "_".join(models[-2:]), # Join the last two models
"chat_item_ids": "_".join(item_ids), # Concatenate all item IDs
}

return transformed_data, additional_json_metadata
64 changes: 62 additions & 2 deletions src/kili/services/asset_import/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
BatchParams,
ContentBatchImporter,
)
from .exceptions import ImportValidationError
from .exceptions import ImportFileConversionError, ImportValidationError
from .helpers import is_chat_format, process_json
from .types import AssetLike


class LLMDataType(Enum):
"""LLM data type."""

DICT = "DICT"
LIST = "LIST"
LOCAL_FILE = "LOCAL_FILE"
HOSTED_FILE = "HOSTED_FILE"

Expand All @@ -45,28 +47,86 @@ def get_data_type(assets: List[AssetLike]) -> LLMDataType:
return LLMDataType.LOCAL_FILE
if all(isinstance(content, dict) for content in content_array):
return LLMDataType.DICT
if all(isinstance(content, list) for content in content_array):
return LLMDataType.LIST
raise ImportValidationError("Invalid value in content for LLM project.")

@staticmethod
def transform_asset_content(asset_content, json_metadata):
"""Transform asset content."""
content, additional_json_metadata = process_json(asset_content)
transformed_asset_content = json.dumps(content).encode("utf-8")

json_metadata_dict = {}
if json_metadata and isinstance(json_metadata, str):
json_metadata_dict = json.loads(json_metadata)
elif json_metadata:
json_metadata_dict = json_metadata

merged_json_metadata = {
**json_metadata_dict,
**additional_json_metadata,
}
changed_json_metadata = json.dumps(merged_json_metadata)

return transformed_asset_content, changed_json_metadata

def import_assets(self, assets: List[AssetLike]):
"""Import LLM assets into Kili."""
self._check_upload_is_allowed(assets)
data_type = self.get_data_type(assets)
assets = self.filter_duplicate_external_ids(assets)

if data_type == LLMDataType.LOCAL_FILE:
assets = self.filter_local_assets(assets, self.raise_error)
batch_params = BatchParams(is_hosted=False, is_asynchronous=False)
batch_importer = ContentBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)
for asset in assets:
baptiste-olivier marked this conversation as resolved.
Show resolved Hide resolved
file_path = asset.get("content", None)
json_metadata = asset.get("json_metadata", "{}")
if file_path and isinstance(file_path, str):
try:
with open(file_path, encoding="utf-8") as file:
data = json.load(file)

if is_chat_format(data, {"role", "content", "id", "chat_id", "model"}):
(
asset["content"],
asset["json_metadata"],
) = self.transform_asset_content(data, json_metadata)

batch_importer = JSONBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)

except Exception as exception:
raise ImportFileConversionError(
f"Error processing file: {exception}"
) from exception

elif data_type == LLMDataType.HOSTED_FILE:
batch_params = BatchParams(is_hosted=True, is_asynchronous=False)
batch_importer = ContentBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)
elif data_type == LLMDataType.DICT:
elif data_type in (LLMDataType.DICT, LLMDataType.LIST):
for asset in assets:
if "content" in asset and isinstance(asset["content"], dict):
asset["content"] = json.dumps(asset["content"]).encode("utf-8")
elif (
"content" in asset
and isinstance(asset["content"], list)
and is_chat_format(
asset["content"], {"role", "content", "id", "chat_id", "model"}
)
):
json_metadata = asset.get("json_metadata", "{}")
asset["content"], asset["json_metadata"] = self.transform_asset_content(
asset["content"], json_metadata
)

batch_params = BatchParams(is_hosted=False, is_asynchronous=False)
batch_importer = JSONBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
Expand Down
Loading
Loading