Skip to content

Commit d231311

Browse files
committed
fix: refactored get_project_dir helper function
1 parent 24b9dc9 commit d231311

File tree

4 files changed

+52
-20
lines changed

4 files changed

+52
-20
lines changed

gptme/dirs.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import subprocess
23
from pathlib import Path
34

45
from platformdirs import user_config_dir, user_data_dir
@@ -40,6 +41,51 @@ def get_logs_dir() -> Path:
4041
return path
4142

4243

44+
def get_project_gptme_dir() -> Path | None:
45+
"""
46+
Walks up the directory tree from the working dir to find the project root,
47+
which is a directory containing a `gptme.toml` file.
48+
Or if none exists, the first parent directory with a git repo.
49+
50+
Meant to be used in scripts/tools to detect a suitable location to store agent data/logs.
51+
"""
52+
path = Path.cwd()
53+
while path != Path("/"):
54+
if (path / "gptme.toml").exists():
55+
return path
56+
path = path.parent
57+
58+
# if no gptme.toml file was found, look for a git repo
59+
return _get_project_git_dir_walk()
60+
61+
62+
def get_project_git_dir() -> Path | None:
63+
return _get_project_git_dir_walk()
64+
65+
66+
def _get_project_git_dir_walk() -> Path | None:
67+
# if no gptme.toml file was found, look for a git repo
68+
path = Path.cwd()
69+
while path != Path("/"):
70+
if (path / ".git").exists():
71+
return path
72+
path = path.parent
73+
return None
74+
75+
76+
def _get_project_git_dir_call() -> Path | None:
77+
try:
78+
projectdir = subprocess.run(
79+
["git", "rev-parse", "--show-toplevel"],
80+
capture_output=True,
81+
text=True,
82+
check=True,
83+
).stdout.strip()
84+
return Path(projectdir)
85+
except subprocess.CalledProcessError:
86+
return None
87+
88+
4389
def _init_paths():
4490
# create all paths
4591
for path in [get_config_dir(), get_data_dir(), get_logs_dir()]:

gptme/prompts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
from .__version__ import __version__
1616
from .config import get_config, get_project_config
17+
from .dirs import get_project_git_dir
1718
from .message import Message
1819
from .tools import ToolFormat
19-
from .util import document_prompt_function, get_project_dir
20+
from .util import document_prompt_function
2021

2122
PromptType = Literal["full", "short"]
2223

@@ -142,7 +143,7 @@ def prompt_gptme(interactive: bool) -> Generator[Message, None, None]:
142143
Proceed directly with the most appropriate actions to complete the task.
143144
""".strip()
144145

145-
projectdir = get_project_dir()
146+
projectdir = get_project_git_dir()
146147
project_config = get_project_config(projectdir)
147148
base_prompt = (
148149
project_config.base_prompt
@@ -189,7 +190,7 @@ def prompt_project() -> Generator[Message, None, None]:
189190
"""
190191
Generate the project-specific prompt based on the current Git repository.
191192
"""
192-
projectdir = get_project_dir()
193+
projectdir = get_project_git_dir()
193194
if not projectdir:
194195
return
195196

gptme/tools/rag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@
4242
from pathlib import Path
4343

4444
from ..config import RagConfig, get_project_config
45+
from ..dirs import get_project_gptme_dir
4546
from ..llm import _chat_complete
4647
from ..message import Message
47-
from ..util import get_project_dir
4848
from .base import ToolSpec, ToolUse
4949

5050
logger = logging.getLogger(__name__)
@@ -163,7 +163,7 @@ def init() -> ToolSpec:
163163
return replace(tool, available=False)
164164

165165
# Check project configuration
166-
project_dir = get_project_dir()
166+
project_dir = get_project_gptme_dir()
167167
if project_dir and (config := get_project_config(project_dir)):
168168
enabled = config.rag.enabled
169169
if not enabled:

gptme/util/__init__.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import logging
77
import re
88
import shutil
9-
import subprocess
109
import sys
1110
import textwrap
1211
from datetime import datetime, timedelta
@@ -216,17 +215,3 @@ def get_installed_programs(candidates: tuple[str, ...]) -> set[str]:
216215
if shutil.which(candidate) is not None:
217216
installed.add(candidate)
218217
return installed
219-
220-
221-
def get_project_dir() -> Path | None:
222-
try:
223-
projectdir = subprocess.run(
224-
["git", "rev-parse", "--show-toplevel"],
225-
capture_output=True,
226-
text=True,
227-
check=True,
228-
).stdout.strip()
229-
return Path(projectdir)
230-
except subprocess.CalledProcessError:
231-
logger.debug("Unable to determine Git repository root.")
232-
return None

0 commit comments

Comments
 (0)