Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Memory consumption by fp16 is not normal when using Engine. #1095

Closed
powermano opened this issue Jun 10, 2022 · 14 comments · Fixed by #1096
Closed

[BUG]: Memory consumption by fp16 is not normal when using Engine. #1095

powermano opened this issue Jun 10, 2022 · 14 comments · Fixed by #1096
Labels
bug Something isn't working

Comments

@powermano
Copy link

🐛 Describe the bug

when using colossalai.amp.convert_to_torch_amp to wrap the model, optimizer and criterion,

if not use_colossai_engine:
    model, optimizer, criterion =  colossalai.amp.convert_to_torch_amp(model, optimizer, criterion)

and then train normally, which also only consumes 4700M of memory.

output, _ = model(img, label)
train_loss = criterion(output, label)
optimizer.backward(train_loss)
optimizer.step()
optimizer.zero_grad()

But if you use colossalai.initialize to initialize, it will consume 7700M of memory. But we did see that by reading the fp16 parameter in config, in the initialization code of colossalai.initialize, the conversion of process colossalai.amp.convert_to_torch_amp is performed, and then we use the Engine for training, but it needs to consume 7700M of memory at this time. This is where I get confused.

engine.zero_grad()
output, _ = engine(img, label)
train_loss = engine.criterion(output, label)
engine.backward(train_loss)
engine.step()   

Environment

No response

@FrankLeeeee
Copy link
Contributor

Thanks! Give me some time to verify this. :)

@FrankLeeeee
Copy link
Contributor

Hi @powermano , appreciate your help to spot this tricky bug, it has been fixed in #1096 .

@powermano
Copy link
Author

You're welcome, I'm looking forward to your new ZeRO.

@FrankLeeeee
Copy link
Contributor

I will close this issue for now. Do keep us informed if you have further questions. :)

@powermano
Copy link
Author

@FrankLeeeee Sorry to bother you, have you even tested my training code? After modifying a few lines of code, the memory on my side dropped from 8.5G to 7.5G, but still can't reach 5.8G without using Engine.

@FrankLeeeee
Copy link
Contributor

Hi, I don't know which model/optimzier/loss/dataloader you use, so in issue #1096 , I used my own experiment configuration but the major code is the same as yours.

@FrankLeeeee
Copy link
Contributor

If you wish, you can share your full script with me, I can go test it.

@powermano
Copy link
Author

the config is

from colossalai.amp import AMP_TYPE

fp16 = dict(
    mode=AMP_TYPE.TORCH,
)

train_debug.py

from colossalai.logging import get_dist_logger
import colossalai
import torch
import os
from colossalai.core import global_context as gpc
from colossalai.nn.lr_scheduler import CosineAnnealingLR
from colossalai.amp.torch_amp import TorchAMPModel, TorchAMPLoss
import psutil
from colossalai.nn.optimizer import HybridAdam


def get_cpu_mem():
    return psutil.Process().memory_info().rss / 1024**2


def get_gpu_mem():
    return torch.cuda.memory_allocated() / 1024**2


def get_mem_info(prefix=''):
    return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'


import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.nn.functional import linear, normalize

using_ckpt = False

