Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How big CPU/GPU memory need to run 3D ResNext-101 #231

Open
lily2lab opened this issue Dec 29, 2020 · 4 comments
Open

How big CPU/GPU memory need to run 3D ResNext-101 #231

lily2lab opened this issue Dec 29, 2020 · 4 comments

Comments

@lily2lab
Copy link

Hi, thanks for the wonderful repo!
I want to extract video features using the pre-trained model (resnext-101-kinetics.pth), but I got an out of memory errors (it shows it need 263G memory). Is there any wrong with what I did? I run it on 4 NVIDIA 1080Ti GPU and my machine memory is 128G, and use the following command:

python main.py --root_path data --video_path ucf101_jpg --annotation_path ucf101_01.json
--result_path results --dataset ucf101 --resume_path pretrained_models/resnext-101-kinetics.pth --model resnext
--model_depth 101 --n_classes 400 --n_threads 4 --no_train --no_val --inference --output_topk 5 --inference_batch_size 1

Can you help me? @kenshohara

@lily2lab
Copy link
Author

lily2lab commented Jan 2, 2021

the error shows as belows:

Namespace(accimage=False, annotation_path=PosixPath('data/ucf101_01.json'), arch='resnext-101', batch_size=32, batchnorm_sync=False, begin_epoch=1, checkpoint=5, colorjitter=False, conv1_t_size=7, conv1_t_stride=1, dampening=0.0, dataset='ucf101', dist_url='tcp://127.0.0.1:23456', distributed=False, file_type='jpg', ft_begin_module='fc', inference=False, inference_batch_size=32, inference_crop='center', inference_no_average=False, inference_stride=16, inference_subset='val', input_type='rgb', learning_rate=0.1, lr_scheduler='multistep', manual_seed=1, mean=[0.4345, 0.4051, 0.3775], mean_dataset='kinetics', model='resnext', model_depth=101, momentum=0.9, multistep_milestones=[50, 100, 150], n_classes=400, n_epochs=200, n_finetune_classes=101, n_input_channels=3, n_pretrain_classes=400, n_threads=4, n_val_samples=3, nesterov=False, no_cuda=False, no_hflip=False, no_max_pool=False, no_mean_norm=False, no_std_norm=False, no_train=False, no_val=False, optimizer='sgd', output_topk=5, overwrite_milestones=False, plateau_patience=10, pretrain_path=PosixPath('data/pretrained_models/resnext-101-kinetics.pth'), resnet_shortcut='B', resnet_widen_factor=1.0, resnext_cardinality=32, result_path=PosixPath('data/results'), resume_path=None, root_path=PosixPath('data'), sample_duration=16, sample_size=112, sample_t_stride=1, std=[0.2768, 0.2713, 0.2737], tensorboard=False, train_crop='random', train_crop_min_ratio=0.75, train_crop_min_scale=0.25, train_t_crop='random', value_scale=1, video_path=PosixPath('data/ucf101_jpg'), weight_decay=0.001, wide_resnet_k=2, world_size=-1)
Traceback (most recent call last):
File "main.py", line 432, in
main_worker(-1, opt)
File "main.py", line 341, in main_worker
model = generate_model(opt)
File "/home/data1_4t/lxl/action_prediction/3D-ResNets-PyTorch/model.py", line 76, in generate_model
no_max_pool=opt.no_max_pool)
File "/home/data1_4t/lxl/action_prediction/3D-ResNets-PyTorch/models/resnext.py", line 66, in generate_model
**kwargs)
File "/home/data1_4t/lxl/action_prediction/3D-ResNets-PyTorch/models/resnext.py", line 53, in init
shortcut_type, n_classes)
File "/home/data1_4t/lxl/action_prediction/3D-ResNets-PyTorch/models/resnet.py", line 132, in init
shortcut_type)
File "/home/data1_4t/lxl/action_prediction/3D-ResNets-PyTorch/models/resnet.py", line 190, in _make_layer
downsample=downsample))
File "/home/data1_4t/lxl/action_prediction/3D-ResNets-PyTorch/models/resnext.py", line 21, in init
super().init(inplanes, planes, stride, downsample)
File "/home/data1_4t/lxl/action_prediction/3D-ResNets-PyTorch/models/resnet.py", line 71, in init
self.conv2 = conv3x3x3(planes, planes, stride)
File "/home/data1_4t/lxl/action_prediction/3D-ResNets-PyTorch/models/resnet.py", line 19, in conv3x3x3
bias=False)
File "/home/data1_4t/lxl/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 443, in init
False, _triple(0), groups, bias)
File "/home/data1_4t/lxl/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 38, in init
out_channels, in_channels // groups, *kernel_size))
RuntimeError: $ Torch: not enough memory: you tried to allocate 263GB. Buy new RAM! at /opt/conda/conda-bld/pytorch_1544176307774/work/aten/src/TH/THGeneral.cpp:201

@missbook520
Copy link

I encountered the same problem, did you solve it?

@guilhermesurek
Copy link

guilhermesurek commented Mar 3, 2021

Hello,
There are four problems that I have found.

Problem 1: On the file resnext.py the widen_factor(1.0) was being replaced by n_classes (400) in the block ResNet on the file "resnet.py" class ResNet. On the line 118 [int(x * widen_factor) for x in block_inplanes] rather than [128, 256, 512, 1024] as defined on function get_inplanes on file "resnext.py", it was being multiplied by 400 or other n_classes passed, geting [51200, 102400, 204800, 4096000] in_planes and resulting out of memory.
Hypotesis: This happend because widenresnet might be seted up after resnext, and this parameter widen_factor was placed before n_classes.

How to solve:
On the file "resnext.py" line 51
Replace:

        super().__init__(block, layers, block_inplanes, n_input_channels,
                         conv1_t_size, conv1_t_stride, no_max_pool,
                         shortcut_type, n_classes)

By:

        super().__init__(block=block, layers=layers, block_inplanes=block_inplanes, n_input_channels=n_input_channels,
                         conv1_t_size=conv1_t_size, conv1_t_stride=conv1_t_stride, no_max_pool=no_max_pool,
                         shortcut_type=shortcut_type, n_classes=n_classes)

The Problem:

class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 block_inplanes,
                 n_input_channels=3,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 shortcut_type='B',
                 widen_factor=1.0,
                 n_classes=400):

Problem 2: The block ResNeXtBottleneck was seted with the parameter "inplanes" but ResNet _make_layer function uses "in_planes".

_make_layer function in resnet.py file:

        layers.append(
            block(in_planes=self.in_planes,
                  planes=planes,
                  stride=stride,
                  downsample=downsample))

How to solve:
On the file resnext.py line 16
Replace:

class ResNeXtBottleneck(Bottleneck):
    expansion = 2

    def __init__(self, inplanes, planes, cardinality, stride=1,
                 downsample=None):
        super().__init__(inplanes, planes, stride, downsample)

        mid_planes = cardinality * planes // 32
        self.conv1 = conv1x1x1(inplanes, mid_planes)
        self.bn1 = nn.BatchNorm3d(mid_planes)
        self.conv2 = nn.Conv3d(mid_planes,
                               mid_planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               groups=cardinality,
                               bias=False)
        self.bn2 = nn.BatchNorm3d(mid_planes)
        self.conv3 = conv1x1x1(mid_planes, planes * self.expansion)

By:

class ResNeXtBottleneck(Bottleneck):
    expansion = 2

    def __init__(self, in_planes, planes, cardinality, stride=1,
                 downsample=None):
        super().__init__(in_planes, planes, stride, downsample)

        mid_planes = cardinality * planes // 32
        self.conv1 = conv1x1x1(in_planes, mid_planes)
        self.bn1 = nn.BatchNorm3d(mid_planes)
        self.conv2 = nn.Conv3d(mid_planes,
                               mid_planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               groups=cardinality,
                               bias=False)
        self.bn2 = nn.BatchNorm3d(mid_planes)
        self.conv3 = conv1x1x1(mid_planes, planes * self.expansion)

Problem 3: Loading pretrained resnext-101-kinetics.pth missmatch the state_dict from the new ResNeXt. Old format started with module like this 'module.conv1.weight' and the new one is in this format 'conv1.weight'. Some, we need to remove all "module." from the state_dict.

How to solve:
On the file model.py line 97
Replace:

def load_pretrained_model(model, pretrain_path, model_name, n_finetune_classes):
    if pretrain_path:
        print('loading pretrained model {}'.format(pretrain_path))
        pretrain = torch.load(pretrain_path, map_location='cpu')

        model.load_state_dict(pretrain['state_dict'])
        tmp_model = model
        if model_name == 'densenet':
            tmp_model.classifier = nn.Linear(tmp_model.classifier.in_features,
                                             n_finetune_classes)
        else:
            tmp_model.fc = nn.Linear(tmp_model.fc.in_features,
                                     n_finetune_classes)

    return model

By:

def load_pretrained_model(model, pretrain_path, model_name, n_finetune_classes):
    if pretrain_path:
        print('loading pretrained model {}'.format(pretrain_path))
        pretrain = torch.load(pretrain_path, map_location='cpu')
        
        if pretrain_path.name == 'resnext-101-kinetics.pth':
            pretrain['state_dict'] = {str(key).replace("module.", "") : value for key, value in pretrain['state_dict'].items()}

        model.load_state_dict(pretrain['state_dict'])
        tmp_model = model
        if model_name == 'densenet':
            tmp_model.classifier = nn.Linear(tmp_model.classifier.in_features,
                                             n_finetune_classes)
        else:
            tmp_model.fc = nn.Linear(tmp_model.fc.in_features,
                                     n_finetune_classes)

    return model

Problem 4: Load state dict mismatch.

How to solve:
On the file resnet.py line 119
Replace:

self.in_planes = block_inplanes[0]

By:

self.in_planes = 64 #block_inplanes[0]

Important: do it just to run the pretrained resnext-101 model, then remove the harded code 64.

@JingwWu
Copy link

JingwWu commented Nov 18, 2023

Hello, There are four problems that I have found.

Problem 1: On the file resnext.py the widen_factor(1.0) was being replaced by n_classes (400) in the block ResNet on the file "resnet.py" class ResNet. On the line 118 [int(x * widen_factor) for x in block_inplanes] rather than [128, 256, 512, 1024] as defined on function get_inplanes on file "resnext.py", it was being multiplied by 400 or other n_classes passed, geting [51200, 102400, 204800, 4096000] in_planes and resulting out of memory. Hypotesis: This happend because widenresnet might be seted up after resnext, and this parameter widen_factor was placed before n_classes.

How to solve: On the file "resnext.py" line 51 Replace:

        super().__init__(block, layers, block_inplanes, n_input_channels,
                         conv1_t_size, conv1_t_stride, no_max_pool,
                         shortcut_type, n_classes)

By:

        super().__init__(block=block, layers=layers, block_inplanes=block_inplanes, n_input_channels=n_input_channels,
                         conv1_t_size=conv1_t_size, conv1_t_stride=conv1_t_stride, no_max_pool=no_max_pool,
                         shortcut_type=shortcut_type, n_classes=n_classes)

The Problem:

class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 block_inplanes,
                 n_input_channels=3,
                 conv1_t_size=7,
                 conv1_t_stride=1,
                 no_max_pool=False,
                 shortcut_type='B',
                 widen_factor=1.0,
                 n_classes=400):

Problem 2: The block ResNeXtBottleneck was seted with the parameter "inplanes" but ResNet _make_layer function uses "in_planes".

_make_layer function in resnet.py file:

        layers.append(
            block(in_planes=self.in_planes,
                  planes=planes,
                  stride=stride,
                  downsample=downsample))

How to solve: On the file resnext.py line 16 Replace:

class ResNeXtBottleneck(Bottleneck):
    expansion = 2

    def __init__(self, inplanes, planes, cardinality, stride=1,
                 downsample=None):
        super().__init__(inplanes, planes, stride, downsample)

        mid_planes = cardinality * planes // 32
        self.conv1 = conv1x1x1(inplanes, mid_planes)
        self.bn1 = nn.BatchNorm3d(mid_planes)
        self.conv2 = nn.Conv3d(mid_planes,
                               mid_planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               groups=cardinality,
                               bias=False)
        self.bn2 = nn.BatchNorm3d(mid_planes)
        self.conv3 = conv1x1x1(mid_planes, planes * self.expansion)

By:

class ResNeXtBottleneck(Bottleneck):
    expansion = 2

    def __init__(self, in_planes, planes, cardinality, stride=1,
                 downsample=None):
        super().__init__(in_planes, planes, stride, downsample)

        mid_planes = cardinality * planes // 32
        self.conv1 = conv1x1x1(in_planes, mid_planes)
        self.bn1 = nn.BatchNorm3d(mid_planes)
        self.conv2 = nn.Conv3d(mid_planes,
                               mid_planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               groups=cardinality,
                               bias=False)
        self.bn2 = nn.BatchNorm3d(mid_planes)
        self.conv3 = conv1x1x1(mid_planes, planes * self.expansion)

This really helps for me, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants