Skip to content

Commit

Permalink
feat(LAB-2944): add file data conversion prior to import for new llm …
Browse files Browse the repository at this point in the history
…format (#1714)
  • Loading branch information
BlueGrizzliBear committed Jun 24, 2024
1 parent dc39a8d commit 88338cd
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/kili/services/asset_import/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class ImportValidationError(Exception):
"""Raised when data given to import does not follow a right format."""


class ImportFileConversionError(Exception):
"""Raised when an error occurs during processing a llm file for conversion."""


class UploadFromLocalDataForbiddenError(Exception):
"""Raised when data given to import does not follow a right format."""

Expand Down
98 changes: 98 additions & 0 deletions src/kili/services/asset_import/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Helpers for the asset_import."""

import warnings


def is_chat_format(data, required_keys):
"""Checks if llm file data is in chat format."""
if isinstance(data, dict):
return False

if not isinstance(data, list):
warnings.warn("Json file is not an array.")
return False

# Check each item in the array
for item in data:
# Ensure each item is a dictionary with the required keys
if not isinstance(item, dict) or not required_keys.issubset(item.keys()):
missing_keys = required_keys - set(item.keys())
raise ValueError(f"Chat item missing keys : {missing_keys}")
return True


def process_json(data):
"""Processes the llm file data : converts it to Kili format is chat format is present."""
# Initialize the transformed structure
transformed_data = {"prompts": [], "type": "markdown", "version": "0.1"}

# Temporary variables for processing
current_prompt = None
completions = []
models = [] # To store models for determining the last two
item_ids = [] # To store all item IDs for concatenation
chat_id = None

for item in data:
chat_id = item.get("chat_id", None)
if item["id"] is not None:
item_ids.append(item["id"])
else:
warnings.warn(f"No id value for chat item {item}.")

if item["content"] is None:
raise ValueError("Chat item content cannot be null.")

# Check if the model is null (indicating a prompt)
if item["model"] is None:
# If there's an existing prompt being processed, add it to the prompts list
if current_prompt is not None:
transformed_data["prompts"].append(
{
"completions": completions,
"prompt": current_prompt["content"],
}
)
completions = [] # Reset completions for the next prompt

# Update the current prompt
current_prompt = item
else:
if item["role"] is None:
raise ValueError("Chat item role cannot be null.")

# Add completion to the current prompt
completions.append(
{
"content": item["content"],
"title": item["role"],
}
)
# Collect model for this item
models.append(item["model"])

if current_prompt is None:
raise ValueError(
"No user prompt found in payload ('model' key set to None) : need at least one."
)

# Add the last prompt if it exists
if current_prompt is not None:
transformed_data["prompts"].append(
{
"completions": completions,
"prompt": current_prompt["content"],
}
)

chat_item_ids = "_".join(item_ids)

# Prepare additional_json_metadata
additional_json_metadata = {
"chat_id": chat_id,
"models": "_".join(models[-2:]), # Join the last two models
"chat_item_ids": chat_item_ids, # Concatenate all item IDs
"text": f"Chat_id: {chat_id}\n\nChat_item_ids: {chat_item_ids}",
}

return transformed_data, additional_json_metadata
64 changes: 62 additions & 2 deletions src/kili/services/asset_import/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
BatchParams,
ContentBatchImporter,
)
from .exceptions import ImportValidationError
from .exceptions import ImportFileConversionError, ImportValidationError
from .helpers import is_chat_format, process_json
from .types import AssetLike


class LLMDataType(Enum):
"""LLM data type."""

DICT = "DICT"
LIST = "LIST"
LOCAL_FILE = "LOCAL_FILE"
HOSTED_FILE = "HOSTED_FILE"

Expand All @@ -45,28 +47,86 @@ def get_data_type(assets: List[AssetLike]) -> LLMDataType:
return LLMDataType.LOCAL_FILE
if all(isinstance(content, dict) for content in content_array):
return LLMDataType.DICT
if all(isinstance(content, list) for content in content_array):
return LLMDataType.LIST
raise ImportValidationError("Invalid value in content for LLM project.")

@staticmethod
def transform_asset_content(asset_content, json_metadata):
"""Transform asset content."""
content, additional_json_metadata = process_json(asset_content)
transformed_asset_content = json.dumps(content).encode("utf-8")

json_metadata_dict = {}
if json_metadata and isinstance(json_metadata, str):
json_metadata_dict = json.loads(json_metadata)
elif json_metadata:
json_metadata_dict = json_metadata

merged_json_metadata = {
**json_metadata_dict,
**additional_json_metadata,
}
changed_json_metadata = json.dumps(merged_json_metadata)

return transformed_asset_content, changed_json_metadata

def import_assets(self, assets: List[AssetLike]):
"""Import LLM assets into Kili."""
self._check_upload_is_allowed(assets)
data_type = self.get_data_type(assets)
assets = self.filter_duplicate_external_ids(assets)

if data_type == LLMDataType.LOCAL_FILE:
assets = self.filter_local_assets(assets, self.raise_error)
batch_params = BatchParams(is_hosted=False, is_asynchronous=False)
batch_importer = ContentBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)
for asset in assets:
file_path = asset.get("content", None)
json_metadata = asset.get("json_metadata", "{}")
if file_path and isinstance(file_path, str):
try:
with open(file_path, encoding="utf-8") as file:
data = json.load(file)

if is_chat_format(data, {"role", "content", "id", "chat_id", "model"}):
(
asset["content"],
asset["json_metadata"],
) = self.transform_asset_content(data, json_metadata)

batch_importer = JSONBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)

except Exception as exception:
raise ImportFileConversionError(
f"Error processing file: {exception}"
) from exception

elif data_type == LLMDataType.HOSTED_FILE:
batch_params = BatchParams(is_hosted=True, is_asynchronous=False)
batch_importer = ContentBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
)
elif data_type == LLMDataType.DICT:
elif data_type in (LLMDataType.DICT, LLMDataType.LIST):
for asset in assets:
if "content" in asset and isinstance(asset["content"], dict):
asset["content"] = json.dumps(asset["content"]).encode("utf-8")
elif (
"content" in asset
and isinstance(asset["content"], list)
and is_chat_format(
asset["content"], {"role", "content", "id", "chat_id", "model"}
)
):
json_metadata = asset.get("json_metadata", "{}")
asset["content"], asset["json_metadata"] = self.transform_asset_content(
asset["content"], json_metadata
)

batch_params = BatchParams(is_hosted=False, is_asynchronous=False)
batch_importer = JSONBatchImporter(
self.kili, self.project_params, batch_params, self.pbar
Expand Down
Loading

0 comments on commit 88338cd

Please sign in to comment.