Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting nan prediction in demo_net.py with AVA on custom video #241

Closed
Serhii-Tiurin opened this issue Jul 11, 2020 · 16 comments
Closed

Getting nan prediction in demo_net.py with AVA on custom video #241

Serhii-Tiurin opened this issue Jul 11, 2020 · 16 comments
Labels
question Further information is requested

Comments

@Serhii-Tiurin
Copy link

Serhii-Tiurin commented Jul 11, 2020

When running python tools/run_net.py --cfg /home/ubuntu/slowfast/configs/AVA/c2/SLOWFAST_32x2_R101_50_50.yaml i am getting nan predictions in https://github.com/facebookresearch/SlowFast/blob/master/tools/demo_net.py#L242
Python version - 3.6.10 (on 3.7.7 didnt work too)
Pytorch - 1.5.1
Torchvision - 0.6.1
Cuda - 10.0
Detectron2 - 0.2
Here is config:
TRAIN:
ENABLE: False
DATASET: ava
BATCH_SIZE: 16
EVAL_PERIOD: 1
CHECKPOINT_PERIOD: 1
AUTO_RESUME: True
CHECKPOINT_FILE_PATH: "/home/ubuntu/SLOWFAST_32x2_R101_50_50.pkl"
CHECKPOINT_TYPE: pytorch
DATA:
NUM_FRAMES: 32
SAMPLING_RATE: 2
TRAIN_JITTER_SCALES: [256, 320]
TRAIN_CROP_SIZE: 224
TEST_CROP_SIZE: 256
INPUT_CHANNEL_NUM: [3, 3]
DETECTION:
ENABLE: True
ALIGNED: False
AVA:
BGR: False
DETECTION_SCORE_THRESH: 0.8
TEST_PREDICT_BOX_LISTS: ["person_box_67091280_iou90/ava_detection_val_boxes_and_labels.csv"]
SLOWFAST:
ALPHA: 4
BETA_INV: 8
FUSION_CONV_CHANNEL_RATIO: 2
FUSION_KERNEL_SZ: 5
RESNET:
ZERO_INIT_FINAL_BN: True
WIDTH_PER_GROUP: 64
NUM_GROUPS: 1
DEPTH: 101
TRANS_FUNC: bottleneck_transform
STRIDE_1X1: False
NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]]
SPATIAL_DILATIONS: [[1, 1], [1, 1], [1, 1], [2, 2]]
SPATIAL_STRIDES: [[1, 1], [2, 2], [2, 2], [1, 1]]
NONLOCAL:
LOCATION: [[[], []], [[], []], [[6, 13, 20], []], [[], []]]
GROUP: [[1, 1], [1, 1], [1, 1], [1, 1]]
INSTANTIATION: dot_product
POOL: [[[2, 2, 2], [2, 2, 2]], [[2, 2, 2], [2, 2, 2]], [[2, 2, 2], [2, 2, 2]], [[2, 2, 2], [2, 2, 2]]]
BN:
USE_PRECISE_STATS: False
NUM_BATCHES_PRECISE: 200
SOLVER:
MOMENTUM: 0.9
WEIGHT_DECAY: 1e-7
OPTIMIZING_METHOD: sgd
MODEL:
NUM_CLASSES: 80
ARCH: slowfast
MODEL_NAME: SlowFast
LOSS_FUNC: bce
DROPOUT_RATE: 0.5
HEAD_ACT: sigmoid
TEST:
ENABLE: False
DATASET: ava
BATCH_SIZE: 8
DATA_LOADER:
NUM_WORKERS: 2
PIN_MEMORY: True
DEMO:
ENABLE: True
LABEL_FILE_PATH: "./demo/AVA/ava.names"
DATA_SOURCE: "/home/ubuntu/gestures_dataset_right_wave_15.mp4"

DISPLAY_WIDTH: 640

DISPLAY_HEIGHT: 480

DETECTRON2_OBJECT_DETECTION_MODEL_CFG: "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
DETECTRON2_OBJECT_DETECTION_MODEL_WEIGHTS: "detectron2://COCO-Detection/faster_rcnn_R_50_FPN_3x/137849458/model_final_280758.pkl"
NUM_GPUS: 1
NUM_SHARDS: 1
RNG_SEED: 0
OUTPUT_DIR: .

