Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Memory consumption by fp16 is not normal #1083

Closed
powermano opened this issue Jun 9, 2022 · 26 comments
Closed

[BUG]: Memory consumption by fp16 is not normal #1083

powermano opened this issue Jun 9, 2022 · 26 comments
Labels
bug Something isn't working

Comments

@powermano
Copy link

🐛 Describe the bug

When i used pytorch origin amp, the gpu memory is much smaller than colossai, why?
the config is

from colossalai.amp import AMP_TYPE
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.nn.optimizer import HybridAdam

fp16 = dict(
    mode=AMP_TYPE.TORCH,
)

optimizer = dict(
    type=HybridAdam,
    lr=0.001,
    # weight_decay=1e-2,
)
model dataset machine batch gradient accmulate size ZeRO speed GPU memory OPT tensor_placement_policy    
ir18 private dataset 1 64 1 no ZeRO 24%|██▍       | 2089/8549 [02:51<08:39, 12.43it/s] 8703M HybridAdam   single machine + Engine  
ir18 private dataset 1 64 1 no ZeRO 19%|█▊        | 1599/8549 [02:24<10:21, 11.17it/s] 5769M HybridAdam   single machine  + wo Engine + pytorch origin fp16  

Environment

No response

@powermano powermano added the bug Something isn't working label Jun 9, 2022
@powermano
Copy link
Author

The difference between my origin pytorch implementation and colossai is convert_to_amp API which using TorchAMPModel to decorate the origin model.
I have tested three different cases:

1 using torch.cuda.amp.autocast(True) inside model forward function:

class model(nn.Module):
    def __init__():
        ....
    def forward(self, x, label):
        with torch.cuda.amp.autocast(True):
              .....
              .....
        return x

2 using @torch.cuda.amp.autocast()

class model(nn.Module):
    def __init__():
        ....
    @torch.cuda.amp.autocast()
    def forward(self, x, label):
              .....
              .....
        return x

3 using TorchAMPModel

class model(nn.Module):
    def __init__():
        ....
    def forward(self, x, label):
              .....
              .....
        return x

model = model()
model = TorchAMPModel(model)

The first two are normal and only need 5769M GPU memory, but the third one needs 8703M GPU memory

@feifeibear
Copy link
Contributor

TorchAMPModel of ColossaiAI is the same as @torch.cuda.amp.autocast(cache_enable=True).
You can set cache_enable=False when aligning the memory usage.

@powermano
Copy link
Author

TorchAMPModel of ColossaiAI is the same as @torch.cuda.amp.autocast(cache_enable=True). You can set cache_enable=False when aligning the memory usage.

The problem now is that using the same model, using pytorch+fp16 takes less memory than colossai+fp16. (5769M vs 8703M)

@feifeibear
Copy link
Contributor

I try to reproduce your problem but failed. Could you provide you code for reference?

from torch import nn
import torch
from colossalai.amp.torch_amp import TorchAMPModel

class SimpleNet(nn.Module):
    """
    In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
    """

    def __init__(self) -> None:
        super().__init__()
        self.embed = nn.Embedding(20, 4)
        self.proj1 = nn.Linear(4, 8)
        self.ln1 = nn.LayerNorm(8)
        self.proj2 = nn.Linear(8, 4)
        self.ln2 = nn.LayerNorm(4)
        self.classifier = nn.Linear(4, 4)

    def forward(self, x):
        x = self.embed(x)
        x = self.proj1(x)
        x = self.ln1(x)
        x = self.proj2(x)
        x = self.ln2(x)
        x = self.classifier(x)
        return x

class SimpleNetAMP(nn.Module):
    """
    In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
    """

    def __init__(self) -> None:
        super().__init__()
        self.embed = nn.Embedding(20, 4)
        self.proj1 = nn.Linear(4, 8)
        self.ln1 = nn.LayerNorm(8)
        self.proj2 = nn.Linear(8, 4)
        self.ln2 = nn.LayerNorm(4)
        self.classifier = nn.Linear(4, 4)

    @torch.cuda.amp.autocast()
    def forward(self, x):
        x = self.embed(x)
        x = self.proj1(x)
        x = self.ln1(x)
        x = self.proj2(x)
        x = self.ln2(x)
        x = self.classifier(x)
        return x
