diff --git a/examples/code_embedding/main.py b/examples/code_embedding/main.py index 053eacf9..65c3943e 100644 --- a/examples/code_embedding/main.py +++ b/examples/code_embedding/main.py @@ -1,7 +1,6 @@ from dotenv import load_dotenv from psycopg_pool import ConnectionPool from pgvector.psycopg import register_vector -from typing import Any import functools import cocoindex import os @@ -9,12 +8,6 @@ import numpy as np -@cocoindex.op.function() -def extract_extension(filename: str) -> str: - """Extract the extension of a filename.""" - return os.path.splitext(filename)[1] - - @cocoindex.transform_flow() def code_to_embedding( text: cocoindex.DataSlice[str], @@ -53,10 +46,12 @@ def code_embedding_flow( code_embeddings = data_scope.add_collector() with data_scope["files"].row() as file: - file["extension"] = file["filename"].transform(extract_extension) + file["language"] = file["filename"].transform( + cocoindex.functions.DetectProgrammingLanguage() + ) file["chunks"] = file["content"].transform( cocoindex.functions.SplitRecursively(), - language=file["extension"], + language=file["language"], chunk_size=1000, min_chunk_size=300, chunk_overlap=300,