-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
install.py
executable file
·272 lines (224 loc) · 9.26 KB
/
install.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
#!/usr/bin/env python
"""Configuration script for setting up tab completion for nerfstudio in bash and zsh."""
import concurrent.futures
import itertools
import os
import pathlib
import random
import shutil
import stat
import subprocess
import sys
from typing import List, Literal, Union
import tyro
from rich.console import Console
from rich.prompt import Confirm
from typing_extensions import assert_never
ConfigureMode = Literal["install", "uninstall"]
ShellType = Literal["zsh", "bash"]
CONSOLE = Console(width=120)
ENTRYPOINTS = [
"ns-install-cli",
"ns-process-data",
"ns-download-data",
"ns-train",
"ns-eval",
"ns-render",
"ns-dev-test",
]
def _check_tyro_cli(script_path: pathlib.Path) -> bool:
"""Check if a path points to a script containing a tyro.cli() call. Also checks
for any permissions/shebang issues.
Args:
script_path: Path to prospective CLI.
Returns:
True if a completion is can be generated.
"""
assert script_path.suffix == ".py"
script_src = script_path.read_text()
if '\nif __name__ == "__main__":\n' in script_src:
# Check script for execute permissions. For consistency, note that we apply this
# and the shebang check even if tyro isn't used.
execute_flags = stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH
if not script_path.stat().st_mode & execute_flags and Confirm.ask(
f"[yellow]:warning: {script_path} is not marked as executable. Fix?[/yellow]",
default=True,
):
script_path.chmod(script_path.stat().st_mode | execute_flags)
# Check that script has a shebang.
if not script_src.startswith("#!/") and Confirm.ask(
f"[yellow]:warning: {script_path} is missing a shebang. Fix?[/yellow]",
default=True,
):
script_path.write_text("#!/usr/bin/env python\n" + script_src)
# Return True only if compatible with tyro.
return "import tyro" in script_src and "tyro.cli" in script_src
return False
def _generate_completion(
path_or_entrypoint: Union[pathlib.Path, str], shell: ShellType, completions_dir: pathlib.Path
) -> pathlib.Path:
"""Given a path to a tyro CLI, write a completion script to a target directory.
Args:
script_path: Path to Python CLI to generate completion script for.
shell: Shell to generate completion script for.
completions_dir: Directory to write completion script to.
Returns:
Success flag.
"""
if isinstance(path_or_entrypoint, pathlib.Path):
# Scripts.
target_name = "_" + path_or_entrypoint.name.replace(".", "_")
args = [sys.executable, str(path_or_entrypoint), "--tyro-print-completion", shell]
elif isinstance(path_or_entrypoint, str):
# Entry points.
target_name = "_" + path_or_entrypoint
args = [path_or_entrypoint, "--tyro-print-completion", shell]
else:
assert_never(path_or_entrypoint)
target_path = completions_dir / shell / target_name
# Generate and write the new completion.
new = subprocess.run(
args=args,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf8",
check=True,
).stdout
target_path.parent.mkdir(parents=True, exist_ok=True)
if not target_path.exists():
target_path.write_text(new)
CONSOLE.log(f":heavy_check_mark: Wrote new completion to {target_path}!")
elif target_path.read_text().strip() != new.strip():
target_path.write_text(new)
CONSOLE.log(f":heavy_check_mark: Updated completion at {target_path}!")
else:
CONSOLE.log(f"[dim]:heavy_check_mark: Nothing to do for {target_path}[/dim].")
return target_path
def _exclamation() -> str:
return random.choice(["Cool", "Nice", "Neat", "Great", "Exciting", "Excellent", "Ok"]) + "!"
def _update_rc(
completions_dir: pathlib.Path,
mode: ConfigureMode,
shell: ShellType,
) -> None:
"""Try to add a `source /.../completions/setup.{shell}` line automatically to a user's zshrc or bashrc.
Args:
completions_dir: Path to location of this script.
shell: Shell to install completion scripts for.
mode: Install or uninstall completions.
"""
# Install or uninstall `source_line`.
source_lines = "\n".join(
[
"",
"# Source nerfstudio autocompletions.",
f"source {completions_dir / 'setup'}.{shell}",
]
)
rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc"
if mode == "install":
if source_lines in rc_path.read_text():
CONSOLE.log(f":call_me_hand: Completions are already installed in {rc_path}. {_exclamation()}")
return
if not Confirm.ask(f"[bold yellow]Install to {rc_path}?", default=True):
CONSOLE.log(f"[bold red]Skipping install for {rc_path.name}.")
return
rc_path.write_text(rc_path.read_text() + source_lines)
CONSOLE.log(
f":person_gesturing_ok: Completions installed to {rc_path}. {_exclamation()} Open a new shell to try them"
" out."
)
elif mode == "uninstall":
if source_lines not in rc_path.read_text():
CONSOLE.log(f":heavy_check_mark: No completions to uninstall from {rc_path.name}.")
return
if not Confirm.ask(f"[bold yellow]Uninstall from {rc_path}?", default=True):
CONSOLE.log(f"[dim red]Skipping uninstall for {rc_path.name}.")
return
rc_path.write_text(rc_path.read_text().replace(source_lines, ""))
CONSOLE.log(f":broom: Completions uninstalled from {rc_path}.")
else:
assert_never(mode)
def main(
mode: ConfigureMode = "install",
/,
) -> None:
"""Main script.
Args:
mode: Choose between installing or uninstalling completions.
shells: Shell(s) to install or uninstall.
"""
if "HOME" not in os.environ:
CONSOLE.log("[bold red]$HOME is not set. Exiting.")
return
# Try to locate the user's bashrc or zshrc.
shells_supported: List[ShellType] = ["zsh", "bash"]
shells_found: List[ShellType] = []
for shell in shells_supported:
rc_path = pathlib.Path(os.environ["HOME"]) / f".{shell}rc"
if not rc_path.exists():
CONSOLE.log(f":person_shrugging: {rc_path.name} not found, skipping.")
else:
CONSOLE.log(f":mag: Found {rc_path.name}!")
shells_found.append(shell)
# Get scripts/ directory.
completions_dir = pathlib.Path(__file__).absolute().parent
scripts_dir = completions_dir.parent
assert completions_dir.name == "completions"
assert scripts_dir.name == "scripts"
# Install mode: Generate completion for each tyro script.
if mode == "uninstall":
for shell in shells_supported:
# Reset target directory for each shell type.
target_dir = completions_dir / shell
if target_dir.exists():
assert target_dir.is_dir()
shutil.rmtree(target_dir, ignore_errors=True)
CONSOLE.log(f":broom: Deleted existing completion directory: {target_dir}.")
else:
CONSOLE.log(f":heavy_check_mark: No existing completions at: {target_dir}.")
elif mode == "install":
# Set to True to install completions for scripts as well.
include_scripts = False
# Find tyro CLIs.
script_paths = list(filter(_check_tyro_cli, scripts_dir.glob("**/*.py"))) if include_scripts else []
script_names = tuple(p.name for p in script_paths)
assert len(set(script_names)) == len(script_names)
# Get existing completion files.
existing_completions = set()
for shell in shells_supported:
target_dir = completions_dir / shell
if target_dir.exists():
existing_completions |= set(target_dir.glob("*"))
# Run generation jobs.
concurrent_executor = concurrent.futures.ThreadPoolExecutor()
with CONSOLE.status("[bold]:writing_hand: Generating completions...", spinner="bouncingBall"):
completion_paths = list(
concurrent_executor.map(
lambda path_or_entrypoint_and_shell: _generate_completion(
path_or_entrypoint_and_shell[0], path_or_entrypoint_and_shell[1], completions_dir
),
itertools.product(script_paths + ENTRYPOINTS, shells_found),
)
)
# Delete obsolete completion files.
for unexpected_path in set(p.absolute() for p in existing_completions) - set(
p.absolute() for p in completion_paths
):
if unexpected_path.is_dir():
shutil.rmtree(unexpected_path)
elif unexpected_path.exists():
unexpected_path.unlink()
CONSOLE.log(f":broom: Deleted {unexpected_path}.")
else:
assert_never(mode)
# Install or uninstall from bashrc/zshrc.
for shell in shells_found:
_update_rc(completions_dir, mode, shell)
CONSOLE.print("[bold]All done![/bold]")
def entrypoint():
"""Entrypoint for use with pyproject scripts."""
tyro.extras.set_accent_color("bright_yellow")
tyro.cli(main, description=__doc__)
if __name__ == "__main__":
entrypoint()