use_colo = True # False
if use_colo:
    model = SimpleNet().cuda()
    model = TorchAMPModel(model)
    # 5632 B
else:
    model = SimpleNetAMP().cuda()
    # 5632 B
print(torch.cuda.max_memory_allocated())

@powermano
Copy link
Author

powermano commented Jun 9, 2022

I try to reproduce your problem but failed. Could you provide you code for reference?

from torch import nn
import torch
from colossalai.amp.torch_amp import TorchAMPModel

class SimpleNet(nn.Module):
    """
    In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
    """

    def __init__(self) -> None:
        super().__init__()
        self.embed = nn.Embedding(20, 4)
        self.proj1 = nn.Linear(4, 8)
        self.ln1 = nn.LayerNorm(8)
        self.proj2 = nn.Linear(8, 4)
        self.ln2 = nn.LayerNorm(4)
        self.classifier = nn.Linear(4, 4)

    def forward(self, x):
        x = self.embed(x)
        x = self.proj1(x)
        x = self.ln1(x)
        x = self.proj2(x)
        x = self.ln2(x)
        x = self.classifier(x)
        return x

class SimpleNetAMP(nn.Module):
    """
    In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
    """

    def __init__(self) -> None:
        super().__init__()
        self.embed = nn.Embedding(20, 4)
        self.proj1 = nn.Linear(4, 8)
        self.ln1 = nn.LayerNorm(8)
        self.proj2 = nn.Linear(8, 4)
        self.ln2 = nn.LayerNorm(4)
        self.classifier = nn.Linear(4, 4)

    @torch.cuda.amp.autocast()
    def forward(self, x):
        x = self.embed(x)
        x = self.proj1(x)
        x = self.ln1(x)
        x = self.proj2(x)
        x = self.ln2(x)
        x = self.classifier(x)
        return x
use_colo = True # False
if use_colo:
    model = SimpleNet().cuda()
    model = TorchAMPModel(model)
    # 5632 B
else:
    model = SimpleNetAMP().cuda()
    # 5632 B
print(torch.cuda.max_memory_allocated())

My test code is as following,
use_colossai_engine to control whether to use colossai engine.

colossalai run --nproc_per_node 1 train_debug.py --config_dir PATH_TO_YOUR_CONFIG

the config file is :

from colossalai.amp import AMP_TYPE
fp16 = dict(
    mode=AMP_TYPE.TORCH,
)
gradient_clipping = 5.0

train_debug.py

from colossalai.logging import get_dist_logger
import colossalai
import torch
import os
from colossalai.core import global_context as gpc
from colossalai.nn.lr_scheduler import CosineAnnealingLR
from colossalai.amp.torch_amp import TorchAMPModel
import psutil
from colossalai.nn.optimizer import HybridAdam


def get_cpu_mem():
    return psutil.Process().memory_info().rss / 1024**2


def get_gpu_mem():
    return torch.cuda.memory_allocated() / 1024**2


def get_mem_info(prefix=''):
    return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'


import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn.functional import linear, normalize

using_ckpt = False

