Skip to content

Commit 4ae7a62

Browse files
authored
Run black/isort to fix linting (#117)
Summary: Test Plan:
1 parent 3dd5e9f commit 4ae7a62

File tree

6 files changed

+74
-12
lines changed

6 files changed

+74
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ From here there are two options: (1) load weights in our train script and (2) lo
9393
In your terminal:
9494

9595
```bash
96-
python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b hub --prompt "My test prompt"
96+
python -m bytelatent.hf load-transformers --entropy-repo facebook/blt-entropy --blt-repo facebook/blt-1b --prompt "My test prompt" hub
9797
```
9898

9999
In your own code:

bytelatent/config_parser.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from typing import Type, TypeVar, Any
2+
from typing import Any, Type, TypeVar
33

44
import omegaconf
55
from omegaconf import DictConfig, OmegaConf
@@ -69,15 +69,19 @@ def parse_args_with_default(
6969

7070
T = TypeVar("T", bound=BaseModel)
7171

72+
7273
def get_pydantic_default_args(args_cls: Type[T]) -> dict[str, Any]:
7374
defaults = {}
7475
for field, info in args_cls.model_fields.items():
7576
if info.default != PydanticUndefined:
7677
defaults[field] = info.default
7778
return defaults
7879

80+
7981
def parse_args_to_pydantic_model(
80-
args_cls: Type[T], cli_args: DictConfig | None = None, instantiate_default_cls: bool = True
82+
args_cls: Type[T],
83+
cli_args: DictConfig | None = None,
84+
instantiate_default_cls: bool = True,
8185
) -> T:
8286
if instantiate_default_cls:
8387
default_cfg = OmegaConf.create(args_cls().model_dump())

bytelatent/stool.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ class StoolArgs(BaseModel):
2626
dirs_exists_ok: bool = (
2727
False # Wether to copy new code and config and run regardless that dir exists
2828
)
29-
override: bool = False # Whether to delete dump dir and restart, requires confirmation
30-
force_override: bool = False # Does not require interaction
29+
override: bool = (
30+
False # Whether to delete dump dir and restart, requires confirmation
31+
)
32+
force_override: bool = False # Does not require interaction
3133
nodes: int = -1 # The number of nodes to run the job on.
3234
ngpu: int = 8 # The number of GPUs required per node.
3335
ncpu: int = 16 # The number of CPUs allocated per GPU.
@@ -43,7 +45,6 @@ class StoolArgs(BaseModel):
4345
dry_run: bool = False
4446

4547

46-
4748
def copy_dir(input_dir: str, output_dir: str) -> None:
4849
print(f"Copying : {input_dir}\n" f"to : {output_dir} ...")
4950
assert os.path.isdir(input_dir), f"{input_dir} is not a directory"
@@ -130,7 +131,9 @@ def launch_job(args: StoolArgs):
130131
job_name = args.name or args.model_conf["name"]
131132
dump_dir = os.path.join(args.dump_dir, job_name) or args.model_conf["dump_dir"]
132133
print("Creating directories...")
133-
os.makedirs(dump_dir, exist_ok=args.dirs_exists_ok or args.override or args.force_override)
134+
os.makedirs(
135+
dump_dir, exist_ok=args.dirs_exists_ok or args.override or args.force_override
136+
)
134137
if args.override or args.force_override:
135138
if args.force_override:
136139
shutil.rmtree(dump_dir)
@@ -161,10 +164,10 @@ def launch_job(args: StoolArgs):
161164
else ""
162165
)
163166
env = jinja2.Environment(
164-
loader=jinja2.PackageLoader('bytelatent'),
167+
loader=jinja2.PackageLoader("bytelatent"),
165168
autoescape=jinja2.select_autoescape(),
166169
)
167-
template = env.get_template('stool_template.sh.jinja')
170+
template = env.get_template("stool_template.sh.jinja")
168171
sbatch_jinja = template.render(
169172
name=job_name,
170173
script=args.script,

demo.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
def main(prompt: str, model_name: str = "blt-1b"):
14-
assert model_name in ['blt-1b', 'blt-7b']
15-
model_name = model_name.replace('-', '_')
14+
assert model_name in ["blt-1b", "blt-7b"]
15+
model_name = model_name.replace("-", "_")
1616
distributed_args = DistributedArgs()
1717
distributed_args.configure_world()
1818
if not torch.distributed.is_initialized():
@@ -27,7 +27,9 @@ def main(prompt: str, model_name: str = "blt-1b"):
2727
patcher_args = train_cfg.data.patcher_args.model_copy(deep=True)
2828
patcher_args.realtime_patching = True
2929
print("Loading entropy model and patcher")
30-
patcher_args.entropy_model_checkpoint_dir = os.path.join("hf-weights", "entropy_model")
30+
patcher_args.entropy_model_checkpoint_dir = os.path.join(
31+
"hf-weights", "entropy_model"
32+
)
3133
patcher = patcher_args.build()
3234
prompts = [prompt]
3335
outputs = generate_nocache(

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ pre_build = [
4646
]
4747
compile_xformers = ['xformers']
4848
dev = [
49+
"black==24.8.0",
4950
"ipython>=9.2.0",
51+
"isort>=6.0.1",
5052
"pudb>=2025.1",
5153
]
5254

uv.lock

Lines changed: 51 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)