Skip to content

Commit

Permalink
Supported DDP for two prune once for all examples (#1095)
Browse files Browse the repository at this point in the history
  • Loading branch information
XinyuYe-Intel committed Sep 1, 2022
1 parent 71c792b commit 26a4762
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 62 deletions.
Expand Up @@ -47,5 +47,22 @@ python run_qa_no_trainer_pruneOFA.py --dataset_name squad \
--learning_rate 1e-5 --do_eval --num_train_epochs 2 --do_quantization \
--output_dir /path/to/stage2_output_dir --loss_weights 0 1 \
--temperature 2 --seed 5143 --pad_to_max_length --run_teacher_logits \
--resume /path/to/stage1_output_dir/best_model_weights.pt
```
--resume /path/to/stage1_output_dir/best_model.pt
```

We also supported Distributed Data Parallel training on single node and multi nodes settings. To use Distributed Data Parallel to speedup training, the bash command needs a small adjustment.
<br>
For example, bash command of stage 1 for SQuAD task will look like the following, where *`<MASTER_ADDRESS>`* is the address of the master node, it won't be necessary for single node case, *`<NUM_PROCESSES_PER_NODE>`* is the desired processes to use in current node, for node with GPU, usually set to number of GPUs in this node, for node without GPU and use CPU for training, it's recommended set to 1, *`<NUM_NODES>`* is the number of nodes to use, *`<NODE_RANK>`* is the rank of the current node, rank starts from 0 to *`<NUM_NODES>`*`-1`.
<br>
Also please note that to use CPU for training in each node with multi nodes settings, argument `--no_cuda` is mandatory. In multi nodes setting, following command needs to be lanuched in each node, and all the commands should be the same except for *`<NODE_RANK>`*, which should be integer from 0 to *`<NUM_NODES>`*`-1` assigned to each node.

```bash
python -m torch.distributed.launch --master_addr=<MASTER_ADDRESS> --nproc_per_node=<NUM_PROCESSES_PER_NODE> --nnodes=<NUM_NODES> --node_rank=<NODE_RANK> \
run_qa_no_trainer_pruneOFA.py --dataset_name squad \
--model_name_or_path Intel/bert-base-uncased-sparse-90-unstructured-pruneofa \
--teacher_model_name_or_path csarron/bert-base-uncased-squad-v1 \
--do_prune --do_distillation --max_seq_length 384 --batch_size 12 \
--learning_rate 1.5e-4 --do_eval --num_train_epochs 8 \
--output_dir /path/to/stage1_output_dir --loss_weights 0 1 \
--temperature 2 --seed 5143 --pad_to_max_length --run_teacher_logits
```
Expand Up @@ -257,6 +257,10 @@ def parse_args():
help='loss weights of distillation, should be a list of length 2, '
'and sum to 1.0, first for student targets loss weight, '
'second for teacher student loss weight.')
parser.add_argument("--local_rank", default=-1, type=int,
help='used for assigning rank to the process in local machine.')
parser.add_argument("--no_cuda", action='store_true',
help='use cpu for training.')
args = parser.parse_args()

# Sanity checks
Expand Down Expand Up @@ -467,7 +471,7 @@ def main():
args = parse_args()

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
accelerator = Accelerator()
accelerator = Accelerator(cpu=args.no_cuda)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -836,14 +840,32 @@ def get_logits(teacher_model, train_dataset):
if os.path.exists(npy_file):
teacher_logits = [list(x) for x in np.load(npy_file, allow_pickle=True)]
else:
train_dataloader = DataLoader(train_dataset, collate_fn=data_collator, batch_size=args.batch_size)
train_dataloader = accelerator.prepare(train_dataloader)
sampler = None
if accelerator.num_processes > 1:
from transformers.trainer_pt_utils import ShardSampler
sampler = ShardSampler(
train_dataset,
batch_size=args.batch_size,
num_processes=accelerator.num_processes,
process_index=accelerator.process_index,
)
train_dataloader = DataLoader(
train_dataset, collate_fn=data_collator, sampler=sampler, batch_size=args.batch_size
)
train_dataloader = tqdm(train_dataloader, desc="Evaluating")
teacher_logits = []
for step, batch in enumerate(train_dataloader):
batch = move_input_to_device(batch, next(teacher_model.parameters()).device)
outputs = teacher_model(**batch).cpu().detach().numpy()
if accelerator.num_processes > 1:
outputs_list = [None for i in range(accelerator.num_processes)]
torch.distributed.all_gather_object(outputs_list, outputs)
outputs = np.concatenate(outputs_list, axis=0)
teacher_logits += [[s,e] for s,e in zip(outputs[0::2], outputs[1::2])]
np.save(npy_file, teacher_logits, allow_pickle=True)
if accelerator.num_processes > 1:
teacher_logits = teacher_logits[:len(train_dataset)]
if accelerator.local_process_index in [-1, 0]:
np.save(npy_file, teacher_logits, allow_pickle=True)
return train_dataset.add_column('teacher_logits', teacher_logits)
with torch.no_grad():
train_dataset = get_logits(teacher_model, train_dataset)
Expand Down Expand Up @@ -934,9 +956,8 @@ def eval_func(model):
if 'teacher_logits' in input_names:
input_names.remove('teacher_logits')
break
model = symbolic_trace(model, input_names=input_names, \
batch_size=args.batch_size, \
sequence_length=args.max_seq_length)
model = symbolic_trace(accelerator.unwrap_model(model), input_names=input_names, \
batch_size=args.batch_size, sequence_length=args.max_seq_length)