class CosFace(torch.nn.Module):
    def __init__(self, s=64.0, m=0.40):
        super(CosFace, self).__init__()
        self.s = s
        self.m = m

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        # labels = labels.squeeze_()
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]
        final_target_logit = target_logit - self.m
        logits[index, labels[index].view(-1)] = final_target_logit
        logits = logits * self.s
        return logits

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.downsample = downsample
        self.stride = stride

    def forard_impl(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out        

    def forward(self, x):
        if self.training and using_ckpt:
            return checkpoint(self.forard_imlp, x)
        else:
            return self.forard_impl(x)


class IResNet(nn.Module):
    fc_scale = 7 * 7
    def __init__(self,
                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False, input_size=(112,112), head=None, num_class=1000):
        super(IResNet, self).__init__()
        self.extra_gflops = 0.0
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        self.fc_input_dim = 132  # for input=(192, 168)

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        # self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
        self.fc = nn.Linear(512 * block.expansion * self.fc_input_dim, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

        if head is not None:
            assert head == 'cosface'
            self.head = CosFace(s=64, m=0.35)
            self.num_class = num_class
            self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_class, num_features)))
        else:
            self.head = None

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
            )
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation))

        return nn.Sequential(*layers)

    # @torch.cuda.amp.autocast()
    def forward(self, x, label):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.bn2(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        out_fc = self.features(x)
        norm_embeddings = normalize(out_fc)
        norm_weight_activated = normalize(self.weight_activated)
        x = linear(norm_embeddings, norm_weight_activated)
        x = self.head(x, label)
        return x, out_fc
        

def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        raise ValueError()
    return model


def iresnet18(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
                    progress, **kwargs)
                

BATCH_SIZE = 32
NUM_EPOCHS = 20
NUM_STEPS = 2000
LOGGING_FREQUNCE = 100
OUTPUT = './work_dir/test_zero_bs_64/'
num_classes = 20000


use_colossai_engine = False

def main():
    parser = colossalai.get_default_parser()
    parser.add_argument('--use_trainer', action='store_true', help='whether to use trainer')
    parser.add_argument('--config_dir', type=str, default='./config.py', help='config for train fr')
    args = parser.parse_args()

    if use_colossai_engine:
        colossalai.launch_from_torch(config=args.config_dir)
    else:
        colossalai.launch_from_torch(config={})

    logger = get_dist_logger()

    # get dist info
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    torch.cuda.set_device(local_rank)
    os.makedirs(OUTPUT, exist_ok=True)

    # build resnet
    model = iresnet18(head='cosface', num_class=num_classes)

    if not use_colossai_engine:
        model = TorchAMPModel(model)

    model.train().cuda()
    
    logger.info(get_mem_info(prefix='After init model, '), ranks=[0])

    # build dataloader
    def get_data(batch_size, num_classes, shape=(3, 192, 168)):
        datas = torch.randn((batch_size, )  + shape, device=torch.cuda.current_device())
        labels = torch.randint(0, num_classes - 1, (batch_size,), device=torch.cuda.current_device())
        return datas, labels

    # build criterion
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = HybridAdam(model.parameters(), lr=0.001)

    # lr_scheduler
    lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS)


    if use_colossai_engine:
        engine, _, _, _ = colossalai.initialize(
            model,
            optimizer,
            criterion,
            )
        engine.train()
    else:
        amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=1000)
    global_step = 0

    optimizer.zero_grad()
    for _ in range(NUM_STEPS):
        
        img, label = get_data(BATCH_SIZE, num_classes)
        global_step += 1
        img = img.cuda()
        label = label.cuda()
        if use_colossai_engine:
            engine.zero_grad()
            output, _ = engine(img, label)
  
            train_loss = engine.criterion(output, label)
            engine.backward(train_loss)
            engine.step()   
        else:
            output, _ = model(img, label)

            train_loss = criterion(output, label)
            amp.scale(train_loss).backward()
            amp.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            amp.step(optimizer)
            amp.update()   

            optimizer.zero_grad()

        if global_step % LOGGING_FREQUNCE == 0:
                logger.info(
                f"global_step {global_step} -train loss: {train_loss:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}",
                ranks=[0])


if __name__ == '__main__':
    main()

@powermano
Copy link
Author

use_colossai_engine = True. During the training process, the gpu memory is 8200M.
use_colossai_engine = False. During the training process, the gpu memory is 5281M.

@feifeibear
Copy link
Contributor

See this line. So you are training the torch version without amp. Is that what you want?

# @torch.cuda.amp.autocast()

@powermano
Copy link
Author

powermano commented Jun 9, 2022

See this line. So you are training the torch version without amp. Is that what you want?

# @torch.cuda.amp.autocast()

I have verified that using TorchAMPModel or directly setting @torch.cuda.amp.autocast(cache_enable=True) is the same.

So the conclusion of #1083 (comment) is wrong.

I want training the torch version with amp, as you can seen in this line.

if not use_colossai_engine:
        model = TorchAMPModel(model)

@feifeibear
Copy link
Contributor

