Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add nn.parallel module #1068

Merged
merged 9 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from colossalai.tensor import ColoTensor, ColoParameter

from colossalai.nn import register_colo_module, init_colo_module, \
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding

from torch import nn
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
import functools
import inspect
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.context_manager.utils import InsertPostInitMethodToModuleSubClasses, call_to_str
from colossalai.builder.pipeline import partition_uniform, partition_balanced
from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoTensor

Expand Down
File renamed without changes.
3 changes: 0 additions & 3 deletions colossalai/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,3 @@
from .model import *
from .optimizer import *
from ._ops import *

from .modules import ColoLinear, ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
3 changes: 0 additions & 3 deletions colossalai/nn/modules/__init__.py

This file was deleted.

3 changes: 3 additions & 0 deletions colossalai/nn/parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .data_parallel import ColoDDP, ColoDDPV2

__all__ = ['ColoDDP', 'ColoDDPV2']
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from colossalai.tensor.chunk import ChunkManager, TensorState
from colossalai.tensor.param_op_hook import use_param_op_hooks

__all__ = ['ColoDDP', 'ColoDDPV2']


def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
Expand Down
15 changes: 15 additions & 0 deletions colossalai/nn/parallel/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .colo_module import ColoModule
from .linear import ColoLinear
from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module

__all__ = [
'ColoModule',
'register_colo_module',
'is_colo_module',
'get_colo_module',
'init_colo_module',
'check_colo_module',
'ColoLinear',
'ColoEmbedding',
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec
from .modules import ColoModule
from . import ColoModule
import torch

_COLOSSAL_MODULES: Dict[type, ColoModule] = {}
Expand Down
4 changes: 0 additions & 4 deletions colossalai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
from .timer import MultiTimer, Timer
from .tensor_detector import TensorDetector
from .model.utils import InsertPostInitMethodToModuleSubClasses
from .model.colo_init_context import ColoInitContext

__all__ = [
'checkpoint',
Expand Down Expand Up @@ -52,6 +50,4 @@
'disposable',
'colo_set_cpu_memory_capacity',
'colo_get_cpu_memory_capacity',
'InsertPostInitMethodToModuleSubClasses',
'ColoInitContext',
]
6 changes: 4 additions & 2 deletions colossalai/zero/init_ctx/init_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import contextlib
import functools
from typing import Optional
from contextlib import AbstractContextManager

import torch
import torch.nn as nn
import torch.distributed as dist

from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
Expand All @@ -12,8 +15,7 @@
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
from contextlib import AbstractContextManager
from colossalai.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.context_manager.utils import InsertPostInitMethodToModuleSubClasses


class ZeroContextConfig(object):
Expand Down
2 changes: 1 addition & 1 deletion colossalai/zero/zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.distributed as dist
from enum import Enum
from torch.optim import Optimizer
from colossalai.nn.parallel import ColoDDPV2
from colossalai.nn.parallel.data_parallel import ColoDDPV2
from typing import Dict
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.core import global_context as gpc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader
from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.context_manager.pipelinable import PipelinableContext
from tqdm import tqdm

from titans.dataloader.cifar10 import build_cifar
Expand Down
5 changes: 1 addition & 4 deletions tests/test_tensor/test_context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest
from colossalai.utils import ColoInitContext
from colossalai.context_manager.colo_init_context import ColoInitContext

from numpy import allclose, require
import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy

from colossalai.utils.cuda import get_current_device

Expand Down
4 changes: 2 additions & 2 deletions tests/test_tensor/test_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.context_manager.colo_init_context import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP
from colossalai.nn.parallel.data_parallel import ColoDDP


def init_1d_row_spec(model):
Expand Down
9 changes: 5 additions & 4 deletions tests/test_tensor/test_hybrid_device.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from colossalai.utils import free_port, ColoInitContext, get_current_device
from colossalai.utils import free_port, get_current_device
from colossalai.context_manager.colo_init_context import ColoInitContext
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.tensor import ComputePattern, ParallelAction

from functools import partial
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode

from colossalai.nn import init_colo_module
from colossalai.nn.parallel import ColoDDP
from colossalai.nn.parallel.layers import init_colo_module
from colossalai.nn.parallel.data_parallel import ColoDDP

import colossalai
import torch
Expand Down
6 changes: 3 additions & 3 deletions tests/test_tensor/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
from colossalai.context_manager.colo_init_context import ColoInitContext
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
ParallelAction, ColoTensor, DistSpecManager
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tensor/test_module_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import torch.multiprocessing as mp

from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
from colossalai.nn import init_colo_module, check_colo_module
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed

import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.context_manager.colo_init_context import ColoInitContext

from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor import distspec
Expand Down
8 changes: 4 additions & 4 deletions tests/test_tensor/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec, ColoParameter, ChunkManager
from colossalai.context_manager.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDP, ColoDDPV2
from colossalai.nn.parallel import ColoDDPV2
from colossalai.testing import parameterize


Expand Down
4 changes: 2 additions & 2 deletions tests/test_tensor/test_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils import ColoInitContext
from colossalai.context_manager.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from colossalai.core import global_context as gpc
from functools import partial
from _utils import tensor_equal, tensor_shard_equal, set_seed
from _utils import tensor_equal, set_seed
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.parallel import ColoDDPV2
Expand Down
7 changes: 1 addition & 6 deletions tests/test_utils/test_pipelinable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import os.path as osp

import pytest
import torch
import torch.multiprocessing as mp

from colossalai.utils.model.pipelinable import PipelinableContext
from colossalai.context_manager.pipelinable import PipelinableContext

from functools import partial
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception

NUM_CHUNKS = 1
Expand Down