Skip to content

Commit

Permalink
chore: apply ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenanht committed Jan 25, 2024
1 parent c878edf commit e5668b5
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 13 deletions.
19 changes: 14 additions & 5 deletions john_toolbox/train/transformers/from_scratch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_
prob = model.project(out[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)],
[
decoder_input,
torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device),
],
dim=1,
)

Expand Down Expand Up @@ -97,7 +100,13 @@ def run_validation(
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

model_out = greedy_decode(
model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device
model,
encoder_input,
encoder_mask,
tokenizer_src,
tokenizer_tgt,
max_len,
device,
)

source_text = batch["src_text"][0]
Expand All @@ -110,9 +119,9 @@ def run_validation(

# Print the source, target and model output
print_msg("-" * console_width)
print_msg(f"{f'SOURCE: ':>12}{source_text}")
print_msg(f"{f'TARGET: ':>12}{target_text}")
print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
print_msg(f"{'SOURCE: ':>12}{source_text}")
print_msg(f"{'TARGET: ':>12}{target_text}")
print_msg(f"{'PREDICTED: ':>12}{model_out_text}")

if count == num_examples:
print_msg("-" * console_width)
Expand Down
92 changes: 84 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,87 @@ indent = 4
color_output = true

[tool.black]
line-length = 80
target-version = ['py38']
include = '\.pyi?$'
extend-exclude = '''
# A regex preceded with ^/ will apply only to files and directories
# in the root of the project.
^/foo.py # exclude a file named foo.py in the root of the project (in addition to the defaults)
'''
# https://github.com/psf/black
target-version = ['py39']
line-length = 100
color = true

exclude = '''
/(
\.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
| env
| venv
)/
'''

[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]

# Same as Black.
line-length = 100
indent-width = 4

# Assume Python 3.9
target-version = "py39"


[tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
select = ["E4", "E7", "E9", "F", "I"]
ignore = ["E501", "I001"]

# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.format]
# Like Black, use double quotes for strings.
quote-style = "double"

# Like Black, indent with spaces, rather than tabs.
indent-style = "space"

# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

0 comments on commit e5668b5

Please sign in to comment.