-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpatcher
More file actions
executable file
·178 lines (154 loc) · 5.66 KB
/
Copy pathpatcher
File metadata and controls
executable file
·178 lines (154 loc) · 5.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#!/usr/bin/env python3
"""Apply patches to a Python wheel file."""
from __future__ import annotations
import argparse
import os
import shlex
import shutil
import subprocess
import tempfile
from pathlib import Path
from wheel.wheelfile import WheelFile # type: ignore
def no_backup_option(patch_prog: Path) -> list[str]:
tool_version = subprocess.check_output((str(patch_prog), "--version"))
if tool_version.startswith(b"GNU patch "):
return ["--no-backup-if-mismatch"]
elif tool_version.endswith(b"Apple\n"):
return ["--version-control", "none"]
else:
print("Unknown 'patch' tool. Cannot determine no-backup option.")
return []
def update_metadata_version(
wheel_dir: str, dist_info_path: str, version: str, suffix: str
) -> None:
"""Update dist-info/METADATA with the new version and suffix."""
metadata_path = Path(wheel_dir).joinpath(dist_info_path, "METADATA")
metadata = []
with metadata_path.open("r") as f:
for line in f:
if line.startswith("Version:"):
assert line == f"Version: {version}\n"
metadata.append(f"Version: {version}{suffix}\n")
print(metadata[-1])
else:
metadata.append(line)
with metadata_path.open("w") as f:
f.write("".join(metadata))
def patch_wheel(
src_wheel: Path,
dest_dir: Path,
patch_dir: Path,
suffix: str,
patch_prog: Path,
strip_count: int,
) -> tuple[Path, str]:
temp_dir = tempfile.mkdtemp()
try:
with WheelFile(src_wheel) as w:
old_dist_info_path = w.dist_info_path
grp = w.parsed_filename.group
namever = grp("namever")
version = grp("ver")
if not suffix.startswith("+"):
suffix = "+" + suffix
if "+" in version:
# e.g., version='2.1.2+cu118'. Two '+'s are not allowed in the wheel filename.
suffix = suffix.replace("+", "_")
new_dist_info_path = "{}{}.dist-info".format(namever, suffix)
requirement = f"{grp('name')}=={version}{suffix}"
new_wheel_filename = "{namever}{suffix}-{pyver}-{abi}-{plat}.whl".format(
namever=namever,
suffix=suffix,
pyver=grp("pyver"),
abi=grp("abi"),
plat=grp("plat"),
)
new_wheel_path = dest_dir.joinpath(new_wheel_filename)
print(f"Extracting {src_wheel} to {temp_dir}")
w.extractall(temp_dir)
discard_backups_option = no_backup_option(patch_prog)
for patch in sorted(os.listdir(patch_dir)):
patch_path = patch_dir.absolute().joinpath(patch)
patch_cmd = [
str(patch_prog),
f"--strip={strip_count}",
f"--input={patch_path}",
] + discard_backups_option
print(f"\t{shlex.join(patch_cmd)}")
subprocess.check_call(patch_cmd, cwd=temp_dir)
os.rename(
os.path.join(temp_dir, old_dist_info_path),
os.path.join(temp_dir, new_dist_info_path),
)
update_metadata_version(temp_dir, new_dist_info_path, version, suffix)
with WheelFile(new_wheel_path, "w") as w:
print(f"Repacking wheel as {new_wheel_path}")
w.write_files(temp_dir)
return new_wheel_path, requirement
finally:
# Clean up the temporary directory
shutil.rmtree(temp_dir)
def parse_args() -> argparse.Namespace:
patch_prog = shutil.which("patch")
parser = argparse.ArgumentParser(description="Patch a Python wheel file")
parser.set_defaults(
patch_prog=Path(patch_prog) if patch_prog else None,
dest_dir=".",
strip_count=1,
)
parser.add_argument("--wheel", "-w", required=True, help="The wheel file to patch")
parser.add_argument(
"--patch-dir",
"-p",
required=True,
help="Directory containing patches to apply to the wheel",
)
parser.add_argument(
"--dest-dir",
"-d",
help="The directory to save the patched wheel. Default: %(default)r",
)
parser.add_argument(
"--suffix",
"-s",
required=True,
help="The suffix to append to the version in the wheel filename. "
"A leading '+' will be added if not present. "
"For example, '--wheel torch-2.1.0-cp38-cp38-manylinux1_x86_64.whl --suffix stripe.4' "
"creates 'torch-2.1.0+stripe.4-cp38-cp38-manylinux1_x86_64.whl'. "
"This patched wheel can be installed with the requirement 'torch==2.1.0+stripe.4'.",
)
parser.add_argument(
"--patch-prog",
"-P",
help="Path to 'patch' executable. Default: '%(default)s'.",
)
parser.add_argument(
"--strip-count",
help="Number of leading path components to strip from filenames in the patch. "
"Default: %(default)d.",
)
args = parser.parse_args()
if args.patch_prog:
args.patch_prog = Path(args.patch_prog)
else:
raise ValueError("Can't find 'patch'")
return args
def main() -> int:
args = parse_args()
new_wheel_path, requirement = patch_wheel(
src_wheel=Path(args.wheel),
dest_dir=Path(args.dest_dir),
patch_dir=Path(args.patch_dir),
suffix=args.suffix,
patch_prog=Path(args.patch_prog),
strip_count=args.strip_count,
)
print(f"Patched wheel written to {new_wheel_path}\n")
print(
f"Use 'pip install {requirement}' to install the patched wheel\n"
+ "after it has been uploaded to a Python package index."
)
return 0
if __name__ == "__main__":
exit(main())