diff --git a/.github/workflows/ufmt.yml b/.github/workflows/ufmt.yml index 20cec70d..f4fa6be2 100644 --- a/.github/workflows/ufmt.yml +++ b/.github/workflows/ufmt.yml @@ -20,7 +20,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install ufmt + pip install black==23.3.0 usort==1.0.6 ufmt==2.1.0 libcst==1.0.1 - name: Analyzing the code with ufmt run: | ufmt check . diff --git a/README.md b/README.md index c18193f5..b094c4f0 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,9 @@ will change rapidly. pip install . # Optionally install editable pip install -e . + +# Optionally Install dev tooling +pip install -e ".[dev]" ``` # User API, subject to change diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 12bd9d76..54287a3f 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -111,7 +111,6 @@ def sync_float8_amax_and_scale_history( # Lazy import to avoid circular dependency if fp8_classes is None: - fp8_classes = Float8Linear for name, child in model.named_modules(): diff --git a/pyproject.toml b/pyproject.toml index f687defd..8d084b4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,19 +16,23 @@ classifiers = [ ] dependencies = [ - "torch >= 2.0.1", - "transformers==4.32.0", - "fire==0.5.0", - "fairscale==0.4.13", - "tqdm==4.66.1", - "pandas >= 2.0", + "torch >= 2.1", + "fairscale==0.4.13" ] [project.optional-dependencies] +test = [ + "transformers==4.32.0", + "pandas >= 2.0", + "tqdm==4.66.1", + "fire==0.5.0" +] dev = [ - "black", - "usort", - "libcst", + "black==23.3.0", + "usort==1.0.6", + "ufmt==2.1.0", + "libcst==1.0.1", + "pytest==7.4.0", "bumpver", "pip-tools" ]