That makes sense. Thanks for helping us to check the functionality of AMP. I close this issue.

@powermano
Copy link
Author

That makes sense. Thanks for helping us to check the functionality of AMP. I close this issue.

That was not what I meant. What I want to express is that defining fp16 in config and then initializing the model with colossalai.initialize will consume about 8G of memory. Instead of using colossalai.initialize to initialize the model, training directly with TorchAMPModel only consumes 5.7G of memory.

I would like to know what is the reason for this extra memory consumption. Is it a problem with using the Engine?

You can test the code above.

@feifeibear
Copy link
Contributor

As we have checked that AMP is correct, so I close the issue.
You now argue that the engine is using more memory during training right?
I suppose that you open another issue #1082 to discuss the problem? If that's not what I suppose, we just reopen the issue.

@powermano
Copy link
Author

powermano commented Jun 10, 2022

The issue #1082 is related to ZeRO and has been resolved.
I now argue that the engine is using more memory during training。
Can you help to check this phenomenon?

@feifeibear
Copy link
Contributor

I can not run your code. What is your config.py? How did you inspect the cuda memory?

@powermano
Copy link
Author

config.py

from colossalai.amp import AMP_TYPE
fp16 = dict(
    mode=AMP_TYPE.TORCH,
)

I inspected the cuda memory using watch -n 1 nvidia-smi

@powermano
Copy link
Author

the command is

colossalai run --nproc_per_node 1 train_debug.py --config_dir $PATH_TO_YOUR_CONFIG

@FrankLeeeee
Copy link
Contributor

Hi, did you apply colossalai.amp.convert_to_torch_amp? This API will wrap model, optimizer and criterion.

@feifeibear
Copy link
Contributor

@FrankLeeeee is right. In the not use colossalai engine situation, you should also wrap the optimizer. That will align the memory usage.

@powermano
Copy link
Author

Hi, did you apply colossalai.amp.convert_to_torch_amp? This API will wrap model, optimizer and criterion.

colossalai.initialize will automaticly use colossalai.amp.convert_to_torch_amp to wrap model, optimizer and criterion.

As you can seen in the colossalai.initialize

 # check amp and zero
  fp16_cfg = gpc.config.get('fp16', None)

  if fp16_cfg is not None and fp16_cfg.mode is not None and use_zero:
      raise ConfigException(
          "It is not allowed to set fp16 and zero configuration in your config file at the same time")

  # clip grad norm
  clip_grad_norm = gpc.config.get('clip_grad_norm', 0.0)

  # initialize amp
  amp_mode = None
  if fp16_cfg is not None and fp16_cfg.mode is not None:
      cfg_ = fp16_cfg.copy()
      amp_mode = cfg_.pop('mode')
      if is_using_pp():
          assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
      if amp_mode == AMP_TYPE.NAIVE:
          cfg_['clip_grad_norm'] = clip_grad_norm
      model, optimizer, criterion = convert_to_amp(model=model,
                                                   optimizer=optimizer,
                                                   criterion=criterion,
                                                   mode=amp_mode,
                                                   amp_config=cfg_)

@FrankLeeeee
Copy link
Contributor

FrankLeeeee commented Jun 10, 2022

Yes, I mean in your experiment without using colossalai.initialize, did you apply torch amp to the optimizer?

@powermano
Copy link
Author

Yes, I mean in your two experiments without using colossalai.initialize, did you apply torch amp to the optimizer?

of course.

if not use_colossai_engine:
    model = TorchAMPModel(model)

without using colossalai.initialize, the training process only consume 5700M gpu memory. If using colossalai.initialize, it takes 8000M gpu memory.

@powermano
Copy link
Author

powermano commented Jun 10, 2022

I inserted a print func into the colossalai.initialize and it actually printed what i wants.

print(22222222222222222222222222222222222222222222, 'convert_to_amp')
model, optimizer, criterion = convert_to_amp(model=model,
                                             optimizer=optimizer,
                                             criterion=criterion,
                                             mode=amp_mode,
                                             amp_config=cfg_)

