From 3d758a9c9458631e99606169e24df777780a016e Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 1 Jun 2024 11:58:59 -0700 Subject: [PATCH] feat: persistent valves --- main.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/main.py b/main.py index ebd41c7d..dd3381f5 100644 --- a/main.py +++ b/main.py @@ -124,8 +124,37 @@ async def load_modules_from_directory(directory): if filename.endswith(".py"): module_name = filename[:-3] # Remove the .py extension module_path = os.path.join(directory, filename) + + # Create subfolder matching the filename without the .py extension + subfolder_path = os.path.join(directory, module_name) + if not os.path.exists(subfolder_path): + os.makedirs(subfolder_path) + logging.info(f"Created subfolder: {subfolder_path}") + + # Create a valves.json file if it doesn't exist + valves_json_path = os.path.join(subfolder_path, "valves.json") + if not os.path.exists(valves_json_path): + with open(valves_json_path, "w") as f: + json.dump({}, f) + logging.info(f"Created valves.json in: {subfolder_path}") + pipeline = await load_module_from_path(module_name, module_path) if pipeline: + # Overwrite pipeline.valves with values from valves.json + if os.path.exists(valves_json_path): + with open(valves_json_path, "r") as f: + valves_json = json.load(f) + ValvesModel = pipeline.valves.__class__ + # Create a ValvesModel instance using default values and overwrite with valves_json + combined_valves = { + **pipeline.valves.model_dump(), + **valves_json, + } + valves = ValvesModel(**combined_valves) + pipeline.valves = valves + + logging.info(f"Updated valves for module: {module_name}") + pipeline_id = pipeline.id if hasattr(pipeline, "id") else module_name PIPELINE_MODULES[pipeline_id] = pipeline PIPELINE_NAMES[pipeline_id] = module_name @@ -441,6 +470,14 @@ async def update_valves(pipeline_id: str, form_data: dict): valves = ValvesModel(**form_data) pipeline.valves = valves + # Determine the directory path for the valves.json file + subfolder_path = os.path.join(PIPELINES_DIR, PIPELINE_NAMES[pipeline_id]) + valves_json_path = os.path.join(subfolder_path, "valves.json") + + # Save the updated valves data back to the valves.json file + with open(valves_json_path, "w") as f: + json.dump(valves.model_dump(), f) + if hasattr(pipeline, "on_valves_updated"): await pipeline.on_valves_updated() except Exception as e: