diff --git a/examples/code_embedding/main.py b/examples/code_embedding/main.py index 65c3943e..053eacf9 100644 --- a/examples/code_embedding/main.py +++ b/examples/code_embedding/main.py @@ -1,6 +1,7 @@ from dotenv import load_dotenv from psycopg_pool import ConnectionPool from pgvector.psycopg import register_vector +from typing import Any import functools import cocoindex import os @@ -8,6 +9,12 @@ import numpy as np +@cocoindex.op.function() +def extract_extension(filename: str) -> str: + """Extract the extension of a filename.""" + return os.path.splitext(filename)[1] + + @cocoindex.transform_flow() def code_to_embedding( text: cocoindex.DataSlice[str], @@ -46,12 +53,10 @@ def code_embedding_flow( code_embeddings = data_scope.add_collector() with data_scope["files"].row() as file: - file["language"] = file["filename"].transform( - cocoindex.functions.DetectProgrammingLanguage() - ) + file["extension"] = file["filename"].transform(extract_extension) file["chunks"] = file["content"].transform( cocoindex.functions.SplitRecursively(), - language=file["language"], + language=file["extension"], chunk_size=1000, min_chunk_size=300, chunk_overlap=300,