Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
```
git clone https://github.com/facebookresearch/dlrm
cd dlrm
pip install -r requirement.txt
pip install -r requirements.txt
git checkout 52b77f80a24303294a02c86b574529cdc420aac5
patch -p1 < {path/to/intel-pytorch-extension}/torch_patches/models/dlrm.patch
```
Expand All @@ -82,7 +82,7 @@
cd transformers
git checkout 1a779ad7ecb9e5215b6bd1cfa0153469d37e4274
patch -p1 < {path/to/intel-pytorch-extension}/torch_patches/models/bert.patch
pip install -r ./examples/requirement.txt
pip install -r ./examples/requirements.txt
pip install --editable .
```

Expand All @@ -95,6 +95,7 @@
git clone https://github.com/pytorch/vision
cd vision
# add ResNext-101-32x4d model
git checkout 883f1fb01a8ba0e1b1cdc16c16f2e6e0ef87e3bd
patch -p1 < {path/to/intel-pytorch-extension}/torch_patches/models/vision.patch
python setup.py install
```
Expand Down Expand Up @@ -124,7 +125,7 @@ export KMP_SETTINGS=1
--arch-mlp-bot=13-512-256-128 --arch-mlp-top=1024-1024-512-256-1 \
--arch-sparse-feature-size=128 --max-ind-range=4000000 \
--numpy-rand-seed=727 \
--print-freq=1024 --print-time --mini-batch-size=2048 --num-batches=4096
--print-freq=1024 --print-time --mini-batch-size=2048 --num-batches=10240
```
```
# run DLRM bf16 trainining
Expand All @@ -136,7 +137,7 @@ export KMP_SETTINGS=1
--arch-mlp-bot=13-512-256-128 --arch-mlp-top=1024-1024-512-256-1 \
--arch-sparse-feature-size=128 --max-ind-range=4000000 \
--numpy-rand-seed=727 \
--print-freq=1024 --print-time --mini-batch-size=2048 --num-batches=4096 \
--print-freq=1024 --print-time --mini-batch-size=2048 --num-batches=10240 \
--use-ipex --mix-precision
```
```
Expand All @@ -149,7 +150,7 @@ export KMP_SETTINGS=1
--arch-mlp-bot=13-512-256-128 --arch-mlp-top=1024-1024-512-256-1 \
--arch-sparse-feature-size=128 --max-ind-range=4000000 \
--numpy-rand-seed=727 \
--print-freq=1024 --print-time --mini-batch-size=16 --num-batches=4096 \
--print-freq=1024 --print-time --mini-batch-size=16 \
--inference-only --share-weight --num-instance=24
```
```
Expand All @@ -162,7 +163,7 @@ export KMP_SETTINGS=1
--arch-mlp-bot=13-512-256-128 --arch-mlp-top=1024-1024-512-256-1 \
--arch-sparse-feature-size=128 --max-ind-range=4000000 \
--numpy-rand-seed=727 \
--print-freq=1024 --print-time --mini-batch-size=16 --num-batches=4096 \
--print-freq=1024 --print-time --mini-batch-size=16 \
--use-ipex --mix-precision --inference-only --share-weight --num-instance=24
```

Expand All @@ -189,33 +190,33 @@ examples/language-modeling/run_language_modeling.py \
```
```
# run Bert fp32 inference
for i in $(seq 0 $LAST_INSTANCE); do
for i in $(seq 0 23); do
LOG_i=cpufp32_bs1_ins${i}.txt
echo "### running on instance $i, numa node 0, core $i"
numactl --physcpubind=$i --membind=0 python -u \
examples/language-modeling/run_language_modeling.py \
--output_dir=output_$i --per_gpu_eval_batch_size=1 \
--model_type=bert_large --do_eval --eval_data_file=$HOME/wikitext-2-raw/wiki.train.raw \
--model_type=bert_large --do_eval --eval_data_file=$DATASET_PATH/wiki.train.raw \
--overwrite_output_dir --mlm --seed=42 --max_step=30 2>&1 | tee $LOG_i &
done
```
```
# run Bert bf16 inference
for i in $(seq 0 $LAST_INSTANCE); do
for i in $(seq 0 23); do
LOG_i=cpufp32_bs1_ins${i}.txt
echo "### running on instance $i, numa node 0, core $i"
numactl --physcpubind=$i --membind=0 python -u \
examples/language-modeling/run_language_modeling.py \
--output_dir=output_$i --per_gpu_eval_batch_size=1 \
--model_type=bert_large --do_eval --eval_data_file=$HOME/wikitext-2-raw/wiki.train.raw \
--model_type=bert_large --do_eval --eval_data_file=$DATASET_PATH/wiki.train.raw \
--overwrite_output_dir --mlm --seed=42 --max_step=30 --ipex --dnnl --mix_precision 2>&1 | tee $LOG_i &
done
```

3. ResNext-101-32x4d
```
cd {path/to/examples}/imagenet/
export DATA_PATH={path/to/dataset}
export DATASET_PATH={path/to/dataset}
```

```
Expand All @@ -224,13 +225,13 @@ done
```
```
# run Bert bf16 traininig
bash run_training_cpu_ipex.sh resnext101_32x4d $DATA_PATH
bash run_training_cpu_ipex.sh resnext101_32x4d $DATASET_PATH dnnl bf16
```
```
# run Bert fp32 inference
bash run_inference_cpu_latency.sh resnext101_32x4d
```
```
# run Bert bf16 inference
bash run_inference_cpu_latency_ipex.sh resnext101_32x4d $DATA_PATH
bash run_inference_cpu_latency_ipex.sh resnext101_32x4d $DATASET_PATH dnnl bf16 jit
```
43 changes: 26 additions & 17 deletions torch_patches/models/bert.patch
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/examples/language-modeling/run_language_modeling.py b/examples/language-modeling/run_language_modeling.py
index 740cb636..6fdf1393 100644
index 740cb636..ce3b74fb 100644
--- a/examples/language-modeling/run_language_modeling.py
+++ b/examples/language-modeling/run_language_modeling.py
@@ -19,7 +19,7 @@ GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss.
Expand Down Expand Up @@ -37,7 +37,7 @@ index 740cb636..6fdf1393 100644
+ else:
+ ipex.core.disable_auto_dnnl()
+ if model_args.mix_precision:
+ ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=True)
+ ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16, train=True)
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
Expand All @@ -61,10 +61,32 @@ index 740cb636..6fdf1393 100644
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index a03ac23f..4754b2a0 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -1023,6 +1023,8 @@ class Trainer:
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
+ if samples_count == self.args.max_steps * batch_size:
+ break
if loss is not None:
eval_losses.append(loss * batch_size)
if logits is not None:
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 64900136..aae71ac6 100644
index 64900136..5c183f97 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple

from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
-
+import intel_pytorch_extension as ipex

if is_torch_available():
import torch
@@ -234,6 +234,10 @@ class TrainingArguments:
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
)
Expand All @@ -81,21 +103,8 @@ index 64900136..aae71ac6 100644
device = torch.device("cpu")
n_gpu = 0
+ elif self.ipex:
+ device = torch.device("dpcpp")
+ device = torch.device(ipex.DEVICE)
+ n_gpu = 0
elif is_torch_tpu_available():
device = xm.xla_device()
n_gpu = 0
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 1e09e03d..d254d977 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -1037,6 +1039,8 @@ class Trainer:
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
+ if samples_count == self.args.max_steps * batch_size:
+ break
if loss is not None:
eval_losses.append(loss * batch_size)
if logits is not None:
32 changes: 21 additions & 11 deletions torch_patches/models/dlrm.patch
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
diff --git a/dlrm_data_pytorch.py b/dlrm_data_pytorch.py
index 736b6dd..e6b0bc3 100644
--- a/dlrm_data_pytorch.py
+++ b/dlrm_data_pytorch.py
@@ -382,7 +382,7 @@ def ensure_dataset_preprocessed(args, d_path):

def make_criteo_data_and_loaders(args):

- if args.mlperf_logging and args.memory_map and args.data_set == "terabyte":
+ if args.memory_map and args.data_set == "terabyte":
# more efficient for larger batches
data_directory = path.dirname(args.raw_data_file)

diff --git a/dlrm_s_pytorch.py b/dlrm_s_pytorch.py
index 3c5c85f..660f6eb 100644
--- a/dlrm_s_pytorch.py
Expand Down Expand Up @@ -131,8 +144,8 @@ index 3c5c85f..660f6eb 100644
+ import intel_pytorch_extension as ipex
+ ipex.core.enable_auto_dnnl()
+ if args.mix_precision:
+ ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=not args.inference_only)
+ device = torch.device("dpcpp")
+ ipex.enable_auto_mixed_precision(mixed_dtype=torch.bfloat16, train=not args.inference_only)
+ device = torch.device(ipex.DEVICE)
+ print("Runining with IPEX...")
else:
device = torch.device("cpu")
Expand Down Expand Up @@ -218,9 +231,6 @@ index 3c5c85f..660f6eb 100644
+ if args.share_weight and args.inference_only:
+ for j, (X, lS_o, lS_i, T) in enumerate(train_ld):
+ traced_model = torch.jit.trace(dlrm, wrap_input(X, lS_o, lS_i, use_gpu, device), check_trace=False)
+ # g=traced_model.graph
+ # torch._C._jit_pass_inline(g)
+ # print(g)
+ break
+ bench = ThroughputBenchmark(traced_model)
+ j = 0
Expand All @@ -231,12 +241,12 @@ index 3c5c85f..660f6eb 100644
+ if j == 1: break
+ if args.use_ipex and args.mix_precision:
+ ipex.core.disable_mix_bf16_fp32()
+ stats = bench.benchmark(
+ num_calling_threads=args.num_instance,
+ num_warmup_iters=100 * args.num_instance,
+ num_iters=1000 * args.num_instance,
+ )
+ print(stats)
+ stats = bench.benchmark(
+ num_calling_threads=args.num_instance,
+ num_warmup_iters=100 * args.num_instance,
+ num_iters=1000 * args.num_instance,
+ )
+ print(stats)
+ if args.enable_profiling:
+ print(prof.key_averages().table(sort_by="self_cpu_time_total"))
+ sys.exit()
Expand Down
Loading