diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 000000000..a8fab7333 --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,20 @@ +name: pre-commit + +on: + push: + branches: [master] + pull_request: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + architecture: x64 + - name: Checkout Torchrec + uses: actions/checkout@v2 + - name: Run pre-commit + uses: pre-commit/action@v2.0.3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..c7f310e60 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-toml + - id: check-yaml + exclude: packaging/.* + - id: end-of-file-fixer + + - repo: https://github.com/omnilib/ufmt + rev: v1.3.0 + hooks: + - id: ufmt + additional_dependencies: + - black == 21.9b0 + - usort == 0.6.4 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..9f90a09d2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.usort] + +first_party_detection = false diff --git a/torchrec/distributed/planner/embedding_planner.py b/torchrec/distributed/planner/embedding_planner.py index dfe648049..30acb8e44 100644 --- a/torchrec/distributed/planner/embedding_planner.py +++ b/torchrec/distributed/planner/embedding_planner.py @@ -119,10 +119,7 @@ def plan( module=module, sharders=sharders, ) - unplaced_param_infos: List[Tuple[ParamSortKey, ParameterInfo]] = [ - (param_sort_key(param_info, self._world_size), param_info) - for param_info in param_infos - ] + unplaced_param_infos: List[Tuple[ParamSortKey, ParameterInfo]] = [(param_sort_key(param_info, self._world_size), param_info) for param_info in param_infos ] placed_param_infos: List[Tuple[ParamSortKey, ParameterInfo]] = [] heapq.heapify(unplaced_param_infos) @@ -175,14 +172,13 @@ def _log_stats( ] emb_dims = [param_info.param.shape[1]] if shard.sharding_type == ShardingType.ROW_WISE.value: - pooling_factor = [pooling_factor[0] / self._world_size] * len(ranks) + pooling_factor = [ + pooling_factor[0] / self._world_size] * len(ranks) emb_dims = emb_dims * len(ranks) elif shard.sharding_type == ShardingType.TABLE_ROW_WISE.value: # pyre-ignore [16] host_id = shard.ranks[0] // self._local_size - ranks = list( - range(host_id * self._local_size, (host_id + 1) * self._local_size) - ) + ranks = list(range(host_id * self._local_size, (host_id + 1) * self._local_size)) pooling_factor = [pooling_factor[0] / self._local_size] * len(ranks) emb_dims = emb_dims * len(ranks) elif shard.sharding_type == ShardingType.COLUMN_WISE.value: