diff --git a/inference/huggingface/fill-mask/test-bert.py b/inference/huggingface/fill-mask/test-bert.py index 994d5fde2..441c617b7 100644 --- a/inference/huggingface/fill-mask/test-bert.py +++ b/inference/huggingface/fill-mask/test-bert.py @@ -14,7 +14,7 @@ pipe.model, mp_size=world_size, dtype=torch.float, - injection_policy={BertLayer : ('output.dense')} + replace_with_kernel_inject=True ) pipe.device = torch.device(f'cuda:{local_rank}')