And here is part of logs:
/home/ubuntu/slowfast/slowfast/models/head_helper.py:111: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert out.shape[2] == 1
/home/ubuntu/detectron2_repo/detectron2/layers/roi_align.py:105: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert rois.dim() == 2 and rois.size(1) == 5
/home/ubuntu/slowfast/slowfast/models/head_helper.py:111: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert out.shape[2] == 1
/home/ubuntu/detectron2_repo/detectron2/layers/roi_align.py:105: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert rois.dim() == 2 and rois.size(1) == 5
[WARNING: flop_count.py: 63]: Skipped operation aten::batch_norm 215 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::batch_norm 215 time(s)
[WARNING: flop_count.py: 63]: Skipped operation aten::relu_ 204 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::relu_ 204 time(s)
[WARNING: flop_count.py: 63]: Skipped operation aten::max_pool3d 7 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::max_pool3d 7 time(s)
[WARNING: flop_count.py: 63]: Skipped operation aten::add 69 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::add 69 time(s)
[WARNING: flop_count.py: 63]: Skipped operation aten::div 3 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::div 3 time(s)
[WARNING: flop_count.py: 63]: Skipped operation aten::avg_pool3d 2 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::avg_pool3d 2 time(s)
[WARNING: flop_count.py: 63]: Skipped operation prim::PythonOp 2 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation prim::PythonOp 2 time(s)
[WARNING: flop_count.py: 63]: Skipped operation aten::max_pool2d 2 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::max_pool2d 2 time(s)
[WARNING: flop_count.py: 63]: Skipped operation aten::dropout 1 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::dropout 1 time(s)
[WARNING: flop_count.py: 63]: Skipped operation aten::sigmoid 1 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.flop_count: 63: Skipped operation aten::sigmoid 1 time(s)
[INFO: misc.py: 160]: Flops: 146.54916608 G
[07/11 21:40:19][INFO] slowfast.utils.misc: 160: Flops: 146.54916608 G
[WARNING: activation_count.py: 54]: Skipped operation aten::batch_norm 215 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::batch_norm 215 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::relu_ 204 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::relu_ 204 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::max_pool3d 7 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::max_pool3d 7 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::add 69 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::add 69 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::einsum 6 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::einsum 6 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::div 3 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::div 3 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::avg_pool3d 2 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::avg_pool3d 2 time(s)
[WARNING: activation_count.py: 54]: Skipped operation prim::PythonOp 2 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation prim::PythonOp 2 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::max_pool2d 2 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::max_pool2d 2 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::dropout 1 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::dropout 1 time(s)
[WARNING: activation_count.py: 54]: Skipped operation aten::sigmoid 1 time(s)
[07/11 21:40:19][WARNING] fvcore.nn.activation_count: 54: Skipped operation aten::sigmoid 1 time(s)
[INFO: misc.py: 165]: Activations: 293.60136 M
[07/11 21:40:19][INFO] slowfast.utils.misc: 165: Activations: 293.60136 M
[INFO: misc.py: 168]: nvidia-smi
[07/11 21:40:19][INFO] slowfast.utils.misc: 168: nvidia-smi

Please feel free to ask any additional information if needed.

@haooooooqi haooooooqi added the question Further information is requested label Jul 12, 2020
@lequytra
Copy link
Contributor

Hey @Serhii-Tiurin,
Can you tell me where you saw the NaN predictions? Did you see the NaN prediction for the labels on the visualized video or did the program stopped with an error?

@littlefisherfisher
Copy link

i got the same problem

@haooooooqi
Copy link
Contributor

Hi,
Thanks for playing with PySlowFast!
@littlefisherfisher @Serhii-Tiurin could you share us more details helping us pinpointing the issue.
cc @lequytra

Thanks,
Haoqi

@irvingzhang0512
Copy link

same issue here.
in demo_net.py, add print(preds) after preds = model(inputs, boxes), get [[nan, ..., nan]]..

@lequytra
Copy link
Contributor

Thank you for bringing this issue to our attention.

I cannot reproduce the issue using my local environment, so I suspect this is an environment-dependent problem. My environment is:
Python 3.8.3
Pytorch - 1.5.0
Torchvision - 0.6.0
Cuda - 10.1
I installed the pre-built Detectron2 for the corresponding Pytorch and Cuda version.

What are inputs and boxes before being inputted into model?

@Serhii-Tiurin Do you mind sharing me your input video? I would like to test with it as well to make sure we are running demo_net.py with the exact same config.

@littlefisherfisher
Copy link

