Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance degradation #8

Closed
baoruxiao opened this issue Apr 17, 2018 · 13 comments
Closed

Performance degradation #8

baoruxiao opened this issue Apr 17, 2018 · 13 comments

Comments

@baoruxiao
Copy link

Hi,

I'm using the converted pretrain-weights to measure mIoU and observed 10% drop. Could you sure you measurement results, did you still get 84% mIoU on PASCAL using the transferred model and weights?

Thanks

@bonlime
Copy link
Owner

bonlime commented Apr 17, 2018

Hi,
I didn't check the mIoU on PASCAL. The only thing I checked is that result of this model on the single image scale is identical to original model.
I used original repo "deeplab_demo.ipynb" to check that, and got that all pixels are identical on all 3 images provided.
To get semantic predictions add:
prediction = np.argmax(prediction,axis=-1) where "predictions" are models output
According to original paper you should get 80.22% mIoU with single crop predictions, how many do you get?

@baoruxiao
Copy link
Author

Hi bonlime,

Thanks for the quick response. Actually I used your conversion to convert Tensorflow deeplabv3 trained on CityScapes, and measure the mIoU on CityScapes validation sets. They reported ~80% mIoU but I only got ~67%.

Is there something I need to be aware of if I use your conversion on CityScapes pretrained model?

I will compare the output of the same image, brilliant strategy.

Thanks

@bonlime
Copy link
Owner

bonlime commented Apr 17, 2018

How do you preprocess your image? The right way to do it is as follows:

  1. resize image so that max(height,width) = 513
  2. /127.5 - 1.
  3. pad smaller dimension from one side with zeros so that img shape is (513,513)

@baoruxiao
Copy link
Author

baoruxiao commented Apr 17, 2018

I used the same as you did, but cityscapes they use original cityscapes image size (1025, 2049) in evaluation, so what I did is only to do /127.5 - 1.
I refer to:

# From tensorflow/models/research/
python deeplab/eval.py
--logtostderr
--eval_split="val"
--model_variant="xception_65"
--atrous_rates=6
--atrous_rates=12
--atrous_rates=18
--output_stride=16
--decoder_output_stride=4
--eval_crop_size=1025
--eval_crop_size=2049
--dataset="cityscapes"
--checkpoint_dir=${PATH_TO_CHECKPOINT}
--eval_logdir=${PATH_TO_EVAL_DIR}
--dataset_dir=${PATH_TO_DATASET}

No worries, let me quickly check the output mask and get back to you. Thanks

@bonlime
Copy link
Owner

bonlime commented Apr 17, 2018

--decoder_output_stride=4
Just want to check, do you take this into account too?
You need to take the output before bilinear resize, and accordingly resize true labels

@baoruxiao
Copy link
Author

baoruxiao commented Apr 17, 2018

I didn't. Could you elaborate a little more why?
Just to confirm, you mean to use the 'x' right after the Dropout?

x = Concatenate()([b4, b0, b1, b2, b3])
x = Conv2D(256, (1, 1), padding='same',
           use_bias=False, name='concat_projection')(x)
x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x)
x = Activation('relu')(x)
x = Dropout(0.1)(x)

# DeepLab v.3+ decoder

# Feature projection
# x4 (x2) block
x = BilinearUpsampling(output_size=(int(np.ceil(input_shape[0] / 4)),
                                    int(np.ceil(input_shape[1] / 4))))(x)`

@bonlime
Copy link
Owner

bonlime commented Apr 17, 2018

decoder_output_stride = 4 means they use only one x4 (or x2 in case of OS=8) bilinear upsample, but don't use the last one.
you need to use "logits_semantic" layer output instead of original model output.
you can define such new model like this
from keras import Model
deeplab = Deeplabv3()
model = Model(deeplab.input, deeplab.layers[-2].output)
And in order to find IoU you also need to downsize true labels x0.25 of it original shape

@baoruxiao
Copy link
Author

baoruxiao commented Apr 17, 2018

Hi bonlime,

I notice that deeplab_demo.ipynb doesn't do input preprocessing. Am I correct?

width, height = image.size
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
batch_seg_map = self.sess.run(
    self.OUTPUT_TENSOR_NAME,
    feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
seg_map = batch_seg_map[0]
return resized_image, seg_map

Or they did it in tensorflow graph?

@baoruxiao
Copy link
Author

I run both model with my metrics and got the same output, seems my mIoU has some issue and doesn't match the one used by Tensorflow team.

@baoruxiao
Copy link
Author

The issue I found out is that:
Official Tensorflow evaluate mIoU by not considering the area labeled unknown (label_id = 255). see:
https://github.com/tensorflow/models/blob/5be198eca885fd967426ad5551c8374f1c3e889e/research/deeplab/eval.py#L126-L143

@bonlime
Copy link
Owner

bonlime commented Apr 18, 2018

Nice! So without unknown area, mIoU is the same?

@baoruxiao
Copy link
Author

baoruxiao commented Apr 18, 2018

Yes, the same or just 1~2% lower than they reported depends on if implement the input preprocessing exactly as they do.
I got 75.16% from the Keras model and they reported 78.79% with OS = 16.
I got 75% by evaluating their tensorflow model on my code.

@lucasdavid
Copy link

I was also having some trouble related to this... I'm not sure what the issue is, but I managed to get the expected mIoU. I'll leave my comments in case it helps anyone:

My input comprises batches of 32 images of size 513x513. Images are scaled with the method nearest such that is largest dimension is 513 and the shortest size is padded with the pixel intensity 128. They are then processed with f(x): return x/127.5 -1.

I used the default MeanIoU metric implementation of Keras, with a small change to handle unknown label (-1, in my case) and a y_pred being a probability vector:

class MeanIoU(tf.keras.metrics.MeanIoU):

  def __init__(
        self,
        num_classes,
        name=None,
        dtype=None,
        ignore_index: int = -1,
        sparse_predictions: bool = False,
    ):
    super().__init__(num_classes, name, dtype)

    self.ignore_index = ignore_index
    self.sparse_predictions = sparse_predictions

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.convert_to_tensor(y_true)
    y_pred = tf.convert_to_tensor(y_pred)

    if not self.sparse_predictions:
      y_pred = tf.argmax(y_pred, axis=-1)

    valid_mask = y_true != self.ignore_index
    y_true = y_true[valid_mask]
    y_pred = y_pred[valid_mask]

    return super().update_state(y_true, y_pred, sample_weight)

  def get_config(self):
    config = super().get_config()
    config.update({
      'ignore_index': self.ignore_index,
      'sparse_predictions': self.sparse_predictions,
    })

    return config

In this setting, I get mIoU = 66.62% over Pascal VOC 2012 validation set.

Then I fine-tuned the model for 20 epochs with learning_rate=0.0001 (recommended value in deeplab's repo), nesterov momentum and OS=8 over Pascal VOC 2012 augmented with Berkeley samples.
It converges to the lowest loss in the first epoch:
loss over time

And the final mIoU is 84.78% (value reported in the paper was 84.56%).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants