9
9
from pathlib import Path
10
10
11
11
from ..message import Message
12
- from ..util import ask_execute
12
+ from ..util import ask_execute , print_preview
13
13
from .base import ToolSpec , ToolUse
14
14
15
15
@@ -76,17 +76,18 @@ def diff_minimal(self, strip_context=False) -> str:
76
76
# TODO: show this when previewing the patch
77
77
# TODO: replace previous patches with the minimal version
78
78
79
- diff = difflib .unified_diff (
80
- self .original .splitlines (),
81
- self .updated .splitlines (),
82
- lineterm = "" ,
83
- fromfile = "original" ,
84
- tofile = "updated" ,
85
- )
86
- diff = list (diff )[3 :]
79
+ diff = list (
80
+ difflib .unified_diff (
81
+ self .original .splitlines (),
82
+ self .updated .splitlines (),
83
+ lineterm = "" ,
84
+ fromfile = "original" ,
85
+ tofile = "updated" ,
86
+ )
87
+ )[3 :]
87
88
if strip_context :
88
89
# find first and last lines with changes
89
- markers = [l [0 ] for l in diff ]
90
+ markers = [line [0 ] for line in diff ]
90
91
start = min (
91
92
markers .index ("+" ) if "+" in markers else len (markers ),
92
93
markers .index ("-" ) if "-" in markers else len (markers ),
@@ -95,6 +96,7 @@ def diff_minimal(self, strip_context=False) -> str:
95
96
markers [::- 1 ].index ("+" ) if "+" in markers else len (markers ),
96
97
markers [::- 1 ].index ("-" ) if "-" in markers else len (markers ),
97
98
)
99
+ len (diff ) - start - end
98
100
diff = diff [start : len (diff ) - end ]
99
101
return "\n " .join (diff )
100
102
@@ -170,7 +172,13 @@ def execute_patch(
170
172
path = Path (fn ).expanduser ()
171
173
if not path .exists ():
172
174
raise ValueError (f"file not found: { fn } " )
175
+
176
+ patches = Patch .from_codeblock (code )
177
+ patches_str = "\n \n " .join (p .diff_minimal () for p in patches )
178
+ print_preview (patches_str , lang = "diff" )
179
+
173
180
if ask :
181
+ # TODO: display minimal patches
174
182
confirm = ask_execute (f"Apply patch to { fn } ?" )
175
183
if not confirm :
176
184
print ("Patch not applied" )
@@ -186,13 +194,13 @@ def execute_patch(
186
194
f .write (patched_content )
187
195
188
196
# Compare token counts
189
- patch_tokens = len (code )
190
- full_file_tokens = len (patched_content )
197
+ patch_len = len (code )
198
+ full_file_len = len (patched_content )
191
199
192
200
warnings = []
193
- if full_file_tokens < patch_tokens :
201
+ if 1000 < full_file_len < patch_len :
194
202
warnings .append (
195
- "Note: The patch was larger than the file. In the future, try writing smaller patches or use the save tool instead."
203
+ "Note: The patch was big and larger than the file. In the future, try writing smaller patches or use the save tool instead."
196
204
)
197
205
warnings_str = ("\n " .join (warnings ) + "\n " ) if warnings else ""
198
206
0 commit comments