Skip to content

Commit

Permalink
added CPU optimization guide part into tuning_guide (pytorch#1512)
Browse files Browse the repository at this point in the history
* added CPU optimization guide part into tuning_guide

* changed non-python command to python comments in CPU specific optimization section

* Update tuning_guide.py

Changed comment of bash commands to double quote.

* Update tuning_guide.py

Co-authored-by: Brian Johnson <brianjo@fb.com>
  • Loading branch information
jingxu10 and brianjo committed Jun 2, 2021
1 parent a162702 commit 07fc674
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions recipes_source/recipes/tuning_guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,69 @@ def fused_gelu(x):
# `torch.autograd.gradgradcheck <https://pytorch.org/docs/stable/autograd.html#torch.autograd.gradgradcheck>`_
#

###############################################################################
# CPU specific optimizations
# --------------------------

###############################################################################
# Utilize Non-Uniform Memory Access (NUMA) Controls
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# NUMA or non-uniform memory access is a memory layout design used in data center machines meant to take advantage of locality of memory in multi-socket machines with multiple memory controllers and blocks. Generally speaking, all deep learning workloads, training or inference, get better performance without accessing hardware resources across NUMA nodes. Thus, inference can be run with multiple instances, each instance runs on one socket, to raise throughput. For training tasks on single node, distributed training is recommended to make each training process run on one socket.
#
# In general cases the following command executes a PyTorch script on cores on the Nth node only, and avoids cross-socket memory access to reduce memory access overhead.

# numactl --cpunodebind=N --membind=N python <pytorch_script>

###############################################################################
# More detailed descriptions can be found `here <https://software.intel.com/content/www/us/en/develop/articles/how-to-get-better-performance-on-pytorchcaffe2-with-intel-acceleration.html>`_.

###############################################################################
# Utilize OpenMP
# ~~~~~~~~~~~~~~
# OpenMP is utilized to bring better performance for parallel computation tasks.
# OMP_NUM_THREADS is the easiest switch that can be used to accelerate computations. It determines number of threads used for OpenMP computations.
# CPU affinity setting controls how workloads are distributed over multiple cores. It affects communication overhead, cache line invalidation overhead, or page thrashing, thus proper setting of CPU affinity brings performance benefits. GOMP_CPU_AFFINITY or KMP_AFFINITY determines how to bind OpenMP* threads to physical processing units. Detailed information can be found `here <https://software.intel.com/content/www/us/en/develop/articles/how-to-get-better-performance-on-pytorchcaffe2-with-intel-acceleration.html>`_.

###############################################################################
# With the following command, PyTorch run the task on N OpenMP threads.

# export OMP_NUM_THREADS=N

###############################################################################
# Typically, the following environment variables are used to set for CPU affinity with GNU OpenMP implementation. OMP_PROC_BIND specifies whether threads may be moved between processors. Setting it to CLOSE keeps OpenMP threads close to the primary thread in contiguous place partitions. OMP_SCHEDULE determines how OpenMP threads are scheduled. GOMP_CPU_AFFINITY binds threads to specific CPUs.

# export OMP_SCHEDULE=STATIC
# export OMP_PROC_BIND=CLOSE
# export GOMP_CPU_AFFINITY="N-M"

###############################################################################
# Intel OpenMP Runtime Library (libiomp)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# By default, PyTorch uses GNU OpenMP (GNU libgomp) for parallel computation. On Intel platforms, Intel OpenMP Runtime Library (libiomp) provides OpenMP API specification support. It sometimes brings more performance benefits compared to libgomp. Utilizing environment variable LD_PRELOAD can switch OpenMP library to libiomp:

# export LD_PRELOAD=<path>/libiomp5.so:$LD_PRELOAD

###############################################################################
# Similar to CPU affinity settings in GNU OpenMP, environment variables are provided in libiomp to control CPU affinity settings.
# KMP_AFFINITY binds OpenMP threads to physical processing units. KMP_BLOCKTIME sets the time, in milliseconds, that a thread should wait, after completing the execution of a parallel region, before sleeping. In most cases, setting KMP_BLOCKTIME to 1 or 0 yields good performances.
# The following commands show a common settings with Intel OpenMP Runtime Library.

# export KMP_AFFINITY=granularity=fine,compact,1,0
# export KMP_BLOCKTIME=1

###############################################################################
# Switch Memory allocator
# ~~~~~~~~~~~~~~~~~~~~~~~
# For deep learning workloads, Jemalloc or TCMalloc can get better performance by reusing memory as much as possible than default malloc funtion. `Jemalloc <https://github.com/jemalloc/jemalloc>`_ is a general purpose malloc implementation that emphasizes fragmentation avoidance and scalable concurrency support. `TCMalloc <https://google.github.io/tcmalloc/overview.html>`_ also features a couple of optimizations to speed up program executions. One of them is holding memory in caches to speed up access of commonly-used objects. Holding such caches even after deallocation also helps avoid costly system calls if such memory is later re-allocated.
# Use environment variable LD_PRELOAD to take advantage of one of them.

# export LD_PRELOAD=<jemalloc.so/tcmalloc.so>:$LD_PRELOAD

###############################################################################
# Train a model on CPU with PyTorch DistributedDataParallel(DDP) functionality
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# For small scale models or memory-bound models, such as DLRM, training on CPU is also a good choice. On a machine with multiple sockets, distributed training brings a high-efficient hardware resource usage to accelerate the training process. `Torch-ccl <https://github.com/intel/torch-ccl>`_, optimized with Intel(R) oneCCL (collective commnications library) for efficient distributed deep learning training implementing such collectives like allreduce, allgather, alltoall, implements PyTorch C10D ProcessGroup API and can be dynamically loaded as external ProcessGroup. Upon optimizations implemented in PyTorch DDP moduel, torhc-ccl accelerates communication operations. Beside the optimizations made to communication kernels, torch-ccl also features simultaneous computation-communication functionality.

###############################################################################
# GPU specific optimizations
# --------------------------
Expand Down

0 comments on commit 07fc674

Please sign in to comment.