class CosFace(torch.nn.Module):
    def __init__(self, s=64.0, m=0.40):
        super(CosFace, self).__init__()
        self.s = s
        self.m = m

    def forward(self, logits: torch.Tensor, labels: torch.Tensor):
        # labels = labels.squeeze_()
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]
        final_target_logit = target_logit - self.m
        logits[index, labels[index].view(-1)] = final_target_logit
        logits = logits * self.s
        return logits

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.downsample = downsample
        self.stride = stride

    def forard_impl(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out        

    def forward(self, x):
        if self.training and using_ckpt:
            return checkpoint(self.forard_imlp, x)
        else:
            return self.forard_impl(x)


class IResNet(nn.Module):
    fc_scale = 7 * 7
    def __init__(self,
                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False, input_size=(112,112), head=None, num_class=1000):
        super(IResNet, self).__init__()
        self.extra_gflops = 0.0
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        self.fc_input_dim = 132  # for input=(192, 168)

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        # self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
        self.fc = nn.Linear(512 * block.expansion * self.fc_input_dim, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

        if head is not None:
            assert head == 'cosface'
            self.head = CosFace(s=64, m=0.35)
            self.num_class = num_class
            self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_class, num_features)))
        else:
            self.head = None

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
            )
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation))

        return nn.Sequential(*layers)

    # @torch.cuda.amp.autocast()
    def forward(self, x, label):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.bn2(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        out_fc = self.features(x)
        norm_embeddings = normalize(out_fc)
        norm_weight_activated = normalize(self.weight_activated)
        x = linear(norm_embeddings, norm_weight_activated)
        x = self.head(x, label)
        return x, out_fc
        

def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        raise ValueError()
    return model


def iresnet18(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
                    progress, **kwargs)
                

BATCH_SIZE = 32
NUM_EPOCHS = 20
NUM_STEPS = 2000
LOGGING_FREQUNCE = 100
OUTPUT = './work_dir/test_zero_bs_64/'
num_classes = 20000


use_colossai_engine = True

def main():
    parser = colossalai.get_default_parser()
    parser.add_argument('--use_trainer', action='store_true', help='whether to use trainer')
    parser.add_argument('--config_dir', type=str, default='./config.py', help='config for train fr')
    args = parser.parse_args()

    if use_colossai_engine:
        colossalai.launch_from_torch(config=args.config_dir)
    else:
        colossalai.launch_from_torch(config={})

    logger = get_dist_logger()
    
    # print(111111111111111111111111111111, gpc.config.fp16)
    # get dist info
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    torch.cuda.set_device(local_rank)
    os.makedirs(OUTPUT, exist_ok=True)

    # build resnet
    model = iresnet18(head='cosface', num_class=num_classes)

    # if not use_colossai_engine:
    #     model = TorchAMPModel(model)

    model.train().cuda()
    
    logger.info(get_mem_info(prefix='After init model, '), ranks=[0])

    # build dataloader
    def get_data(batch_size, num_classes, shape=(3, 192, 168)):
        datas = torch.randn((batch_size, )  + shape, device=torch.cuda.current_device())
        labels = torch.randint(0, num_classes - 1, (batch_size,), device=torch.cuda.current_device())
        return datas, labels

    # build criterion
    criterion = torch.nn.CrossEntropyLoss()

    # if not use_colossai_engine:
    #     criterion = TorchAMPLoss(criterion)

    optimizer = HybridAdam(model.parameters(), lr=0.001)

    # lr_scheduler
    lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS)


    if use_colossai_engine:
        engine, _, _, _ = colossalai.initialize(
            model,
            optimizer,
            criterion,
            )
        engine.train()
    else:
        # amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=1000)
        model, optimizer, criterion =  colossalai.amp.convert_to_torch_amp(model, optimizer, criterion)
    global_step = 0

    optimizer.zero_grad()
    for _ in range(NUM_STEPS):
        
        img, label = get_data(BATCH_SIZE, num_classes)
        global_step += 1
        img = img.cuda()
        label = label.cuda()
        if use_colossai_engine:
            engine.zero_grad()
            output, _ = engine(img, label)
            # output = engine(img)
            # train_loss = module_partial_fc(output, label)
            train_loss = engine.criterion(output, label)
            engine.backward(train_loss)
            engine.step()   
        else:
            output, _ = model(img, label)

            train_loss = criterion(output, label)
            # amp.scale(train_loss).backward()
            # amp.unscale_(optimizer)
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            # amp.step(optimizer)
            # amp.update()   

            optimizer.backward(train_loss)
            optimizer.step()

            optimizer.zero_grad()

        if global_step % LOGGING_FREQUNCE == 0:
                logger.info(
                f"global_step {global_step} -train loss: {train_loss:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}",
                ranks=[0])


if __name__ == '__main__':
    main()

@FrankLeeeee
Copy link
Contributor

Thanks, give me some time. I will keep you updated :)

@powermano
Copy link
Author

the command is

colossalai run --nproc_per_node 1 train_debug.py --config_dir $PATH_TO_YOUR_CONFIG

@FrankLeeeee FrankLeeeee reopened this Jun 13, 2022
@FrankLeeeee
Copy link
Contributor

Hi @powermano , I have run your code with/without engine. However, I do not observe any difference. I logged the memory usage like below.

if global_step % LOGGING_FREQUNCE == 0:
        logger.info(
            f"global_step {global_step} -train loss: {train_loss:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}",
            ranks=[0])
        logger.info(get_mem_info(), ranks=[0])

I have added torch.cuda.max_memory_allocated() to the log string.

The results are:

with engine:
Current GPU memory usage: 875.88 MB
max GPU memory usage: 2979.18 MB,
CPU memory usage: 4459.73 MB

without engine:
Current GPU memory usage: 875.88 MB, 
max GPU memory usage: 2979.18 MB,
CPU memory usage: 4462.37 MB

@powermano
Copy link
Author

I downloaded the latest version, the test is normal, it should be my problem with my version。

Installing collected packages: colossalai
  Attempting uninstall: colossalai
    Found existing installation: colossalai 0.1.4+torch1.10cu10.2
    Uninstalling colossalai-0.1.4+torch1.10cu10.2:
      Successfully uninstalled colossalai-0.1.4+torch1.10cu10.2
Successfully installed colossalai-0.1.6

The results are:

with Engine

[06/14/22 05:40:28] INFO     colossalai - colossalai - INFO: train_debug.py:337 
                             main                                               
                    INFO     colossalai - colossalai - INFO: global_step 200    
                             -train loss: 35.863, lr: 0.001                     
                    INFO     colossalai - colossalai - INFO: train_debug.py:339 
                             main                                               
                    INFO     colossalai - colossalai - INFO: GPU memory usage:  
                             2978.99 MB, CPU memory usage: 3120.56 MB

without engine:

[06/14/22 05:30:27] INFO     colossalai - colossalai - INFO: train_debug.py:336
                             main
                    INFO     colossalai - colossalai - INFO: global_step 200
                             -train loss: 35.667, lr: 0.001
                    INFO     colossalai - colossalai - INFO: train_debug.py:338
                             main
                    INFO     colossalai - colossalai - INFO: GPU memory usage:
                             2978.99 MB, CPU memory usage: 3122.28 MB

@powermano
Copy link
Author

I now close this issue.

@FrankLeeeee
Copy link
Contributor

Great, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants