Skip to content

Commit

Permalink
Move f/utils => f/internal; move testing libs to fair_dev/testing (#1004
Browse files Browse the repository at this point in the history
)
  • Loading branch information
crutcher committed Jun 12, 2022
1 parent 3b72794 commit 2350968
Show file tree
Hide file tree
Showing 83 changed files with 115 additions and 115 deletions.
2 changes: 1 addition & 1 deletion benchmarks/experimental/experimental_async_approaches.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
import torchtext
from torchtext.data.utils import get_tokenizer

from fair_dev.testing.testing import dist_init, get_worker_map
from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map

try:
from fairscale.optim import Adam # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import utils

from benchmarks.golden_configs.lm_wikitext2 import Pipe as lm_wikitext2
from fair_dev.testing.testing import dist_init
from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.utils.testing import dist_init

MPI_PORT = 29500
RPC_PORT = 29501
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion fairscale/utils/testing.py → fair_dev/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
import torch.multiprocessing as mp
import torch.nn as nn

from fairscale.internal import torch_version
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
from fairscale.utils import torch_version

if TYPE_CHECKING:
Base = nn.Module[Tensor]
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion fairscale/experimental/nn/distributed_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch import Tensor, nn
from torch.distributed import rpc

from fairscale.internal import torch_version
from fairscale.nn.pipe import microbatch
from fairscale.utils import torch_version

from .data import DataConsumer
from .graph import Node, PipelineModulesGraph
Expand Down
2 changes: 1 addition & 1 deletion fairscale/experimental/nn/ssd_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL

from fairscale.utils import torch_version
from fairscale.internal import torch_version

try:
from torch.utils._pytree import tree_map
Expand Down
2 changes: 1 addition & 1 deletion fairscale/experimental/nn/sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch.distributed as dist
from torch.distributed import ProcessGroup

from fairscale.internal import torch_version
from fairscale.nn.checkpoint import is_checkpointing, is_recomputing
from fairscale.utils import torch_version


def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion fairscale/nn/checkpoint/checkpoint_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn as nn
import torch.utils.checkpoint as torch_checkpoint

from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
from fairscale.internal.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors

from .checkpoint_utils import patch_batchnorm

