Skip to content

Commit

Permalink
put ECD on GPU to recreate the 8-bit tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffkinnison committed Jan 25, 2024
1 parent 831ebf5 commit 318aa19
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
MAX_CPU_BATCH_SIZE,
MINIMIZE,
MODEL_ECD,
MODEL_LLM,
TEST,
TRAINING,
USED_TOKENS,
Expand Down Expand Up @@ -1143,8 +1144,10 @@ def train(
# to a RuntimeError in `load_state_dict`. Explicitly call `model.cuda()` to make sure the
# matrices are part of model state. This workaround is necessary because the matrices are
# deleted during the model's forward pass.
if self.model.model.device.type == "cuda":
if self.model.config_obj.model_type == MODEL_LLM and self.model.model.device.type == "cuda":
self.model.model.cuda()
elif self.model.config_obj.model_type == MODEL_ECD and self.model.device.type == "cuda":
self.model.cuda()
_, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
only_weights_format_keys = ["weights_format" in k for k in unexpected_keys]

Expand Down

0 comments on commit 318aa19

Please sign in to comment.