Skip to content

Commit 3a5321b

Browse files
committed
agent with rich display
1 parent e3cca57 commit 3a5321b

File tree

9 files changed

+761
-458
lines changed

9 files changed

+761
-458
lines changed

agent/commit0_utils.py renamed to agent/agent_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66
from typing import List
77
import fitz
8+
import yaml
89

910
from agent.class_types import AgentConfig
1011

@@ -118,7 +119,7 @@ def get_file_info(file_path: Path, prefix: str = "") -> str:
118119
return "\n".join(filter(None, tree_string))
119120

120121

121-
def get_target_edit_files(target_dir: str) -> list[str]:
122+
def get_target_edit_files(target_dir: str, src_prefix: str) -> list[str]:
122123
"""Find the files with functions with the pass statement."""
123124
files = []
124125
for root, _, filenames in os.walk(target_dir):
@@ -131,7 +132,7 @@ def get_target_edit_files(target_dir: str) -> list[str]:
131132

132133
# Remove the base_dir prefix
133134
files = [file.replace(target_dir, "").lstrip("/") for file in files]
134-
135+
files = [src_prefix + file for file in files]
135136
# Only keep python files
136137
files = [file for file in files if file.endswith(".py")]
137138

@@ -308,3 +309,17 @@ def get_lint_cmd(repo_name: str, use_lint_info: bool) -> str:
308309
else:
309310
lint_cmd = ""
310311
return lint_cmd
312+
313+
314+
def write_agent_config(agent_config_file: str, agent_config: dict) -> None:
315+
"""Write the agent config to the file."""
316+
with open(agent_config_file, "w") as f:
317+
yaml.dump(agent_config, f)
318+
319+
320+
def read_yaml_config(config_file: str) -> dict:
321+
"""Read the yaml config from the file."""
322+
if not os.path.exists(config_file):
323+
raise FileNotFoundError(f"The config file '{config_file}' does not exist.")
324+
with open(config_file, "r") as f:
325+
return yaml.load(f, Loader=yaml.FullLoader)

agent/agents.py

Lines changed: 91 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
import sys
2-
import os
32
from abc import ABC, abstractmethod
43
from pathlib import Path
54
import logging
65

76
from aider.coders import Coder
87
from aider.models import Model
98
from aider.io import InputOutput
10-
from tenacity import retry, wait_exponential, RetryCallState, retry_if_exception_type
9+
import re
1110

1211

13-
class APIError(Exception):
14-
def __init__(self, status_code: int, message: str):
15-
self.status_code = status_code
16-
self.message = message
17-
super().__init__(f"API Error: {status_code} - {message}")
18-
19-
def handle_logging(logging_name: str, log_file: Path):
12+
def handle_logging(logging_name: str, log_file: Path) -> None:
13+
"""Handle logging for agent"""
2014
logger = logging.getLogger(logging_name)
2115
logger.setLevel(logging.INFO)
2216
logger.propagate = False
@@ -26,17 +20,32 @@ def handle_logging(logging_name: str, log_file: Path):
2620
)
2721
logger.addHandler(logger_handler)
2822

23+
24+
class AgentReturn(ABC):
25+
def __init__(self, log_file: Path):
26+
self.log_file = log_file
27+
self.last_cost = self.get_money_cost()
28+
29+
def get_money_cost(self) -> float:
30+
"""Get accumulated money cost from log file"""
31+
last_cost = 0.0
32+
with open(self.log_file, "r") as file:
33+
for line in file:
34+
if "Tokens:" in line and "Cost:" in line:
35+
match = re.search(
36+
r"Cost: \$\d+\.\d+ message, \$(\d+\.\d+) session", line
37+
)
38+
if match:
39+
last_cost = float(match.group(1))
40+
return last_cost
41+
42+
2943
class Agents(ABC):
30-
def __init__(self, max_iteration: int, retry_if_api_error_codes: tuple[int, ...] = (429, 503, 529)):
44+
def __init__(self, max_iteration: int):
3145
self.max_iteration = max_iteration
3246

33-
# error code 429 is rate limit exceeded for openai and anthropic
34-
# error code 503 is service overloaded for openai
35-
# error code 529 is service overloaded for anthropic
36-
self.retry_if_api_error_codes = retry_if_api_error_codes
37-
3847
@abstractmethod
39-
def run(self) -> None:
48+
def run(self) -> AgentReturn:
4049
"""Start agent"""
4150
raise NotImplementedError
4251

@@ -46,84 +55,81 @@ def __init__(self, max_iteration: int, model_name: str):
4655
super().__init__(max_iteration)
4756
self.model = Model(model_name)
4857