console output:;

 WARNING  colossalai - colossalai - WARNING: /usr/local/pytho
                             n/lib/python3.6/site-packages/colossalai/initialize
                             .py:304 initialize                                 
                    WARNING  colossalai - colossalai - WARNING: Initializing an 
                             non ZeRO model with optimizer class                
22222222222222222222222222222222222222222222 convert_to_amp
                    WARNING  colossalai - colossalai - WARNING: /usr/local/pytho
                             n/lib/python3.6/site-packages/colossalai/initialize
                             .py:440 initialize                                 
                    WARNING  colossalai - colossalai - WARNING: No PyTorch DDP  
                             or gradient handler is set up, please make sure you
                             do not need to all-reduce the gradients after a    
                             training step.                                     
[06/10/22 03:15:42] INFO     colossalai - colossalai - INFO: train_debug.py:328 
                             main                                               
                    INFO     colossalai - colossalai - INFO: global_step 100    
                             -train loss: 36.324, lr: 0.001                     

This means that the conversion of colossalai.amp.convert_to_torch_amp is indeed done through the initialization of colossalai.initialize . But it consumes more gpu memory.

@FrankLeeeee
Copy link
Contributor

I think there is some misunderstanding. What I mean is that you should not only wrap your model with torch amp but also your optimzier. In AMP, we need to handle gradient with mixed precision as well. Wrapping only your model but not your optimizer will lose the gradient handling part. That's why you see a lower memory usage when colossalai.initailize is not used. If you include your optimizer, it will be the same. I have attached a sample code below for your reference.

if not use_colossai_engine:
    model, optimizer, criterion =  colossalai.amp.convert_to_torch_amp(model, optimizer, criterion)

@powermano
Copy link
Author

powermano commented Jun 10, 2022

I do not think so, i have wraped criterion.

if not use_colossai_engine:
        criterion = TorchAMPLoss(criterion)

And the code below is the a pytorch way to wrap optimizer.

if use_colossai_engine:
    engine, _, _, _ = colossalai.initialize(
        model,
        optimizer,
        criterion,
        )
    engine.train()
else:
    amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=1000)

global_step = 0

optimizer.zero_grad()
for _ in range(NUM_STEPS):
    
    img, label = get_data(BATCH_SIZE, num_classes)
    global_step += 1
    img = img.cuda()
    label = label.cuda()
    if use_colossai_engine:
        engine.zero_grad()
        output, _ = engine(img, label)
        # output = engine(img)
        # train_loss = module_partial_fc(output, label)
        train_loss = engine.criterion(output, label)
        engine.backward(train_loss)
        engine.step()   
    else:
        output, _ = model(img, label)

        train_loss = criterion(output, label)
        amp.scale(train_loss).backward()
        amp.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
        amp.step(optimizer)
        amp.update()

        optimizer.zero_grad()

In this way, it still comsumes less gpu memory. But this is same as TorchAMPOptimizer(ColossalaiOptimizer).
I am confused. What is the problem?

@powermano
Copy link
Author

I may not have expressed clearly, as you said above, use colossalai.amp.convert_to_torch_amp to wrap the model, optimizer and criterion,

if not use_colossai_engine:
    model, optimizer, criterion =  colossalai.amp.convert_to_torch_amp(model, optimizer, criterion)

and then train normally, which also only consumes 4700M of memory.

output, _ = model(img, label)
train_loss = criterion(output, label)
optimizer.backward(train_loss)
optimizer.step()
optimizer.zero_grad()

But if you use colossalai.initialize to initialize, it will consume 7700M of memory. But we did see that by reading the fp16 parameter in config, in the initialization code of colossalai.initialize, the conversion of process colossalai.amp.convert_to_torch_amp is performed, and then we use the Engine for training, but it needs to consume 7700M of memory at this time. This is where I get confused.

engine.zero_grad()
output, _ = engine(img, label)
train_loss = engine.criterion(output, label)
engine.backward(train_loss)
engine.step()   

@FrankLeeeee
Copy link
Contributor

I see, I am writing a script to reproduce this problem. Can you open a new issue so that we can move our discussion there?

@powermano
Copy link
Author

OK, i will open a new issue #1095

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants