Skip to content

Commit

Permalink
replace init_empty_weights with torch.device(meta)
Browse files Browse the repository at this point in the history
  • Loading branch information
lchu-ibm committed Aug 1, 2023
1 parent d8a81bb commit c8d4f38
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions llama_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
get_policies
)

from accelerate import init_empty_weights

from utils.dataset_utils import get_preprocessed_dataset

from utils.config_utils import (
Expand Down Expand Up @@ -107,7 +105,7 @@ def main(**kwargs):
)
else:
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
with init_empty_weights():
with torch.device("meta"):
model = LlamaForCausalLM(llama_config)
else:
model = LlamaForCausalLM.from_pretrained(
Expand Down

0 comments on commit c8d4f38

Please sign in to comment.