from neural_compressor.experimental.scheduler import Scheduler
from neural_compressor.experimental import Quantization
Expand All @@ -954,7 +975,9 @@ def eval_func(model):
agent.pruning_func = train_func
agent.eval_func = eval_func
model = agent()
model.save(args.output_dir)
model = common.Model(accelerator.unwrap_model(model.model))
if accelerator.local_process_index in [-1, 0]:
model.save(args.output_dir)
# change to framework model for further use
model = model.model

Expand Down
Expand Up @@ -45,7 +45,7 @@ python run_glue_no_trainer_pruneOFA.py --task_name sst2 \
--do_prune --do_distillation --max_seq_length 128 --batch_size 32 \
--learning_rate 1e-5 --num_train_epochs 3 --output_dir /path/to/stage2_output_dir \
--loss_weights 0 1 --temperature 2 --seed 5143 --do_quantization \
--resume /path/to/stage1_output_dir/best_model_weights.pt --pad_to_max_length
--resume /path/to/stage1_output_dir/best_model.pt --pad_to_max_length
```

## MNLI task
Expand All @@ -65,7 +65,7 @@ python run_glue_no_trainer_pruneOFA.py --task_name mnli \
--do_prune --do_distillation --max_seq_length 128 --batch_size 32 \
--learning_rate 1e-5 --num_train_epochs 3 --output_dir /path/to/stage2_output_dir \
--loss_weights 0 1 --temperature 2 --seed 5143 --do_quantization \
--resume /path/to/stage1_output_dir/best_model_weights.pt --pad_to_max_length
--resume /path/to/stage1_output_dir/best_model.pt --pad_to_max_length
```

## QQP task
Expand All @@ -85,7 +85,7 @@ python run_glue_no_trainer_pruneOFA.py --task_name qqp \
--do_prune --do_distillation --max_seq_length 128 --batch_size 32 \
--learning_rate 1e-5 --num_train_epochs 3 --output_dir /path/to/stage2_output_dir \
--loss_weights 0 1 --temperature 2 --seed 5143 --do_quantization \
--resume /path/to/stage1_output_dir/best_model_weights.pt --pad_to_max_length
--resume /path/to/stage1_output_dir/best_model.pt --pad_to_max_length
```

## QNLI task
Expand All @@ -105,5 +105,21 @@ python run_glue_no_trainer_pruneOFA.py --task_name qnli \
--do_prune --do_distillation --max_seq_length 128 --batch_size 32 \
--learning_rate 1e-5 --num_train_epochs 3 --output_dir /path/to/stage2_output_dir \
--loss_weights 0 1 --temperature 2 --seed 5143 --do_quantization \
--resume /path/to/stage1_output_dir/best_model_weights.pt --pad_to_max_length
--resume /path/to/stage1_output_dir/best_model.pt --pad_to_max_length
```

We also supported Distributed Data Parallel training on single node and multi nodes settings. To use Distributed Data Parallel to speedup training, the bash command needs a small adjustment.
<br>
For example, bash command of stage 1 for SST2 task will look like the following, where *`<MASTER_ADDRESS>`* is the address of the master node, it won't be necessary for single node case, *`<NUM_PROCESSES_PER_NODE>`* is the desired processes to use in current node, for node with GPU, usually set to number of GPUs in this node, for node without GPU and use CPU for training, it's recommended set to 1, *`<NUM_NODES>`* is the number of nodes to use, *`<NODE_RANK>`* is the rank of the current node, rank starts from 0 to *`<NUM_NODES>`*`-1`.
<br>
Also please note that to use CPU for training in each node with multi nodes settings, argument `--no_cuda` is mandatory. In multi nodes setting, following command needs to be lanuched in each node, and all the commands should be the same except for *`<NODE_RANK>`*, which should be integer from 0 to *`<NUM_NODES>`*`-1` assigned to each node.

```bash
python -m torch.distributed.launch --master_addr=<MASTER_ADDRESS> --nproc_per_node=<NUM_PROCESSES_PER_NODE> --nnodes=<NUM_NODES> --node_rank=<NODE_RANK> \
run_glue_no_trainer_pruneOFA.py --task_name sst2 \
--model_name_or_path Intel/bert-base-uncased-sparse-90-unstructured-pruneofa \
--teacher_model_name_or_path textattack/bert-base-uncased-SST-2 \
--do_prune --do_distillation --max_seq_length 128 --batch_size 32 \
--learning_rate 1e-4 --num_train_epochs 9 --output_dir /path/to/stage1_output_dir \
--loss_weights 0 1 --temperature 2 --seed 5143
```

0 comments on commit 26a4762

Please sign in to comment.