49-
@retry(
50-
wait=wait_exponential(multiplier=1, min=4, max=10),
51-
retry=retry_if_exception_type(APIError)
52-
)
5358
def run(
5459
self,
5560
message: str,
5661
test_cmd: str,
5762
lint_cmd: str,
5863
fnames: list[str],
5964
log_dir: Path,
60-
) -> None:
65+
) -> AgentReturn:
6166
"""Start aider agent"""
62-
try:
63-
if test_cmd:
64-
auto_test = True
65-
else:
66-
auto_test = False
67-
if lint_cmd:
68-
auto_lint = True
69-
else:
70-
auto_lint = False
71-
log_dir = log_dir.resolve()
72-
log_dir.mkdir(parents=True, exist_ok=True)
73-
input_history_file = log_dir / ".aider.input.history"
74-
chat_history_file = log_dir / ".aider.chat.history.md"
75-
76-
print(
77-
f"check {os.path.abspath(chat_history_file)} for prompts and lm generations",
78-
file=sys.stderr,
79-
)
80-
# Set up logging
81-
log_file = log_dir / "aider.log"
82-
logging.basicConfig(
83-
filename=log_file,
84-
level=logging.INFO,
85-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
86-
)
67+
if test_cmd:
68+
auto_test = True
69+
else:
70+
auto_test = False
71+
if lint_cmd:
72+
auto_lint = True
73+
else:
74+
auto_lint = False
75+
log_dir = log_dir.resolve()
76+
log_dir.mkdir(parents=True, exist_ok=True)
77+
input_history_file = log_dir / ".aider.input.history"
78+
chat_history_file = log_dir / ".aider.chat.history.md"
8779

88-
# Redirect print statements to the log file
89-
sys.stdout = open(log_file, "a")
90-
sys.stderr = open(log_file, "a")
80+
# Set up logging
81+
log_file = log_dir / "aider.log"
82+
logging.basicConfig(
83+
filename=log_file,
84+
level=logging.INFO,
85+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
86+
)
9187

92-
# Configure httpx and backoff logging
93-
handle_logging("httpx", log_file)
94-
handle_logging("backoff", log_file)
88+
# Redirect print statements to the log file
89+
sys.stdout = open(log_file, "a")
90+
sys.stderr = open(log_file, "a")
9591

96-
io = InputOutput(
97-
yes=True,
98-
input_history_file=input_history_file,
99-
chat_history_file=chat_history_file,
100-
)
101-
coder = Coder.create(
102-
main_model=self.model,
103-
fnames=fnames,
104-
auto_lint=auto_lint,
105-
auto_test=auto_test,
106-
lint_cmds={"python": lint_cmd},
107-
test_cmd=test_cmd,
108-
io=io,
92+
# Configure httpx and backoff logging
93+
handle_logging("httpx", log_file)
94+
handle_logging("backoff", log_file)
95+
96+
io = InputOutput(
97+
yes=True,
98+
input_history_file=input_history_file,
99+
chat_history_file=chat_history_file,
100+
)
101+
coder = Coder.create(
102+
main_model=self.model,
103+
fnames=fnames,
104+
auto_lint=auto_lint,
105+
auto_test=auto_test,
106+
lint_cmds={"python": lint_cmd},
107+
test_cmd=test_cmd,
108+
io=io,
109+
)
110+
coder.max_reflection = self.max_iteration
111+
coder.stream = True
112+
113+
# Run the agent
114+
# coder.run(message)
115+
116+
#### TMP
117+
import time
118+
import random
119+
120+
time.sleep(random.random() * 5)
121+
n = random.random() / 10
122+
with open(log_file, "a") as f:
123+
f.write(
124+
f"> Tokens: 33k sent, 1.3k received. Cost: $0.12 message, ${n} session. \n"
109125
)
110-
coder.max_reflection = self.max_iteration
111-
coder.stream = False
112-
113-
# Run the agent
114-
raise Exception("test")
115-
coder.run(message)
116-
117-
except Exception as e:
118-
# If the exception is related to API errors, raise an APIError
119-
if hasattr(e, 'status_code') and e.status_code in self.retry_if_api_error_codes:
120-
raise APIError(e.status_code, str(e))
121-
# For other exceptions, re-raise them
122-
raise
123-
finally:
124-
# Close redirected stdout and stderr
125-
sys.stdout.close()
126-
sys.stderr.close()
127-
# Restore original stdout and stderr
128-
sys.stdout = sys.__stdout__
129-
sys.stderr = sys.__stderr__
126+
#### TMP
127+
128+
# Close redirected stdout and stderr
129+
sys.stdout.close()
130+
sys.stderr.close()
131+
# Restore original stdout and stderr
132+
sys.stdout = sys.__stdout__
133+
sys.stderr = sys.__stderr__
134+
135+
return AgentReturn(log_file)

agent/class_types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
class AgentConfig:
66
agent_name: str
77
model_name: str
8-
backend: str
98
use_user_prompt: bool
109
user_prompt: str
1110
use_repo_info: bool

0 commit comments

Comments
 (0)