diff --git a/src/codegen/sdk/codebase/codebase_context.py b/src/codegen/sdk/codebase/codebase_context.py index 957efe708..c4337683d 100644 --- a/src/codegen/sdk/codebase/codebase_context.py +++ b/src/codegen/sdk/codebase/codebase_context.py @@ -474,7 +474,11 @@ def _process_diff_files(self, files_to_sync: Mapping[SyncType, list[Path]], incr task = self.progress.begin("Adding new files", count=len(files_to_sync[SyncType.ADD])) for idx, filepath in enumerate(files_to_sync[SyncType.ADD]): task.update(f"Adding {self.to_relative(filepath)}", count=idx) - content = self.io.read_text(filepath) + try: + content = self.io.read_text(filepath) + except UnicodeDecodeError as e: + logger.warning(f"Can't read file at:{filepath} since it contains non-unicode characters. File will be ignored!") + continue # TODO: this is wrong with context changes if filepath.suffix in self.extensions: file_cls = self.node_classes.file_cls diff --git a/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_graph.py b/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_graph.py index a5e8f5d5f..61c8e5534 100644 --- a/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_graph.py +++ b/tests/unit/codegen/sdk/codebase/codebase_graph/test_codebase_graph.py @@ -58,3 +58,40 @@ def __init__(self): assert len(import_resolution_edges) == 4 assert len(file_contains_node_edges) == 14 assert len(symbol_usage_edges) == 6 + + +def test_codebase_broken_file(tmpdir) -> None: + # language=python + content = """ +from some_file import x, y, z +import numpy as np + +global_var_1 = 1 +global_var_2 = 2 + +def foo(): + return bar() + +def bar(): + return 42 + +class MyClass: + def __init__(self): + pass + +class MySubClass(MyClass): + def __init__(self): + super().__init__() + pass + """ + content_broken = bytes("你好", "big5hkscs") + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content, "test2.py": content_broken}) as codebase: + assert codebase is not None + assert isinstance(codebase.ctx, CodebaseContext) + import_resolution_edges = [edge for edge in codebase.ctx.edges if edge[2].type == EdgeType.IMPORT_SYMBOL_RESOLUTION] + file_contains_node_edges = list(itertools.chain.from_iterable(file.get_nodes() for file in codebase.files)) + symbol_usage_edges = [edge for edge in codebase.ctx.edges if edge[2].type == EdgeType.SYMBOL_USAGE] + + assert len(import_resolution_edges) == 4 + assert len(file_contains_node_edges) == 14 + assert len(symbol_usage_edges) == 6