Expand Down
14 changes: 7 additions & 7 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import (
from fairscale.internal.containers import apply_to_tensors
from fairscale.internal.parallel import (
ProcessGroupName,
chunk_and_pad,
enable_pytorch_sync_bn,
get_process_group_cached,
validate_process_group,
)
from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
from fairscale.internal.params import calc_grad_norm, recursive_copy_to_device
from fairscale.internal.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.internal.state_dict import replace_by_prefix_
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap

from . import fsdp_optim_utils as ou

Expand Down
2 changes: 1 addition & 1 deletion fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import torch.autograd.profiler as profiler
import torch.distributed as dist

from fairscale.internal.params import Workhandle, get_global_rank
from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS
from fairscale.utils.params import Workhandle, get_global_rank


def _trainable(param: torch.Tensor) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
import_ssd_offload = False
pass

from fairscale.utils.state_dict import replace_by_prefix_
from fairscale.internal.state_dict import replace_by_prefix_

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion fairscale/nn/pipe/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import torch

from fairscale.internal.object import pyobject_to_tensor, tensor_to_pyobject
from fairscale.nn.model_parallel import get_pipeline_parallel_group
from fairscale.utils.object import pyobject_to_tensor, tensor_to_pyobject

from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors

Expand Down
2 changes: 1 addition & 1 deletion fairscale/nn/pipe/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch.autograd
import torch.cuda

from fairscale.utils import torch_version
from fairscale.internal import torch_version

from . import microbatch
from .batchnorm import DeferredBatchNorm
Expand Down
2 changes: 1 addition & 1 deletion fairscale/optim/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.optim import Optimizer
from torch.optim.sgd import SGD

from fairscale.utils import torch_version
from fairscale.internal import torch_version


class _GeneralMultiDeviceReplicator(object):
Expand Down
2 changes: 1 addition & 1 deletion fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from torch.nn import Parameter
from torch.optim import SGD, Optimizer

from fairscale.internal.params import calc_grad_norm, get_global_rank, recursive_copy_to_device
from fairscale.nn.misc import ParamBucket
from fairscale.utils.params import calc_grad_norm, get_global_rank, recursive_copy_to_device

__all__ = ["OSS"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset

from fair_dev.testing.testing import get_worker_map, torch_spawn
from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.utils.testing import get_worker_map, torch_spawn


class MySGD(Optimizer):
Expand Down
2 changes: 1 addition & 1 deletion tests/experimental/nn/data_parallel/test_gossip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import torch.distributed
import torch.nn.functional as F

from fair_dev.testing.testing import skip_if_single_gpu, spawn_for_all_world_sizes
import fairscale.experimental.nn.data_parallel.gossip as gossip
from fairscale.utils.testing import skip_if_single_gpu, spawn_for_all_world_sizes

# Enfore CUBLAS reproducibility, see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
Expand Down
2 changes: 1 addition & 1 deletion tests/experimental/nn/test_auto_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn
import torch.nn as nn

from fairscale.utils import torch_version
from fairscale.internal import torch_version


class PositionalEncoding(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion tests/experimental/nn/test_mevo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import pytest
import torch

from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.experimental.nn import MEVO
from fairscale.experimental.nn.mevo import BaselineSoftmaxNllLoss, get_data
from fairscale.utils.testing import skip_if_no_cuda


@pytest.fixture(scope="session", params=[torch.float16, torch.float32])
Expand Down
4 changes: 2 additions & 2 deletions tests/experimental/nn/test_multiprocess_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import torch.multiprocessing as mp
import torch.nn as nn

from fair_dev.testing.testing import skip_if_single_gpu
from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_single_gpu
from fairscale.internal import torch_version

pytestmark = pytest.mark.skipif(
not torch.cuda.is_available() or torch_version() < (1, 9, 0),
Expand Down
4 changes: 2 additions & 2 deletions tests/experimental/nn/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import pytest
import torch

from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.experimental.nn.offload import OffloadModel
from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda
from fairscale.internal import torch_version

if torch_version() >= (1, 8, 0):
from fairscale.experimental.nn.auto_shard import shard_model
Expand Down
2 changes: 1 addition & 1 deletion tests/experimental/tooling/test_layer_memory_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel

from fair_dev.testing.testing import GPT2, dist_init, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
from fairscale.experimental.tooling.layer_memory_tracker import (
LayerwiseMemoryTracker,
ProcessGroupTracker,
find_best_reset_points,
)
from fairscale.nn import FullyShardedDataParallel
from fairscale.utils.testing import GPT2, dist_init, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx


@skip_if_no_cuda()
Expand Down
4 changes: 2 additions & 2 deletions tests/nn/checkpoint/test_checkpoint_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import torch.nn as nn
from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper

from fair_dev.testing.testing import skip_if_no_cuda
from fairscale.internal import torch_version
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper, disable_checkpointing
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda


def get_cuda_mem_allocated():
Expand Down
4 changes: 2 additions & 2 deletions tests/nn/checkpoint/test_checkpoint_activations_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from torch.optim import SGD

from fair_dev.testing.testing import objects_are_equal
from fairscale.internal import torch_version
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.utils import torch_version
from fairscale.utils.testing import objects_are_equal

NORM_TYPES = [LayerNorm, BatchNorm2d]
MP_TYPES = ["fp32", "fp16", "call_half"]
Expand Down
8 changes: 4 additions & 4 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
from torch import nn
import torch.distributed

from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import (
from fair_dev.testing.testing import (
DeviceAndTypeCheckModule,
DummyProcessGroup,
dist_init,
Expand All @@ -30,6 +27,9 @@
skip_a_test_if_in_CI,
spawn_for_all_world_sizes,
)
from fairscale.internal import torch_version
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState

if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler
Expand Down
2 changes: 1 addition & 1 deletion tests/nn/data_parallel/test_fsdp_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
import torch.nn as nn

from fairscale.utils import torch_version
from fairscale.internal import torch_version

from .test_fsdp import (
CONFIG_OPTIONS,
Expand Down
2 changes: 1 addition & 1 deletion tests/nn/data_parallel/test_fsdp_freezing_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

from fair_dev.testing.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, objects_are_equal, rmf, skip_if_single_gpu, teardown


class FreezeModel(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion tests/nn/data_parallel/test_fsdp_grad_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from parameterized import parameterized
import torch

from fair_dev.testing.testing import DummyProcessGroup, make_cudnn_deterministic, objects_are_equal
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import DummyProcessGroup, make_cudnn_deterministic, objects_are_equal

from .test_fsdp import DistributedTest, NestedWrappedModule, rename_test, spawn_and_init

Expand Down
2 changes: 1 addition & 1 deletion tests/nn/data_parallel/test_fsdp_hf_transformer_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from torch import nn

from fair_dev.testing.testing import dist_init
from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, enable_wrap
from fairscale.utils.testing import dist_init


def wrap_transformer_only(module, recurse, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions tests/nn/data_parallel/test_fsdp_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from torch.nn import Linear, Module
from torch.optim import SGD

from fair_dev.testing.testing import dist_init, rmf, skip_if_no_cuda, teardown
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, rmf, skip_if_no_cuda, teardown


# A fixture to get tempfiles and ensure they are cleaned up.
Expand Down
6 changes: 3 additions & 3 deletions tests/nn/data_parallel/test_fsdp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

from fair_dev.testing.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.internal import torch_version
from fairscale.internal.parallel import get_process_group_cached
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.utils import torch_version
from fairscale.utils.parallel import get_process_group_cached
from fairscale.utils.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx


def to_fsdp(module, fsdp_config):
Expand Down
2 changes: 1 addition & 1 deletion tests/nn/data_parallel/test_fsdp_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import torch.nn as nn
from torch.optim import Adam

from fair_dev.testing.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx
from fairscale.nn import FullyShardedDataParallel
from fairscale.utils.testing import in_temporary_directory, skip_if_single_gpu, temp_files_ctx
from tests.nn.data_parallel.test_fsdp import DistributedTest, MixtureOfExperts, rename_test, spawn_and_init

USE_TEMPFILE = True # False for debugging
Expand Down
4 changes: 2 additions & 2 deletions tests/nn/data_parallel/test_fsdp_multiple_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from torch.nn import Linear, Module
from torch.optim import SGD

from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown


def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from torch.nn.parallel import DistributedDataParallel
import torch.optim as optim

from fair_dev.testing.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
from fairscale.internal import torch_version
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx


class Model(nn.Module):
Expand Down

0 comments on commit 2350968

Please sign in to comment.