Skip to content

Commit

Permalink
[SPMD] Replace nn.Linear (huggingface#36)
Browse files Browse the repository at this point in the history
Summary:
This pull request replaces the default nn.Linear with our patched version that doesn't flatten the high dimensional tensors.

Test Plan:
Tested on a V4-8.
  • Loading branch information
alanwaketan committed Oct 27, 2023
1 parent 3bead04 commit 29fb919
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,10 @@ def get_mesh(ici_mesh_shape, dcn_mesh_shape=None):
else:
return xs.HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape)

# Replace the linear layer
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
model = apply_xla_patch_to_nn_linear(model, xs.xla_patched_nn_linear_forward)

# Convert the model from meta to XLA tensors one layer at a time to avoid
# host-side OOM
for name, param in model.named_parameters():
Expand Down

0 comments on commit 29fb919

Please sign in to comment.