-
-
Notifications
You must be signed in to change notification settings - Fork 44
/
config.py
350 lines (269 loc) · 10.9 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import json
import shutil
from functools import wraps
from inspect import Parameter, signature
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Optional
import rtoml
from pydantic import (
BaseModel,
Field,
FilePath,
PositiveInt,
PrivateAttr,
field_validator,
model_validator,
)
from pydantic_extra_types.color import Color
from .logger import logger
Receiver = Callable[..., Any]
class Signal(BaseModel): # type: ignore[misc]
__receivers: list[Receiver] = PrivateAttr(default_factory=list)
def connect(self, receiver: Receiver) -> None:
self.__receivers.append(receiver)
def disconnect(self, receiver: Receiver) -> None:
self.__receivers.remove(receiver)
def emit(self, *args: Any) -> None:
for receiver in self.__receivers:
receiver(*args)
def key_id(name: str) -> PositiveInt:
"""Avoid importing Qt too early."""
from qtpy.QtCore import Qt
return getattr(Qt, f"Key_{name}")
class Key(BaseModel): # type: ignore[misc]
"""Represents a list of key codes, with optionally a name."""
ids: list[PositiveInt] = Field(unique=True)
name: Optional[str] = None
__signal: Signal = PrivateAttr(default_factory=Signal)
@field_validator("ids")
@classmethod
def ids_is_non_empty_set(cls, ids: set[Any]) -> set[Any]:
if len(ids) <= 0:
raise ValueError("Key's ids must be a non-empty set")
return ids
def set_ids(self, *ids: int) -> None:
self.ids = list(set(ids))
def match(self, key_id: int) -> bool:
m = key_id in self.ids
if m:
logger.debug(f"Pressed key: {self.name}")
return m
@property
def signal(self) -> Signal:
return self.__signal
def connect(self, function: Receiver) -> None:
self.__signal.connect(function)
class Keys(BaseModel): # type: ignore[misc]
QUIT: Key = Field(default_factory=lambda: Key(ids=[key_id("Q")], name="QUIT"))
PLAY_PAUSE: Key = Field(
default_factory=lambda: Key(ids=[key_id("Space")], name="PLAY / PAUSE")
)
NEXT: Key = Field(default_factory=lambda: Key(ids=[key_id("Right")], name="NEXT"))
PREVIOUS: Key = Field(
default_factory=lambda: Key(ids=[key_id("Left")], name="PREVIOUS")
)
REVERSE: Key = Field(default_factory=lambda: Key(ids=[key_id("V")], name="REVERSE"))
REPLAY: Key = Field(default_factory=lambda: Key(ids=[key_id("R")], name="REPLAY"))
FULL_SCREEN: Key = Field(
default_factory=lambda: Key(ids=[key_id("F")], name="TOGGLE FULL SCREEN")
)
HIDE_MOUSE: Key = Field(
default_factory=lambda: Key(ids=[key_id("H")], name="HIDE / SHOW MOUSE")
)
@model_validator(mode="before")
@classmethod
def ids_are_unique_across_keys(cls, values: dict[str, Key]) -> dict[str, Key]:
ids: set[int] = set()
for key in values.values():
if len(ids.intersection(key["ids"])) != 0:
raise ValueError(
"Two or more keys share a common key code: please make sure each key has distinct key codes"
)
ids.update(key["ids"])
return values
def merge_with(self, other: "Keys") -> "Keys":
for key_name, key in self:
other_key = getattr(other, key_name)
key.ids = list(set(key.ids).union(other_key.ids))
key.name = other_key.name or key.name
return self
def dispatch_key_function(self) -> Callable[[PositiveInt], None]:
_dispatch = {}
for _, key in self:
for _id in key.ids:
_dispatch[_id] = key.signal
def dispatch(key: PositiveInt) -> None:
if signal := _dispatch.get(key, None):
signal.emit()
return dispatch
class Config(BaseModel): # type: ignore[misc]
"""General Manim Slides config."""
keys: Keys = Field(default_factory=Keys)
@classmethod
def from_file(cls, path: Path) -> "Config":
"""Read a configuration from a file."""
return cls.model_validate(rtoml.load(path)) # type: ignore
def to_file(self, path: Path) -> None:
"""Dump the configuration to a file."""
rtoml.dump(self.model_dump(), path, pretty=True)
def merge_with(self, other: "Config") -> "Config":
"""Merge with another config."""
self.keys = self.keys.merge_with(other.keys)
return self
class BaseSlideConfig(BaseModel): # type: ignore
"""Base class for slide config."""
loop: bool = False
auto_next: bool = False
playback_rate: float = 1.0
reversed_playback_rate: float = 1.0
notes: str = ""
dedent_notes: bool = True
@classmethod
def wrapper(cls, arg_name: str) -> Callable[..., Any]:
"""
Wrap a function to transform keyword argument into an instance of this class.
The function signature is updated to reflect the new keyword-only arguments.
The wrapped function must follow two criteria:
- its last parameter must be ``**kwargs`` (or equivalent);
- and its second last parameter must be ``<arg_name>``.
"""
def _wrapper_(fun: Callable[..., Any]) -> Callable[..., Any]:
@wraps(fun)
def __wrapper__(*args: Any, **kwargs: Any) -> Any: # noqa: N807
fun_kwargs = {
key: value
for key, value in kwargs.items()
if key not in cls.__fields__
}
fun_kwargs[arg_name] = cls(**kwargs)
return fun(*args, **fun_kwargs)
sig = signature(fun)
parameters = list(sig.parameters.values())
parameters[-2:-1] = [
Parameter(
field_name,
Parameter.KEYWORD_ONLY,
default=field_info.default,
annotation=field_info.annotation,
)
for field_name, field_info in cls.__fields__.items()
]
sig = sig.replace(parameters=parameters)
__wrapper__.__signature__ = sig # type: ignore[attr-defined]
return __wrapper__
return _wrapper_
@model_validator(mode="after")
@classmethod
def apply_dedent_notes(
cls, base_slide_config: "BaseSlideConfig"
) -> "BaseSlideConfig":
if base_slide_config.dedent_notes:
base_slide_config.notes = dedent(base_slide_config.notes)
return base_slide_config
class PreSlideConfig(BaseSlideConfig):
"""Slide config to be used prior to rendering."""
start_animation: int
end_animation: int
@classmethod
def from_base_slide_config_and_animation_indices(
cls,
base_slide_config: BaseSlideConfig,
start_animation: int,
end_animation: int,
) -> "PreSlideConfig":
return cls(
start_animation=start_animation,
end_animation=end_animation,
**base_slide_config.dict(),
)
@field_validator("start_animation", "end_animation")
@classmethod
def index_is_posint(cls, v: int) -> int:
if v < 0:
raise ValueError("Animation index (start or end) cannot be negative")
return v
@model_validator(mode="after")
@classmethod
def start_animation_is_before_end(
cls, pre_slide_config: "PreSlideConfig"
) -> "PreSlideConfig":
if pre_slide_config.start_animation >= pre_slide_config.end_animation:
if pre_slide_config.start_animation == pre_slide_config.end_animation == 0:
raise ValueError(
"You have to play at least one animation (e.g., `self.wait()`) "
"before pausing. If you want to start paused, use the appropriate "
"command-line option when presenting. "
"IMPORTANT: when using ManimGL, `self.wait()` is not considered "
"to be an animation, so prefer to directly use `self.play(...)`."
)
raise ValueError(
"Start animation index must be strictly lower than end animation index"
)
return pre_slide_config
@model_validator(mode="after")
@classmethod
def loop_and_auto_next_disallowed(
cls, pre_slide_config: "PreSlideConfig"
) -> "PreSlideConfig":
if pre_slide_config.loop and pre_slide_config.auto_next:
raise ValueError(
"You cannot have both `loop=True` and `auto_next=True`, "
"because a looping slide has no ending. "
"This may be supported in the future if "
"https://github.com/jeertmans/manim-slides/pull/299 gets merged."
)
return pre_slide_config
@property
def slides_slice(self) -> slice:
return slice(self.start_animation, self.end_animation)
class SlideConfig(BaseSlideConfig):
"""Slide config to be used after rendering."""
file: FilePath
rev_file: FilePath
@classmethod
def from_pre_slide_config_and_files(
cls, pre_slide_config: PreSlideConfig, file: Path, rev_file: Path
) -> "SlideConfig":
return cls(file=file, rev_file=rev_file, **pre_slide_config.dict())
class PresentationConfig(BaseModel): # type: ignore[misc]
slides: list[SlideConfig] = Field(min_length=1)
resolution: tuple[PositiveInt, PositiveInt] = (1920, 1080)
background_color: Color = "black"
@classmethod
def from_file(cls, path: Path) -> "PresentationConfig":
"""Read a presentation configuration from a file."""
with open(path) as f:
obj = json.load(f)
slides = obj.setdefault("slides", [])
parent = path.parent.parent # Never fails, but parents[1] can fail
for slide in slides:
if file := slide.get("file", None):
slide["file"] = parent / file
if rev_file := slide.get("rev_file", None):
slide["rev_file"] = parent / rev_file
return cls.model_validate(obj) # type: ignore
def to_file(self, path: Path) -> None:
"""Dump the presentation configuration to a file."""
with open(path, "w") as f:
f.write(self.model_dump_json(indent=2))
def copy_to(
self,
folder: Path,
use_cached: bool = True,
include_reversed: bool = True,
prefix: str = "",
) -> "PresentationConfig":
"""Copy the files to a given directory."""
for slide_config in self.slides:
file = slide_config.file
rev_file = slide_config.rev_file
dest = folder / f"{prefix}{file.name}"
rev_dest = folder / f"{prefix}{rev_file.name}"
slide_config.file = dest
slide_config.rev_file = rev_dest
if not use_cached or not dest.exists():
shutil.copy(file, dest)
if include_reversed and (not use_cached or not rev_dest.exists()):
shutil.copy(rev_file, rev_dest)
return self