Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

long time persist model parameters #15

Open
cyrushu opened this issue Nov 22, 2023 · 0 comments
Open

long time persist model parameters #15

cyrushu opened this issue Nov 22, 2023 · 0 comments

Comments

@cyrushu
Copy link

cyrushu commented Nov 22, 2023

Dear developers,

Here is a solution for long time persist model parameters. Which would save some networks. It would be better to have a sha256 check inside the cache-check process.

diff --git a/chroma/utility/api.py b/chroma/utility/api.py
index 902b776..ce996c8 100644
--- a/chroma/utility/api.py
+++ b/chroma/utility/api.py
@@ -21,7 +21,11 @@ import requests

 import chroma

-ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__))
+# SETTING CHROMA_ROOT_DIR or use default directory: ~/.config/chroma
+ROOT_DIR = os.environ.get(
+    "CHROMA_ROOT_DIR",
+    os.path.join(os.path.expanduser("~"), ".config", "chroma"))
+os.makedirs(ROOT_DIR, exist_ok=True)


 def register_key(key: str, key_directory=ROOT_DIR) -> None:
@@ -92,11 +96,8 @@ def download_from_generate(

     # Create a hash of the URL + weight name to determine the path for the cached/temporary file
     url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest()
-    temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash)
-    destination = os.path.join(temp_dir, "weights.pt")
-
-    # Ensure the directory exists
-    os.makedirs(temp_dir, exist_ok=True)
+    os.makedirs(os.path.join(ROOT_DIR, "weights"), exist_ok=True)
+    destination = os.path.join(ROOT_DIR, "weights", f"{url_hash}.pt")

     # Check if cache exists
     cache_exists = os.path.exists(destination)
@@ -117,8 +118,14 @@ def download_from_generate(
     response = requests.get(base_url, params=params)
     response.raise_for_status()  # Raise an error for HTTP errors

-    with open(destination, "wb") as file:
-        file.write(response.content)
+    # Write into temp_file
+    temp_file = tempfile.TemporaryFile()
+    temp_file.write(response.content)
+
+    # Write into cached destination
+    with open(destination, "wb") as f:
+        temp_file.seek(0)
+        f.write(temp_file.read())

     print(f"Data saved to {destination}")
     return destination
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant