Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DenseNet in Chainer (caffemodel) CuDNN error #4426

Closed
ilkarman opened this issue Mar 3, 2018 · 5 comments
Closed

DenseNet in Chainer (caffemodel) CuDNN error #4426

ilkarman opened this issue Mar 3, 2018 · 5 comments
Assignees

Comments

@ilkarman
Copy link

ilkarman commented Mar 3, 2018

I haven't been able to find a native DenseNet implementation in Chainer so have used the caffemodel from shicai.

sym = caffe.CaffeFunction("DenseNet_121.caffemodel")
chainer.cuda.get_device(0).use()  # Make a specified GPU current
sym.to_gpu()  # Copy the model to the GPU
print(list(sym.layers)[-1])
# ('fc6', ['pool5'], ['fc6'])
dta = fake_input_data_cf[:8]
pred = sym(inputs={'data':cuda.to_gpu(dta)}, outputs=['fc6'])

However I get an error saying that cuDNN does like batch-norm with eps < 1e-5 but it doesn't appear to be in the model's protobuf, e.g.:


layer {
  name: "conv1/bn"
  type: "BatchNorm"
  bottom: "conv1"
  top: "conv1/bn"
  batch_norm_param {
    eps: 1e-5
  }
}

RuntimeError Traceback (most recent call last)
in ()
1 dta = fake_input_data_cf[:8]
----> 2 pred = sym(inputs={'data':dta}, outputs=['fc6'])

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/links/caffe/caffe_function.py in call(self, inputs, outputs, disable, **kwargs)
218 func = self.forwards[func_name]
219 input_vars = tuple(variables[blob] for blob in bottom)
--> 220 output_vars = func(*input_vars)
221 if not isinstance(output_vars, collections.Iterable):
222 output_vars = output_vars,

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/links/caffe/caffe_function.py in call(self, x)
567
568 def call(self, x):
--> 569 return self.func(x, *self.args, **self.kwargs)
570
571

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/links/caffe/caffe_function.py in call(self, *xs, **kwargs)
606
607 def call(self, *xs, **kwargs):
--> 608 return self.caffe_func[self.name](*xs, **kwargs)
609
610

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/links/normalization/batch_normalization.py in call(self, x, **kwargs)
142 ret = functions.batch_normalization(
143 x, gamma, beta, eps=self.eps, running_mean=self.avg_mean,
--> 144 running_var=self.avg_var, decay=decay)
145 else:
146 # Use running average statistics or fine-tuned statistics.

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/normalization/batch_normalization.py in batch_normalization(x, gamma, beta, **kwargs)
527 ('running_var', None), ('decay', 0.9))
528
--> 529 return BatchNormalization(eps, running_mean, running_var, decay).apply(
530 (x, gamma, beta))[0]
531

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/normalization/batch_normalization.py in init(self, eps, mean, var, decay)
30 if eps < 1e-5:
31 msg = 'cuDNN does not allow an eps value less than 1e-5.'
---> 32 raise RuntimeError(msg)
33 self.decay = decay
34

RuntimeError: cuDNN does not allow an eps value less than 1e-5.

My details are:

OS: linux
Python: 3.5.2 |Anaconda custom (64-bit)| (default, Jul 2 2016, 17:53:06)
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
Chainer: 3.4.0
CuPy: 2.4.0
Numpy: 1.14.1
GPU: ['Tesla P100-PCIE-16GB', 'Tesla P100-PCIE-16GB']
CUDA Version 8.0.61
CuDNN Version 6.0.21

It seems it's being imported incorrectly. Are there plans to add native Chainer pretrained densenet? I would use ONNX but unfortunately import to Chainer is not supported officially. Perhaps unofficially? #

@ilkarman ilkarman closed this as completed Mar 3, 2018
@ilkarman ilkarman reopened this Mar 3, 2018
@ilkarman
Copy link
Author

ilkarman commented Mar 3, 2018

If I manually truncate the bn layers:

for layer in list(sym._children):
    if "bn" in layer:
        print(layer)
        if sym.__dict__[layer].eps < 1e-5:
            sym.__dict__[layer].eps = 1e-5
            print('truncated')

I seem to get a completely different error:

CuDNNError Traceback (most recent call last)
in ()
----> 1 out = predict_fn(sym, fake_input_data_cf, 4)

in predict_fn(classifier, data, batchsize)
6 pred = classifier(inputs={'data':cuda.to_gpu(dta)},
7 outputs=['fc6'])
----> 8 out[idx*batchsize:(idx+1)*batchsize] = cuda.to_cpu(pred['fc6'].data).squeeze()
9 return out

/anaconda/envs/py35/lib/python3.5/contextlib.py in exit(self, type, value, traceback)
75 value = type()
76 try:
---> 77 self.gen.throw(type, value, traceback)
78 raise RuntimeError("generator didn't stop after throw()")
79 except StopIteration as exc:

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/configuration.py in using_config(name, value, config)
123 setattr(config, name, value)
124 try:
--> 125 yield
126 finally:
127 delattr(config, name)

in predict_fn(classifier, data, batchsize)
6 pred = classifier(inputs={'data':cuda.to_gpu(dta)},
7 outputs=['fc6'])
----> 8 out[idx*batchsize:(idx+1)*batchsize] = cuda.to_cpu(pred['fc6'].data).squeeze()
9 return out

/anaconda/envs/py35/lib/python3.5/contextlib.py in exit(self, type, value, traceback)
75 value = type()
76 try:
---> 77 self.gen.throw(type, value, traceback)
78 raise RuntimeError("generator didn't stop after throw()")
79 except StopIteration as exc:

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/configuration.py in using_config(name, value, config)
123 setattr(config, name, value)
124 try:
--> 125 yield
126 finally:
127 delattr(config, name)

in predict_fn(classifier, data, batchsize)
5 for idx, dta in yield_mb_X(data, batchsize):
6 pred = classifier(inputs={'data':cuda.to_gpu(dta)},
----> 7 outputs=['fc6'])
8 out[idx*batchsize:(idx+1)*batchsize] = cuda.to_cpu(pred['fc6'].data).squeeze()
9 return out

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/links/caffe/caffe_function.py in call(self, inputs, outputs, disable, **kwargs)
218 func = self.forwards[func_name]
219 input_vars = tuple(variables[blob] for blob in bottom)
--> 220 output_vars = func(*input_vars)
221 if not isinstance(output_vars, collections.Iterable):
222 output_vars = output_vars,

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/links/caffe/caffe_function.py in call(self, x)
567
568 def call(self, x):
--> 569 return self.func(x, *self.args, **self.kwargs)
570
571

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/pooling/average_pooling_2d.py in average_pooling_2d(x, ksize, stride, pad)
161
162 """
--> 163 return AveragePooling2D(ksize, stride, pad, False).apply((x,))[0]

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/function_node.py in apply(self, inputs)
243 self._input_indexes_to_retain = None
244 self._output_indexes_to_retain = None
--> 245 outputs = self.forward(in_data)
246 assert type(outputs) is tuple
247

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/function_node.py in forward(self, inputs)
335 assert len(inputs) > 0
336 if isinstance(inputs[0], cuda.ndarray):
--> 337 return self.forward_gpu(inputs)
338 return self.forward_cpu(inputs)
339

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/pooling/average_pooling_2d.py in forward_gpu(self, x)
25 if chainer.should_use_cudnn('>=auto'):
26 self.retain_inputs((0,))
---> 27 return super(AveragePooling2D, self).forward_gpu(x)
28
29 self._in_shape = x[0].shape

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/pooling/pooling_2d.py in forward_gpu(self, x)
57
58 handle = cudnn.get_handle()
---> 59 pool_desc = self.create_pool_desc()
60 x_desc = cudnn.create_tensor_descriptor(x)
61 y_desc = cudnn.create_tensor_descriptor(y)

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/pooling/average_pooling_2d.py in create_pool_desc(self)
67 return cuda.cudnn.create_pooling_descriptor(
68 (self.kh, self.kw), (self.sy, self.sx), (self.ph, self.pw),
---> 69 cuda.cuda.cudnn.CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING)
70
71

cupy/cudnn.pyx in cupy.cudnn.create_pooling_descriptor()

cupy/cuda/cudnn.pyx in cupy.cuda.cudnn.setPooling2dDescriptor_v4()

cupy/cuda/cudnn.pyx in cupy.cuda.cudnn.check_status()

CuDNNError: CUDNN_STATUS_BAD_PARAM: b'CUDNN_STATUS_BAD_PARAM'

@kmaehashi
Copy link
Member

I confirmed that RuntimeError: cuDNN does not allow an eps value less than 1e-5 issue, but couldn't reproduce CUDNN_STATUS_BAD_PARAM error.
Could you please share the full code to reproduce this?

@ilkarman
Copy link
Author

ilkarman commented Mar 7, 2018

Sure, sorry about that:

import os
import sys
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
from chainer.links import caffe
from chainer import cuda

sym = caffe.CaffeFunction("DenseNet_121.caffemodel")
chainer.cuda.get_device(0).use()  # Make a specified GPU current
sym.to_gpu()  # Copy the model to the GPU
print(list(sym.layers)[-1])
#('fc6', ['pool5'], ['fc6'])
# Truncate batch-norm
for layer in list(sym._children):
    if "bn" in layer:
        #print(layer)
        if sym.__dict__[layer].eps < 1e-5:
            sym.__dict__[layer].eps = 1e-5
            #print('truncated')

dta = np.random.rand(8, 3, 224, 224).astype(np.float32)
# ('fc6', ['pool5'], ['fc6'])
pred = sym(inputs={'data':cuda.to_gpu(dta)}, 
           outputs=['pool5'])

# Produces
---------------------------------------------------------------------------
CuDNNError                                Traceback (most recent call last)
<ipython-input-6-76369d2f0f95> in <module>()
      1 # ('fc6', ['pool5'], ['fc6'])
      2 pred = sym(inputs={'data':cuda.to_gpu(dta)}, 
----> 3            outputs=['pool5'])

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/links/caffe/caffe_function.py in __call__(self, inputs, outputs, disable, **kwargs)
    218             func = self.forwards[func_name]
    219             input_vars = tuple(variables[blob] for blob in bottom)
--> 220             output_vars = func(*input_vars)
    221             if not isinstance(output_vars, collections.Iterable):
    222                 output_vars = output_vars,

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/links/caffe/caffe_function.py in __call__(self, x)
    567 
    568     def __call__(self, x):
--> 569         return self.func(x, *self.args, **self.kwargs)
    570 
    571 

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/pooling/average_pooling_2d.py in average_pooling_2d(x, ksize, stride, pad)
    161 
    162     """
