Skip to content

Commit

Permalink
[refactory] add nn.parallel module (#1068)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Jun 6, 2022
1 parent 6754f1b commit 49832b2
Show file tree
Hide file tree
Showing 22 changed files with 44 additions and 46 deletions.
3 changes: 0 additions & 3 deletions colossalai/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@
from .model import *
from .optimizer import *
from ._ops import *

from .modules import ColoLinear, ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
3 changes: 0 additions & 3 deletions colossalai/nn/modules/__init__.py

This file was deleted.

3 changes: 3 additions & 0 deletions colossalai/nn/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .data_parallel import ColoDDP, ColoDDPV2

__all__ = ['ColoDDP', 'ColoDDPV2']
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.param_op_hook import use_param_op_hooks

__all__ = ['ColoDDP', 'ColoDDPV2']


def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
Expand Down
15 changes: 15 additions & 0 deletions colossalai/nn/parallel/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .colo_module import ColoModule
from .linear import ColoLinear
from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module

__all__ = [
'ColoModule',
'register_colo_module',
'is_colo_module',
'get_colo_module',
'init_colo_module',
'check_colo_module',
'ColoLinear',
'ColoEmbedding',
]
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec
from .modules import ColoModule
from . import ColoModule
import torch

_COLOSSAL_MODULES: Dict[type, ColoModule] = {}
Expand Down
4 changes: 0 additions & 4 deletions colossalai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector
from .model.utils import InsertPostInitMethodToModuleSubClasses
from .model.colo_init_context import ColoInitContext

__all__ = [
'checkpoint',
Expand Down Expand Up @@ -52,6 +50,4 @@
'disposable',
'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity',
'InsertPostInitMethodToModuleSubClasses',
'ColoInitContext',
]
2 changes: 1 addition & 1 deletion colossalai/utils/model/colo_init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from colossalai.tensor import ColoTensor, ColoParameter

from colossalai.nn import register_colo_module, init_colo_module, \
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding

from torch import nn
Expand Down
5 changes: 1 addition & 4 deletions colossalai/utils/model/pipelinable.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
import functools
import inspect
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.builder.pipeline import partition_uniform, partition_balanced
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoTensor

Expand Down
6 changes: 4 additions & 2 deletions colossalai/zero/init_ctx/init_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import functools
from typing import Optional
from contextlib import AbstractContextManager

import torch
import torch.nn as nn
import torch.distributed as dist

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
Expand All @@ -12,8 +15,7 @@
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from contextlib import AbstractContextManager
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses


class ZeroContextConfig(object):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.parallel.data_parallel import ColoDDPV2
from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.core import global_context as gpc
Expand Down
5 changes: 1 addition & 4 deletions tests/test_tensor/test_context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest
from colossalai.utils import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext

from numpy import allclose, require
import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy

from colossalai.utils.cuda import get_current_device

Expand Down
4 changes: 2 additions & 2 deletions tests/test_tensor/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP
from colossalai.nn.parallel.data_parallel import ColoDDP


def init_1d_row_spec(model):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_tensor/test_hybrid_device.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from colossalai.utils import free_port, ColoInitContext, get_current_device
from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.tensor import ComputePattern, ParallelAction

from functools import partial
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode

from colossalai.nn import init_colo_module
from colossalai.nn.parallel import ColoDDP
from colossalai.nn.parallel.layers import init_colo_module
from colossalai.nn.parallel.data_parallel import ColoDDP

import colossalai
import torch
Expand Down
6 changes: 3 additions & 3 deletions tests/test_tensor/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, DistSpecManager
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor/test_module_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.multiprocessing as mp

from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.nn import init_colo_module, check_colo_module
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed

import colossalai
Expand Down
8 changes: 4 additions & 4 deletions tests/test_tensor/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec, ColoParameter, ChunkManager
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
from colossalai.nn.parallel import ColoDDPV2
from colossalai.testing import parameterize


Expand Down
4 changes: 2 additions & 2 deletions tests/test_tensor/test_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2
Expand Down
5 changes: 0 additions & 5 deletions tests/test_utils/test_pipelinable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import os.path as osp

import pytest
import torch
import torch.multiprocessing as mp

from colossalai.utils.model.pipelinable import PipelinableContext

from functools import partial
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception

NUM_CHUNKS = 1
Expand Down

0 comments on commit 49832b2

Please sign in to comment.