Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: ZeRO without using shard_param #1082

Closed
powermano opened this issue Jun 9, 2022 · 15 comments
Closed

[BUG]: ZeRO without using shard_param #1082

powermano opened this issue Jun 9, 2022 · 15 comments
Labels
bug Something isn't working

Comments

@powermano
Copy link

🐛 Describe the bug

🐛 Describe the bug

When i use ZeRO without shard_params, it occurs the following problems

Traceback (most recent call last):
  File "train.py", line 175, in <module>
    main()
  File "train.py", line 39, in main
    with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=False):
  File "/usr/local/Python-3.8.6/lib/python3.8/site-packages/colossalai/zero/init_ctx/init_context.py", line 75, in __init__
    self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
  File "/usr/local/Python-3.8.6/lib/python3.8/site-packages/colossalai/zero/init_ctx/init_context.py", line 37, in __init__
    assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
AttributeError: 'int' object has no attribute 'type'

My init code is:

def main():
    parser = colossalai.get_default_parser()
    parser.add_argument('--use_trainer', action='store_true', help='whether to use trainer')
    args = parser.parse_args()

    colossalai.launch_from_torch(config='./config.py')

    logger = get_dist_logger()

    rank = int(os.environ['RANK'])
    # build resnet
    use_zero3 = hasattr(gpc.config, 'zero')
    if use_zero3:
        shard_strategy = TensorShardStrategy()
        with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=False):
            model = resnet34(num_classes=10)
    else:
        model = resnet34(num_classes=10)

my config is

from colossalai.amp import AMP_TYPE
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import HybridAdam

zero = dict(
    model_config=dict(
        tensor_placement_policy='cuda',
        shard_strategy=TensorShardStrategy(),
        reuse_fp16_shard=False
    ),
    optimizer_config=dict()
)

optimizer = dict(
    type=HybridAdam,
    lr=0.001,
    # weight_decay=1e-2,
)

BATCH_SIZE = 64
NUM_EPOCHS = 20
LOGGING_FREQUNCE = 20
OUTPUT = './'

gradient_clipping = 5.0

Environment

pip install colossalai==0.1.5+torch1.10cu11.1 -f https://release.colossalai.org

ubuntu 18.04

Environment

pip install colossalai==0.1.5+torch1.10cu11.1 -f https://release.colossalai.org

ubuntu 18.04

@powermano powermano added the bug Something isn't working label Jun 9, 2022
@powermano
Copy link
Author

If I modified the code as following, it actually worked.

 rank = int(os.environ['RANK'])
    # build resnet
  use_zero3 = hasattr(gpc.config, 'zero')
  if use_zero3:
      shard_strategy = TensorShardStrategy()
      
      # with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=False):
      #     model = resnet34(num_classes=10)
      with ZeroInitContext(target_device=torch.device('cuda', rank), shard_strategy=shard_strategy, shard_param=False):
          model = resnet34(num_classes=10)

@powermano
Copy link
Author

I do not know how to save the ZeRO model params. When using the save_checkpoint API , the saved file is pretty small.

@powermano
Copy link
Author

powermano commented Jun 9, 2022

From the ZeRO paper

ZeRO-DP has three main optimization stages (as depicted in Figure 1), which correspond to the partitioning of optimizer states, gradients, and parameters. When enabled cumulatively: 
1) Optimizer State Partitioning (Pos): 4x memory reduction, same communication volume as DP; 
2) Add Gradient Partitioning (Pos+g): 8x memory reduction, **same communication volume as DP;** 
3) Add Parameter Partitioning (Pos+g+p): Memory reduction is linear with DP degree Nd. For example, splitting across 64 GPUs (Nd = 64) will yield a 64x memory reduction. There is a modest 50% increase in communication volume. 

If using ZeRO without Parameter Partitioning, it has same communication volume as DP, but the training speed is lower than DP. The specific results is shown as following:

