Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Integration test issues (#129)
Browse files Browse the repository at this point in the history
Summary:
Corrections:
- upgrade to fairscale 0.3.6 to remove assertion errors corrected since 0.3.5
- make sure that integration test errors are raised in the CI
- disable the failing integration tests due to issue facebookresearch/fairscale#643
- minor refactor of the regnet_fsdp to use the fsdp_wrapper directly

Pull Request resolved: fairinternal/ssl_scaling#129

Reviewed By: prigoyal

Differential Revision: D28120296

Pulled By: QuentinDuval

fbshipit-source-id: 0c23b639a9d7103aafa5f25b14c05042269df061
  • Loading branch information
QuentinDuval authored and facebook-github-bot committed Apr 30, 2021
1 parent 68ebbf9 commit a6ac886
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 15 deletions.
2 changes: 1 addition & 1 deletion dev/run_quick_tests.sh
Expand Up @@ -22,7 +22,7 @@ echo "========================================================================"

pushd "${SRC_DIR}/tests"
for test_file in "${TEST_LIST[@]}"; do
python -m unittest $test_file
python -m unittest $test_file || exit
done
popd

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
@@ -1,7 +1,7 @@
# Keep this sorted for easy search.
# Keep versions pinned for testing reproducibility.
cython==0.29.22
fairscale==0.3.5
fairscale==0.3.6
fvcore==0.1.3.post20210317
hydra-core==1.0.6
numpy==1.19.5
Expand Down
6 changes: 6 additions & 0 deletions tests/test_state_checkpointing.py
Expand Up @@ -159,6 +159,9 @@ def run_benchmarking(self, checkpoint_path: str, with_fsdp: bool):
results = run_integration_test(config)
return results.get_losses(), results.get_accuracies()

@unittest.skip(
"FAILING due to https://github.com/facebookresearch/fairscale/issues/643"
)
@gpu_test(gpu_count=2)
def test_benchmarking_from_a_consolidated_checkpoint(self):
with in_temporary_directory() as checkpoint_folder:
Expand All @@ -177,6 +180,9 @@ def test_benchmarking_from_a_consolidated_checkpoint(self):
self.assertEqual(ddp_losses, fsdp_losses)
self.assertEqual(ddp_accuracies, fsdp_accuracies)

@unittest.skip(
"FAILING due to https://github.com/facebookresearch/fairscale/issues/643"
)
@gpu_test(gpu_count=2)
def test_benchmarking_from_sharded_checkpoint(self):
with in_temporary_directory() as checkpoint_folder:
Expand Down
18 changes: 5 additions & 13 deletions vissl/models/trunks/regnet_fsdp.py
Expand Up @@ -37,14 +37,13 @@
from classy_vision.models.regnet import RegNetParams
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap
from vissl.config import AttrDict
from vissl.data.collators.collator_helper import MultiDimensionalTensor
from vissl.models.model_helpers import (
Flatten,
get_trunk_forward_outputs,
transform_model_input_data_type,
get_tunk_forward_interpolated_outputs,
transform_model_input_data_type,
)
from vissl.models.trunks import register_model_trunk
from vissl.utils.fsdp_utils import fsdp_wrapper
Expand Down Expand Up @@ -174,8 +173,7 @@ def create_block(
width_in, width_out, stride, params, bottleneck_multiplier, group_width
)
block = auto_wrap_bn(block, single_rank_pg=False)
with enable_wrap(wrapper_cls=fsdp_wrapper, **self.fsdp_config):
block = wrap(block)
block = fsdp_wrapper(module=block, **self.fsdp_config)
return block


Expand Down Expand Up @@ -284,20 +282,14 @@ def create_regnet_feature_blocks(factory: RegnetBlocksFactory, model_config):
params=params,
stage_index=i + 1,
)

if isinstance(factory, RegnetFSDPBlocksFactory):
if model_config.ACTIVATION_CHECKPOINTING.USE_ACTIVATION_CHECKPOINTING:
logging.info("Using activation checkpointing")
new_stage = checkpoint_wrapper(new_stage, offload_to_cpu=False)
new_stage = fsdp_wrapper(module=new_stage, **model_config.FSDP_CONFIG)

with enable_wrap(wrapper_cls=fsdp_wrapper, **model_config.FSDP_CONFIG):
new_stage = wrap(new_stage)

blocks.append(
(
f"block{i + 1}",
new_stage,
)
)
blocks.append((f"block{i + 1}", new_stage))
trunk_depth += blocks[-1][1].stage_depth
current_width = width_out

Expand Down

0 comments on commit a6ac886

Please sign in to comment.