Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Select ResNet-50 model for 2017-Q3 #13

Closed
cconvey opened this issue Aug 2, 2017 · 3 comments
Closed

Select ResNet-50 model for 2017-Q3 #13

cconvey opened this issue Aug 2, 2017 · 3 comments

Comments

@cconvey
Copy link
Contributor

cconvey commented Aug 2, 2017

Choose a specific ResNet-50 model implementation for our Q3 work.

This choice is subject to change in the future, for example based on future discussions with the Benchmark team. But it needs to be a reasonable starting point for our development work.

The tangible work-product of this Issue should be a file (perhaps some scripts, perhaps a README, or both) giving the details of this choice.

@yxlao
Copy link
Member

yxlao commented Aug 3, 2017

Summary

  • Selected Google's benchmark scripts implementation for ResNet50.
  • List of XLA ops needed:
    ['add', 'arg', 'bitcast', 'broadcast', 'concatenate', 'constant', 
     'convolution', 'copy', 'divide', 'dot', 'equal', 'exponential', 
     'fusion', 'get', 'greater', 'less', 'log', 'logical', 'map', 'maximum', 
     'multiply', 'negate', 'pad', 'reduce', 'reshape', 'reverse', 'select',
     'subtract', 'transpose', 'tuple']
    
  • Updates from Aug 8's standup:
    • ResNet 56 for Cifar dataset could be achived first, before ResNet 50 on i1k.

Comparision of two implementations

Candidate 1: Google's benchmark implementation (link)

This implementation is based on (the first paper of ResNet), also called ResNetV1. It includes ResNet 18, 34, 50, 101, 152.

  • Goods
    • Our original requirement is ResNet50 (although our requirement can be set by us as well).
    • It aligns with our benchmark team's repo, easy for comparison.
  • Bads
    • It has a bunch of benchmark scripts wrapped around for other NN models, needs clean ups.
    • This script is intended for benchmarking purpose. It's not meant to represent the exact reproduction of the original papers result, for example, the optimizer, learning rates could be different.
    • Also according to our benchmark team (Jing Huang), we haven't systematically tested the convergence of the model and the accuracies. But there are plans to do so in the future.

Candidate 2: Google's model implementation (link)

This implementation is based on (the second paper of ResNet), also called ResNetV2. It includes ResNet 32, 110, 164, 1001.

  • Goods
    • Cleaner implementation, only minimal clean up needed.
    • We have precision results published by Google that we can use to check the convergence. This implementation is meant for model training rather than benchmarking.
  • Bads
    • It's not ResNet50.
    • Does not align with our benchmark efforts.

Running and getting the XLA Ops

  • Step 1: Build and install TF with XLA.
  • Step 2: Modify the file to enable XLA
    • Add config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
      to create_config_proto() function in tf_cnn_benchmarks.py. For more info, see this doc.
  • Step 3: Run resnet model, dumping XLA graph to dot files
    TF_XLA_FLAGS=--xla_generate_hlo_graph=.* python tf_cnn_benchmarks.py --model resnet50 --batch_size 32 --data_format NHWC   
    
  • Step 4: Grep from the .dot files. There are more than 10,000 of .dot files, most of them are intermediate graphs from the passes, but collectively they represent the superset of all the ops used in the model.
    import os
    import re
    
    # get all .dot files
    file_names = []
    for file_name in os.listdir("/tmp"):
        if file_name.endswith(".dot"):
            file_names.append(os.path.join("/tmp", file_name))
    
    # get all opcodes
    opcode_re = r'%[a-z]+'
    all_opcodes = set()
    for file_name in file_names:
        with open(file_name) as f:
            for l in f.readlines():
                opcodes = set(re.findall(opcode_re, l))
                all_opcodes |= opcodes
    all_opcodes = [opcode[1:] for opcode in all_opcodes]
    print(sorted(all_opcodes))
    The output is the list at the very beginning.

TODO

  • The current dump is with XLA-GPU. Somehow the hlo_graph_dumper.cc does not work for CPU here. Need to investigate. This may create extra ops here like the fusion.
  • Extract from the benchmark scripts to make a clean ResNet50 implementation as our example.

@yxlao yxlao closed this as completed Aug 3, 2017
@yxlao yxlao reopened this Aug 3, 2017
@yxlao
Copy link
Member

yxlao commented Aug 3, 2017

While we're here, I also did some benchmarks on the Candidate 1 above.

  • Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
  • GeForce GTX 1080 Ti (x2)
  • Tested with MKL vs non-MKL

CPU

CUDA_VISIBLE_DEVICES="" python tf_cnn_benchmarks.py --model resnet50 --batch_size 32 --data_format NHWC
images/sec: 1.8 +/- 0.0 (jitter = 0.0)

CPU MKL

CUDA_VISIBLE_DEVICES="" python tf_cnn_benchmarks.py --model resnet50 --batch_size 32 --data_format NHWC
images/sec: 3.5 +/- 0.0 (jitter = 0.0)

GPU

CUDA_VISIBLE_DEVICES="0" python tf_cnn_benchmarks.py --model resnet50 --batch_size 32 --da
ta_format NHWC
images/sec: 131.0 +/- 0.0 (jitter = 0.0)

2xGPU

CUDA_VISIBLE_DEVICES="0,1" python tf_cnn_benchmarks.py --model resnet50 --batch_size 32 -
-data_format NHWC --num_gpus 2
images/sec: 250.1 +/- 0.0 (jitter = 0.0)

XLA CPU

CUDA_VISIBLE_DEVICES="" python tf_cnn_benchmarks.py --model resnet50 --batch_size 32 --data_format NHWC
images/sec: 1.8 +/- 0.0 (jitter = 0.0)

XLA GPU

CUDA_VISIBLE_DEVICES="0" python tf_cnn_benchmarks.py --model resnet50 --batch_size 32 --da
ta_format NHWC
images/sec: 112.9 +/- 0.0 (jitter = 0.0)

XLA 2xGPU

CUDA_VISIBLE_DEVICES="0,1" python tf_cnn_benchmarks.py --model resnet50 --batch_size 32 -
-data_format NHWC --num_gpus 2
images/sec: 160.1 +/- 0.0 (jitter = 0.0)

The CPU results are very poor, this is quite strange. With Candidate 2, when running with Non-XLA CPU I see all CPU threads used to full, but somehow with Candidate 1, we only see 30% usage of CPU threads. Maybe some settings problem.

@diyessi
Copy link
Contributor

diyessi commented Feb 15, 2018

Model chosen.

@diyessi diyessi closed this as completed Feb 15, 2018
mbrookhart pushed a commit that referenced this issue Sep 24, 2018
* add scalar ops, repeat

* add a TODO
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
No open projects
Development

No branches or pull requests

3 participants