--> 163     return AveragePooling2D(ksize, stride, pad, False).apply((x,))[0]

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/function_node.py in apply(self, inputs)
    243             self._input_indexes_to_retain = None
    244             self._output_indexes_to_retain = None
--> 245             outputs = self.forward(in_data)
    246             assert type(outputs) is tuple
    247 

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/function_node.py in forward(self, inputs)
    335         assert len(inputs) > 0
    336         if isinstance(inputs[0], cuda.ndarray):
--> 337             return self.forward_gpu(inputs)
    338         return self.forward_cpu(inputs)
    339 

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/pooling/average_pooling_2d.py in forward_gpu(self, x)
     25         if chainer.should_use_cudnn('>=auto'):
     26             self.retain_inputs((0,))
---> 27             return super(AveragePooling2D, self).forward_gpu(x)
     28 
     29         self._in_shape = x[0].shape

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/pooling/pooling_2d.py in forward_gpu(self, x)
     57 
     58         handle = cudnn.get_handle()
---> 59         pool_desc = self.create_pool_desc()
     60         x_desc = cudnn.create_tensor_descriptor(x)
     61         y_desc = cudnn.create_tensor_descriptor(y)

/anaconda/envs/py35/lib/python3.5/site-packages/chainer/functions/pooling/average_pooling_2d.py in create_pool_desc(self)
     67         return cuda.cudnn.create_pooling_descriptor(
     68             (self.kh, self.kw), (self.sy, self.sx), (self.ph, self.pw),
---> 69             cuda.cuda.cudnn.CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING)
     70 
     71 

cupy/cudnn.pyx in cupy.cudnn.create_pooling_descriptor()

cupy/cuda/cudnn.pyx in cupy.cuda.cudnn.setPooling2dDescriptor_v4()

cupy/cuda/cudnn.pyx in cupy.cuda.cudnn.check_status()

CuDNNError: CUDNN_STATUS_BAD_PARAM: b'CUDNN_STATUS_BAD_PARAM'

My versions:

OS:  linux
Python:  3.5.2 |Anaconda custom (64-bit)| (default, Jul  2 2016, 17:53:06) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
Chainer:  3.4.0
CuPy:  2.4.0
Numpy:  1.14.1
GPU:  ['Tesla P100-PCIE-16GB', 'Tesla P100-PCIE-16GB']
CUDA Version 8.0.61
CuDNN Version  6.0.21

Edit: I tried loading the caffe resnet-50 (ResNet-50-model.caffemodel) model this way. I had to again truncate batchnorm-eps but after that it worked (without the error that the densenet model produces). Could this be because Shicai mentions: "ceil_mode: false is used in the first pooling layers ('pool1')"

Architecture

@ilkarman
Copy link
Author

ilkarman commented Mar 7, 2018

Ah whoops it does seem like shicai had to modify the Caffe code.

@kmaehashi
Copy link
Member

This issue occurs because global pooling is used in this layer:
https://github.com/shicai/DenseNet-Caffe/blob/master/DenseNet_121.prototxt#L4743-L4752

Global pooling is supported in Chainer v4.0.0b3 or later.
#3098
#4161

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants