Skip to content

Commit

Permalink
drop jit for loss and remove bias sharding for tp
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 27, 2023
1 parent 68dcc10 commit 168278e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 15 deletions.
13 changes: 2 additions & 11 deletions src/partitioning/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,9 @@
from flax.core.frozen_dict import freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.sharding import Mesh, PartitionSpec
from jax.sharding import PositionalSharding, NamedSharding
from jax.sharding import NamedSharding
from typing import Callable

def setup_dp_mesh():
"""
Creates jax device mesh for data-parallel training
"""
devices = np.asarray(jax.devices())
mesh = Mesh(devices, ["dp"])

return mesh


def _match(qs, ks):
"""Return True if regexes in qs match any window of strings in tuple ks."""
Expand Down Expand Up @@ -102,7 +93,7 @@ def _get_partition_rules_tp(mesh: Mesh):
(("(query_proj|key_proj|value_proj)", "kernel"), NamedSharding(mesh,PartitionSpec(None, "mp"))),
(("residual_out", "kernel"), NamedSharding(mesh,PartitionSpec("mp", None))),
(("(query_proj|key_proj|value_proj)", "bias"), NamedSharding(mesh,PartitionSpec(None))),
(("residual_out", "bias"), NamedSharding(mesh,PartitionSpec("mp"))),
(("residual_out", "bias"), NamedSharding(mesh,PartitionSpec("None"))),
# MLP
(("fc_in", "kernel"), NamedSharding(mesh,PartitionSpec(None, "mp"))),
(("fc_residual", "kernel"), NamedSharding(mesh,PartitionSpec("mp", None))),
Expand Down
6 changes: 2 additions & 4 deletions src/utils/losses.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""
Loss function
"""
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import jit


@jit
def cross_entropy_loss(labels: jnp.array, logits: jnp.array) -> jnp.array:
"""Standard Cross Entropy Loss function
Expand All @@ -19,5 +17,5 @@ def cross_entropy_loss(labels: jnp.array, logits: jnp.array) -> jnp.array:
"""

return (
-jnp.sum(labels * nn.log_softmax(logits.astype(jnp.float32), axis=-1), axis=-1)
-jnp.sum(labels * jax.nn.log_softmax(logits.astype(jnp.float32), axis=-1), axis=-1)
)

0 comments on commit 168278e

Please sign in to comment.