Skip to content

Commit

Permalink
fix failing tests by adding tp_comms flag
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 28, 2023
1 parent 5c76fc6 commit 871b017
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 134 deletions.
5 changes: 5 additions & 0 deletions src/models/GPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TransformerBlock(nn.Module):
N: int = None
dtype: Any = jnp.float32
alibi_attn: bool = True
tp_comms: bool = True

@nn.compact
def __call__(
Expand All @@ -39,13 +40,15 @@ def __call__(
self.N,
self.alibi_attn,
self.dtype,
tp_comms=self.tp_comms
)(nn.LayerNorm(dtype=self.dtype, use_bias=False)(x), train)
x = x + attn_out
x = x + MLPBlock(
self.embedding_dim,
dropout=self.residual_dropout,
N=self.N,
dtype=self.dtype,
tp_comms=self.tp_comms
)(nn.LayerNorm(dtype=self.dtype, use_bias=False)(x), train)
return x

Expand All @@ -63,6 +66,7 @@ class Transformer(nn.Module):
N: int = None
dtype: Any = jnp.float32
alibi_attn: bool = False
tp_comms: bool = True

@nn.compact
def __call__(
Expand Down Expand Up @@ -93,6 +97,7 @@ def __call__(
self.N,
self.dtype,
self.alibi_attn,
self.tp_comms
)(out, train)

out = nn.LayerNorm(dtype=self.dtype, use_bias=False)(out)
Expand Down
23 changes: 15 additions & 8 deletions src/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import jax
import jax.nn.initializers as initializers
import jax.numpy as jnp
from flax.linen import partitioning as nn_partitioning
from einops import rearrange
from src.models.replicated_utils import *
from src.models.replicated_utils import f_psum, g_psum


def get_slopes(n: int) -> List:
Expand Down Expand Up @@ -48,14 +47,16 @@ class MLPBlock(nn.Module):
dimension_multiplier: int = 4
dropout: float = 0.0
N: int = None
tp_comms: bool = True

dtype: Any = jnp.float32

@nn.compact
def __call__(self, x: jnp.array, train: bool) -> jnp.array:
dropout = partial(nn.Dropout, rate=self.dropout, deterministic=not train)

x = f_psum(x)
if self.tp_comms:
x = f_psum(x)

x = nn.Dense(
features=self.dimension_multiplier * self.embedding_dim,
Expand All @@ -74,7 +75,9 @@ def __call__(self, x: jnp.array, train: bool) -> jnp.array:
dtype=self.dtype,
use_bias=False,
)(x)
out = g_psum(out)
if self.tp_comms:

out = g_psum(out)
return dropout()(out)


Expand All @@ -95,6 +98,7 @@ class CausalAttention(nn.Module):
N: int = None
alibi_attn: bool = False
dtype: Any = jnp.float32
tp_comms: bool = True

def setup(self):
self.slopes = jnp.array(get_slopes(self.num_head))
Expand All @@ -112,7 +116,8 @@ def __call__(
dropout = partial(nn.Dropout, rate=self.dropout, deterministic=not train)
T, C = x.shape[-2:]

x = f_psum(x)
if self.tp_comms:
x = f_psum(x)

key = nn.Dense(
name="key_proj",
Expand Down Expand Up @@ -155,8 +160,9 @@ def __call__(
attn_full = (query @ key) / jnp.sqrt(key.shape[-1]) # Shape is (B, nh, sq, sk)

if self.alibi_attn:
# NOTE: We are fixing the ALiBi mask since this is for training, during inference or is seq_len changes this will cause issues
attn_full = attn_full + self.alibi_mask[:, :T, :T]
# NOTE: We are fixing the ALiBi mask since this is for training,
# during inference or is seq_len changes this will cause issues
attn_full = attn_full + self.alibi_mask

masked_attn = jnp.where(
self.mask, attn_full.astype(jnp.float32), jnp.finfo(jnp.float32).min
Expand All @@ -178,6 +184,7 @@ def __call__(
use_bias=False,
)(attn_out)

out = g_psum(out)
if self.tp_comms:
out = g_psum(out)

return dropout()(out)
123 changes: 2 additions & 121 deletions src/partitioning/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@
"""

import re

import jax
import numpy as np
import optax
from flax.core import FrozenDict
from flax.core.frozen_dict import freeze
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.sharding import Mesh, PartitionSpec
from jax.sharding import NamedSharding
from typing import Callable


Expand All @@ -37,60 +34,6 @@ def replace(key, val):

return replace


def _get_partition_rules_dp(mesh: Mesh):
"""
Follows Megatron-LM partition rules from
`Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism`
by Shoeybi et al.
<https://arxiv.org/abs/1909.08053>
"""
return [
(("wte", "embedding"), NamedSharding(mesh, PartitionSpec("dp", None, None))),
(("wpe", "embedding"), NamedSharding(mesh, PartitionSpec("dp", None, None))),
# attention
(
("(query_proj|key_proj|value_proj)", "kernel"),
NamedSharding(mesh, PartitionSpec("dp", None, None)),
),
(
("residual_out", "kernel"),
NamedSharding(mesh, PartitionSpec("dp", None, None)),
),
(
("(query_proj|key_proj|value_proj)", "bias"),
NamedSharding(mesh, PartitionSpec("dp", None)),
),
(("residual_out", "bias"), NamedSharding(mesh, PartitionSpec("dp", None))),
# MLP
(("fc_in", "kernel"), NamedSharding(mesh, PartitionSpec("dp", None, None))),
(
("fc_residual", "kernel"),
NamedSharding(mesh, PartitionSpec("dp", None, None)),
),
(("fc_in", "bias"), NamedSharding(mesh, PartitionSpec("dp", None))),
(("fc_residual", "bias"), NamedSharding(mesh, PartitionSpec("dp", None))),
# layer norms
(
(
"LayerNorm_0",
"(bias|scale)",
),
NamedSharding(mesh, PartitionSpec("dp", None)),
),
# layer norms
(
(
"LayerNorm_1",
"(bias|scale)",
),
NamedSharding(mesh, PartitionSpec("dp", None)),
),
]


def _get_partition_rules_tp(axis_name: str):
"""
Follows Megatron-LM partition rules from
Expand Down Expand Up @@ -144,68 +87,6 @@ def _get_partition_rules_tp(axis_name: str):
]


def _get_partition_rules_tp_dp(mesh: Mesh, axis_name: str):
"""
Follows Megatron-LM partition rules from
`Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism`
by Shoeybi et al.
<https://arxiv.org/abs/1909.08053>
"""
return [
(
("wte", "embedding"),
NamedSharding(mesh, PartitionSpec("dp", axis_name, None)),
),
(
("wpe", "embedding"),
NamedSharding(mesh, PartitionSpec("dp", axis_name, None)),
),
# attention
(
("(query_proj|key_proj|value_proj)", "kernel"),
NamedSharding(mesh, PartitionSpec("dp", None, axis_name)),
),
(
("residual_out", "kernel"),
NamedSharding(mesh, PartitionSpec("dp", axis_name, None)),
),
(
("(query_proj|key_proj|value_proj)", "bias"),
NamedSharding(mesh, PartitionSpec("dp", None)),
),
(("residual_out", "bias"), NamedSharding(mesh, PartitionSpec("dp", axis_name))),
# MLP
(
("fc_in", "kernel"),
NamedSharding(mesh, PartitionSpec("dp", None, axis_name)),
),
(
("fc_residual", "kernel"),
NamedSharding(mesh, PartitionSpec("dp", axis_name, None)),
),
(("fc_in", "bias"), NamedSharding(mesh, PartitionSpec("dp", None))),
(("fc_residual", "bias"), NamedSharding(mesh, PartitionSpec("dp", None))),
# layer norms
(
(
"LayerNorm_0",
"(bias|scale)",
),
NamedSharding(mesh, PartitionSpec("dp", None)),
),
# layer norms
(
(
"LayerNorm_1",
"(bias|scale)",
),
NamedSharding(mesh, PartitionSpec("dp", None)),
),
]


def set_partitions_rules(
in_dict, mesh: Mesh, rules_func: Callable, axis_name: str = "mp"
):
Expand Down Expand Up @@ -244,8 +125,8 @@ def get_opt_spec(x):
x, (FrozenDict,)
): # if we get first/second moment buffers, clone PSpec of the params
return param_spec
return None # else, PSpec of None (this is to be copied across all devices) (stuff like GA step, skip_step, etc)

return PartitionSpec() # else, PSpec of None (this is to be copied across all devices) (stuff like GA step, skip_step, etc)
opt_state_spec = jax.tree_util.tree_map(
get_opt_spec,
opt_state_shapes,
Expand Down
17 changes: 12 additions & 5 deletions tests/test_model_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def tearDown(self) -> None:

def test_MLP_create(self):

mlp = MLPBlock(embedding_dim=128, dimension_multiplier=4, dropout=0.1, N=10)
mlp = MLPBlock(embedding_dim=128, dimension_multiplier=4, dropout=0.1, N=10, tp_comms=False)
batch_cts = random.normal(self.rng, shape=(1, 512, 128))
params = mlp.init(self.init_rng, batch_cts, False)

def test_MLP_fwd(self):
mlp = MLPBlock(embedding_dim=128, dimension_multiplier=4, dropout=0.1, N=6)
mlp = MLPBlock(embedding_dim=128, dimension_multiplier=4, dropout=0.1, N=6, tp_comms=False)
batch_cts = random.normal(self.rng, shape=(1, 512, 128))
params = mlp.init(self.init_rng, batch_cts, False)

Expand Down Expand Up @@ -57,14 +57,14 @@ def tearDown(self) -> None:
def test_attn_create(self):

attn = CausalAttention(
embedding_dim=128, num_head=8, block_size=512, dropout=0.1, N=6
embedding_dim=128, num_head=8, block_size=512, dropout=0.1, N=6, tp_comms=False
)
batch_cts = random.normal(self.rng, shape=(1, 512, 128))
params = attn.init(self.init_rng, batch_cts, False)

def test_attn_fwd(self):
attn = CausalAttention(
embedding_dim=128, num_head=8, block_size=512, dropout=0.1, N=6
embedding_dim=128, num_head=8, block_size=512, dropout=0.1, N=6, tp_comms=False
)
batch_cts = random.normal(self.rng, shape=(1, 512, 128))
params = attn.init(self.init_rng, batch_cts, False)
Expand Down Expand Up @@ -94,6 +94,7 @@ def test_attn_fwd_ALiBi(self):
dropout=0.1,
N=6,
alibi_attn=True,
tp_comms=False
)
batch_cts = random.normal(self.rng, shape=(1, 512, 128))
params = attn.init(self.init_rng, batch_cts, False)
Expand Down Expand Up @@ -133,6 +134,7 @@ def test_block_create_standard(self):
residual_dropout=0.1,
N=6,
dtype=None,
tp_comms=False
)
batch_cts = random.normal(self.rng, shape=(1, 512, 128))
params = block.init(self.init_rng, batch_cts, False)
Expand All @@ -147,6 +149,7 @@ def test_block_fwd_standard(self):
N=6,
dtype=None,
alibi_attn=True,
tp_comms=False
)
batch_cts = random.normal(self.rng, shape=(1, 512, 128))
params = block.init(self.init_rng, batch_cts, False)
Expand Down Expand Up @@ -179,6 +182,7 @@ def test_gpt_create_standard(self):
dropout=0.1,
N=6,
dtype=None,
tp_comms=False
)
batch_tok = random.randint(self.rng, shape=(1, 512), maxval=256, minval=0)
params = block.init(self.init_rng, batch_tok, None, False)
Expand All @@ -194,6 +198,7 @@ def test_gpt_fwd_standard(self):
N=6,
dtype=None,
alibi_attn=True,
tp_comms=False
)
batch_tok = random.randint(self.rng, shape=(1, 512), maxval=256, minval=0)
params = block.init(self.init_rng, batch_tok, None, False)
Expand All @@ -217,6 +222,7 @@ def test_gpt_fwd_fp16(self):
N=6,
dtype=jnp.float16,
alibi_attn=True,
tp_comms=False
)
batch_tok = random.randint(self.rng, shape=(1, 512), maxval=256, minval=0)
params = block.init(self.init_rng, batch_tok, None, False)
Expand All @@ -240,6 +246,7 @@ def test_gpt_loss_standard(self):
N=6,
dtype=None,
alibi_attn=True,
tp_comms=False
)
batch_tok = random.randint(self.rng, shape=(1, 512), maxval=256, minval=0)
params = block.init(self.init_rng, batch_tok, None, False)
Expand All @@ -259,4 +266,4 @@ def test_gpt_loss_standard(self):

loss_external = cross_entropy_loss(oh_labels_shifted, logits_shifted)

self.assertEqual(loss, loss_external)
self.assertEqual(jnp.mean(loss), jnp.mean(loss_external))

0 comments on commit 871b017

Please sign in to comment.