Thanks for your attention.
First, The CHECKPOINT_FILE is downloaded from model zoo, SlowFast | R101 | Kinetics 600 | 8 x 8 | 29.1 | 2.2 | link, is that right? config file is SLOWFAST_32X2_r101_50_50.yaml in the demo AVA file, I use the test video in ava(DaUzhc9_6io.mp4) as the input, and after preds = model(inputs, boxes), preds are nan, the inputs size is [1, 3, 8, 256, 341] and [1, 3, 32, 256, 341],like the following:
[[[[[ -2.0000, -2.0000, -2.0000, ..., 33.5556, 11.3333,
-2.0000],
[ -2.0000, -2.0000, -2.0000, ..., 33.5556, 11.3333,
-2.0000],
[ -2.0000, -2.0000, -2.0000, ..., 33.5556, 6.8889,
-2.0000],
...,
[ -2.0000, -2.0000, -2.0000, ..., -2.0000, -2.0000,
-2.0000],
[ -2.0000, -2.0000, -2.0000, ..., -2.0000, -2.0000,
-2.0000],
[ -2.0000, -2.0000, -2.0000, ..., -2.0000, -2.0000,
-2.0000]],

      [[ -2.0000,  -2.0000,  -2.0000,  ...,  33.5556,   6.8889,
         -2.0000],
       [ -2.0000,  -2.0000,  -2.0000,  ...,  33.5556,   6.8889,
         -2.0000],
       [ -2.0000,  -2.0000,  -2.0000,  ...,  33.5556,   6.8889,
         -2.0000],
       ...,
       [ -2.0000,  -2.0000,  -2.0000,  ...,  -2.0000,  -2.0000,
         -2.0000],
       [ -2.0000,  -2.0000,  -2.0000,  ...,  -2.0000,  -2.0000,
         -2.0000],
       [ -2.0000,  -2.0000,  -2.0000,  ...,  -2.0000,  -2.0000,
         -2.0000]],

...
def forward(self, x, bboxes=None):
x = self.s1(x)
x = self.s1_fuse(x)
x = self.s2(x)
x = self.s2_fuse(x)
for pathway in range(self.num_pathways):
pool = getattr(self, "pathway{}_pool".format(pathway))
x[pathway] = pool(x[pathway])
x = self.s3(x)
x = self.s3_fuse(x)
x = self.s4(x)
x = self.s4_fuse(x)
after x = self.s4(x), x are nan.
Thanks.

@irvingzhang0512
Copy link

after some debugging, maybe i could fix this issue.

frames are preprocessed by frame_processed = scale(cfg.DATA.TEST_CROP_SIZE, frame_processed) and tensor_normalize(torch.as_tensor(frames), cfg.DATA.MEAN, cfg.DATA.STD)

after scale, we get resized float32 images, pixels are in [0, 255.0]
when we normalize the resized images by tensor_normalize, mean and std are for pixels in range [0, 1.]

so add divide by 255 after scale could fix this issue.
frame_processed = scale(cfg.DATA.TEST_CROP_SIZE, frame_processed)/255.

@littlefisherfisher
Copy link

@irvingzhang0512 yes, you fixed the issue. Thanks.
and is your prediction boxes normal, i have many wrong person prediction boxes.

@irvingzhang0512
Copy link

@irvingzhang0512 yes, you fixed the issue. Thanks.
and is your prediction boxes normal, i have many wrong person prediction boxes.

by default, we get person bbox from the mid_frame (instead of the current/last frame) of all seq_len frames, so when visualizing the results, we get "wrong" bbox.

FYI, by default, demo_net.py predicts action labels every seq_len frames.

so if you want to get the "good" visualizing results, try to

  1. get detection results from the current frame: outputs = object_predictor(frame)
  2. predict action results every frame(instead of every "seq_len" frames): ucomment frames.pop(0), comment frames = []

However, this demo runs extramely slow...

@littlefisherfisher
Copy link

@irvingzhang0512 yes, you fixed the issue. Thanks.
and is your prediction boxes normal, i have many wrong person prediction boxes.

by default, we get person bbox from the mid_frame (instead of the current/last frame) of all seq_len frames, so when visualizing the results, we get "wrong" bbox.

FYI, by default, demo_net.py predicts action labels every seq_len frames.

so if you want to get the "good" visualizing results, try to

1. get detection results from the current frame: `outputs = object_predictor(frame)`

2. predict action results every frame(instead of every "seq_len" frames): ucomment `frames.pop(0)`, comment `frames = []`

However, this demo runs extramely slow...

Thanks for your reply.
demo runs 6.7s (2080ti) extramely slow, and i modified the code
inputs=tensor_normalize(torch.as_tensor(frames), cfg.DATA.MEAN, cfg.DATA.STD) to
inputs=tensor_normalize(torch.from_numpy(numpy.array(frames)), cfg.DATA.MEAN, cfg.DATA.STD),
it runs 0.12s.

@Serhii-Tiurin
Copy link
Author

Thanks all of you guys!
I can confirm that issue is fixed, also with small update from @littlefisherfisher demo is much faster.

@Serhii-Tiurin
Copy link
Author

@irvingzhang0512 thanks for update regarding object bboxes. Is there an option to make object bboxes every frame and action recognition on sequence of frames?

@irvingzhang0512
Copy link

@wwdok
Copy link

wwdok commented Nov 11, 2020

@Serhii-Tiurin
https://github.com/facebookresearch/SlowFast/blob/master/tools/demo_net.py#L293
option 1 may help

Code has changed ? There is no L293

@thezaza101
Copy link

thezaza101 commented Dec 1, 2020

This still seems to happen, however the solution proposed above by @irvingzhang0512 is no longer valid due to:

inputs = torch.from_numpy(np.array(frames)).float() / 255

It would be good to get some feedback from the authors on this.

Adding in the line proposed by @irvingzhang0512 does cause the input to be non nan however the prediction accuracy drops off (using a custom trained model using the AVA process). This is becouse if we divide the images by 255, the process_cv2_inputs function will nomalise the image again and cause the input range to be invalid.

@thezaza101
Copy link

Adding /255 to the images twice in the training and inference script (as per @irvingzhang0512) seems to be a workaround that works reasonably well:
https://github.com/thezaza101/SlowFast/blob/e3f8bff1b14c2d58af4ebf5d9029973858d343e9/slowfast/datasets/utils.py#L41

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

7 participants