model dataset machine batch gradient accmulate size ZeRO speed GPU memory OPT tensor_placement_policy    
ir18 private dataset 2 64 2 no ZeRO 37%|███▋      | 1598/4274 [02:32<04:14, 10.50it/s] 9011M SGD   common data paralle  
ir18 private dataset 2 64 1 ZeRO + No shard params 14%|█▍        | 606/4275 [01:25<08:27,  7.23it/s] 9141M HybridAdam cuda    
ir18 private dataset 2 64 1 ZeRO + shard params 13%|█▎        | 571/4275 [01:32<10:32,  5.85it/s] 9073M HybridAdam cuda    
ir18 private dataset 2 64 1 ZeRO + shard params 5%|▌         | 217/4275 [01:37<29:16,  2.31it/s] 6819M HybridAdam cpu    

@ver217
Copy link
Member

ver217 commented Jun 9, 2022

If I modified the code as following, it actually worked.

 rank = int(os.environ['RANK'])
    # build resnet
  use_zero3 = hasattr(gpc.config, 'zero')
  if use_zero3:
      shard_strategy = TensorShardStrategy()
      
      # with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=False):
      #     model = resnet34(num_classes=10)
      with ZeroInitContext(target_device=torch.device('cuda', rank), shard_strategy=shard_strategy, shard_param=False):
          model = resnet34(num_classes=10)

Yes, you are right. target_device should be a torch.device.

@ver217
Copy link
Member

ver217 commented Jun 9, 2022

I do not know how to save the ZeRO model params. When using the save_checkpoint API , the saved file is pretty small.

ZeRO saves fp16 parameters now, so the size of the saved checkpoint is the half of the normal ones.

@powermano
Copy link
Author

powermano commented Jun 9, 2022

Thanks. But i checked the saved pt file, only bn params, other params are []. Maybe I maked some mistakes, i will further check it .

ZeRO saves fp16 parameters now, so the size of the saved checkpoint is the half of the normal ones.

I do not know how to save the ZeRO model params. When using the save_checkpoint API , the saved file is pretty small.

ZeRO saves fp16 parameters now, so the size of the saved checkpoint is the half of the normal ones.

@ver217
Copy link
Member

ver217 commented Jun 9, 2022

From the ZeRO paper

ZeRO-DP has three main optimization stages (as depicted in Figure 1), which correspond to the partitioning of optimizer states, gradients, and parameters. When enabled cumulatively: 
1) Optimizer State Partitioning (Pos): 4x memory reduction, same communication volume as DP; 
2) Add Gradient Partitioning (Pos+g): 8x memory reduction, **same communication volume as DP;** 
3) Add Parameter Partitioning (Pos+g+p): Memory reduction is linear with DP degree Nd. For example, splitting across 64 GPUs (Nd = 64) will yield a 64x memory reduction. There is a modest 50% increase in communication volume. 

If using ZeRO without Parameter Partitioning, it has same communication volume as DP, but the training speed is lower than DP. The specific results is shown as following:

model dataset machine batch gradient accmulate size ZeRO speed GPU memory OPT tensor_placement_policy    
ir18 private dataset 2 64 2 no ZeRO 37%|███▋      | 1598/4274 [02:32<04:14, 10.50it/s] 9011M SGD   common data paralle  
ir18 private dataset 2 64 1 ZeRO + No shard params 14%|█▍        | 606/4275 [01:25<08:27,  7.23it/s] 9141M HybridAdam cuda    
ir18 private dataset 2 64 1 ZeRO + shard params 13%|█▎        | 571/4275 [01:32<10:32,  5.85it/s] 9073M HybridAdam cuda    
ir18 private dataset 2 64 1 ZeRO + shard params 5%|▌         | 217/4275 [01:37<29:16,  2.31it/s] 6819M HybridAdam cpu

In theory, you are right. However, we don't optimize ZeRO without sharded parameters now. If you don't need shard parameters, you can just use DDP instead of ZeRO. We are implementing a new ZeRO, which is faster than current implementation. You can also try this when our work is done.

@powermano
Copy link
Author

Thanks. I have another question. As the tabel shown above, even I used shard params, the GPU memory only drop from 9141M to
9073M(tensor_placement_policy=cuda). Is this normal for 2gpu ?

@ver217
Copy link
Member

ver217 commented Jun 9, 2022

ir18

Could you tell the size of ir18?

@powermano
Copy link
Author

ir18

Could you tell the size of ir18?

the ir18 model is as following:

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn.functional import linear, normalize

using_ckpt = False

class CosFace(torch.nn.Module):
    def __init__(self, s=64.0, m=0.40):
        super(CosFace, self).__init__()
        self.s = s
        self.m = m

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        # labels = labels.squeeze_()
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]
        final_target_logit = target_logit - self.m
        logits[index, labels[index].view(-1)] = final_target_logit
        logits = logits * self.s
        return logits

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.downsample = downsample
        self.stride = stride

    def forard_impl(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out        

    def forward(self, x):
        if self.training and using_ckpt:
            return checkpoint(self.forard_imlp, x)
        else:
            return self.forard_impl(x)


class IResNet(nn.Module):
    fc_scale = 7 * 7
    def __init__(self,
                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False, input_size=(112,112), head=None, num_class=1000):
        super(IResNet, self).__init__()
        self.extra_gflops = 0.0
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        self.fc_input_dim = 132  # for input=(192, 168)

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        # self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
        self.fc = nn.Linear(512 * block.expansion * self.fc_input_dim, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

        if head is not None:
            assert head == 'cosface'
            self.head = CosFace(s=64, m=0.35)
            self.num_class = num_class
            self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_class, num_features)))
        else:
            self.head = None

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
            )
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation))

        return nn.Sequential(*layers)

    # @torch.cuda.amp.autocast()
    def forward(self, x, label):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.bn2(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        out_fc = self.features(x)
        norm_embeddings = normalize(out_fc)
        norm_weight_activated = normalize(self.weight_activated)
        x = linear(norm_embeddings, norm_weight_activated)
        x = self.head(x, label)
        return x, out_fc
        

def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        raise ValueError()
    return model


def iresnet18(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
                    progress, **kwargs)

@ver217
Copy link
Member

ver217 commented Jun 9, 2022

Thanks. I have another question. As the tabel shown above, even I used shard params, the GPU memory only drop from 9141M to 9073M(tensor_placement_policy=cuda). Is this normal for 2gpu ?

Hi, ir18 model has only 45783552 elements, which is 87MB. I think it's normal, since it's a small model. Sharding parameters won't save much GPU memory. The activation takes the most GPU memory. You can try activation offload in this case.

@powermano
Copy link
Author

Thanks. I have another question. As the tabel shown above, even I used shard params, the GPU memory only drop from 9141M to 9073M(tensor_placement_policy=cuda). Is this normal for 2gpu ?

Hi, ir18 model has only 45783552 elements, which is 87MB. I think it's normal, since it's a small model. Sharding parameters won't save much GPU memory. The activation takes the most GPU memory. You can try activation offload in this case.

Thanks. Will your new ZeRO optimize ZeRO without sharded parameters ?

@ver217
Copy link
Member

ver217 commented Jun 9, 2022

Yes,we will

@powermano
Copy link
Author

Can't wait to see your new version of ZeRO.

@Zjh-819
Copy link

Zjh-819 commented Jul 15, 2022

I do not know how to save the ZeRO model params. When using the save_checkpoint API , the saved file is pretty small.

ZeRO saves fp16 parameters now, so the size of the saved checkpoint is the half of the normal ones.

Hello, I met the same problem, I don't know how to save model's params properly when using ZeRO. I used the save_checkpoint API, but when I check the .pt file manually, I found all params are of value '[]' except some biases. How can I fix this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants