diff --git a/src/datacustomcode/scan.py b/src/datacustomcode/scan.py index 7ca63b1..5e50c5d 100644 --- a/src/datacustomcode/scan.py +++ b/src/datacustomcode/scan.py @@ -368,6 +368,8 @@ def update_config(file_path: str) -> dict[str, Any]: base_directory = find_base_directory(file_path) package_type = get_package_type(base_directory) + existing_config["entryPoint"] = os.path.basename(file_path) + if package_type == "script": existing_config["dataspace"] = get_dataspace(existing_config) output = scan_file(file_path) diff --git a/tests/test_scan.py b/tests/test_scan.py index 528f728..5d7e2ff 100644 --- a/tests/test_scan.py +++ b/tests/test_scan.py @@ -565,6 +565,134 @@ def test_raises_error_on_invalid_json(self): if os.path.exists(config_path): os.remove(config_path) + def test_update_config_updates_entrypoint(self): + """ + Test that update_config() updates the entryPoint field + when scanning a renamed file. + """ + content = textwrap.dedent( + """ + from datacustomcode.client import Client + + client = Client() + df = client.read_dlo("input_dlo") + client.write_to_dlo("output_dlo", df, "overwrite") + """ + ) + + temp_path = create_test_script(content) + file_dir = os.path.dirname(temp_path) + config_path = os.path.join(file_dir, "config.json") + + try: + sdk_config_path = create_sdk_config(file_dir, "script") + + initial_config = { + "sdkVersion": "1.0.0", + "entryPoint": "old_entrypoint.py", + "dataspace": "custom_dataspace", + "permissions": { + "read": {"dlo": ["old_dlo"]}, + "write": {"dlo": ["old_output"]}, + }, + } + with open(config_path, "w") as f: + json.dump(initial_config, f) + + updated_config = update_config(temp_path) + + assert updated_config["entryPoint"] == os.path.basename(temp_path) + assert updated_config["dataspace"] == "custom_dataspace" + assert updated_config["permissions"]["read"]["dlo"] == ["input_dlo"] + assert updated_config["permissions"]["write"]["dlo"] == ["output_dlo"] + + finally: + os.remove(temp_path) + if os.path.exists(config_path): + os.remove(config_path) + if os.path.exists(sdk_config_path): + os.remove(sdk_config_path) + os.rmdir(os.path.dirname(sdk_config_path)) + + def test_update_entrypoint_with_absolute_path(self): + """Test that entryPoint uses basename even when file_path is absolute.""" + content = textwrap.dedent( + """ + from datacustomcode.client import Client + + client = Client() + df = client.read_dlo("input_dlo") + client.write_to_dlo("output_dlo", df, "overwrite") + """ + ) + + temp_path = create_test_script(content) + assert os.path.isabs(temp_path), "Test requires absolute path" + + file_dir = os.path.dirname(temp_path) + config_path = os.path.join(file_dir, "config.json") + + try: + sdk_config_path = create_sdk_config(file_dir, "script") + + initial_config = { + "sdkVersion": "1.0.0", + "entryPoint": "old.py", + "dataspace": "default", + "permissions": {"read": {}, "write": {}}, + } + with open(config_path, "w") as f: + json.dump(initial_config, f) + + updated_config = update_config(temp_path) + + assert updated_config["entryPoint"] == os.path.basename(temp_path) + assert "/" not in updated_config["entryPoint"] + + finally: + os.remove(temp_path) + if os.path.exists(config_path): + os.remove(config_path) + if os.path.exists(sdk_config_path): + os.remove(sdk_config_path) + os.rmdir(os.path.dirname(sdk_config_path)) + + def test_update_entrypoint_preserves_function_type(self): + """Test that entryPoint update works for 'function' package type.""" + content = textwrap.dedent( + """ + from datacustomcode.client import Client + + def my_function(event, context): + return {"statusCode": 200} + """ + ) + + temp_path = create_test_script(content) + file_dir = os.path.dirname(temp_path) + config_path = os.path.join(file_dir, "config.json") + + try: + sdk_config_path = create_sdk_config(file_dir, "function") + + initial_config = { + "entryPoint": "old_function.py", + } + with open(config_path, "w") as f: + json.dump(initial_config, f) + + updated_config = update_config(temp_path) + + assert updated_config["entryPoint"] == os.path.basename(temp_path) + + finally: + os.remove(temp_path) + if os.path.exists(config_path): + os.remove(config_path) + if os.path.exists(sdk_config_path): + os.remove(sdk_config_path) + os.rmdir(os.path.dirname(sdk_config_path)) + class TestDataAccessLayerCalls: """Tests for the DataAccessLayerCalls class directly."""