Skip to content

Commit

Permalink
add explicit comms to model fwd pass
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 28, 2023
1 parent d023e90 commit 088f3fc
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 105 deletions.
2 changes: 1 addition & 1 deletion conf/model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
test:
embedding_dim: 64
vocab_size: 256
num_head: 4
num_head: 8
block_size: 32
dropout: 0.0
N: 2
Expand Down
1 change: 0 additions & 1 deletion src/models/GPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from src.models.layers import CausalAttention, MLPBlock
from src.utils.losses import cross_entropy_loss
from einops import rearrange


class TransformerBlock(nn.Module):
Expand Down
95 changes: 44 additions & 51 deletions src/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,22 @@
import jax.numpy as jnp
from flax.linen import partitioning as nn_partitioning
from einops import rearrange


def shard_noop(x, spec):
return nn_partitioning.with_sharding_constraint(x, spec)
from src.models.replicated_utils import *


def get_slopes(n: int) -> List:
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
return [start * ratio ** i for i in range(n)]

if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 *
closest_power_of_2)[0::2][: n - closest_power_of_2]
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)


Expand Down Expand Up @@ -57,8 +53,10 @@ class MLPBlock(nn.Module):

@nn.compact
def __call__(self, x: jnp.array, train: bool) -> jnp.array:
dropout = partial(nn.Dropout, rate=self.dropout,
deterministic=not train)
dropout = partial(nn.Dropout, rate=self.dropout, deterministic=not train)

x = f_psum(x)

x = nn.Dense(
features=self.dimension_multiplier * self.embedding_dim,
name="fc_in",
Expand All @@ -71,12 +69,12 @@ def __call__(self, x: jnp.array, train: bool) -> jnp.array:
out = nn.Dense(
features=self.embedding_dim,
name="fc_residual",
kernel_init=initializers.normal(
stddev=(0.02 / jnp.sqrt(2 * self.N))),
kernel_init=initializers.normal(stddev=(0.02 / jnp.sqrt(2 * self.N))),
bias_init=initializers.zeros,
dtype=self.dtype,
use_bias=False,
)(x)
out = g_psum(out)
return dropout()(out)


Expand All @@ -100,8 +98,9 @@ class CausalAttention(nn.Module):

def setup(self):
self.slopes = jnp.array(get_slopes(self.num_head))
self.mask = jnp.tril(jnp.ones((self.block_size, self.block_size), dtype=jnp.int8)).reshape(
1, 1, self.block_size, self.block_size)
self.mask = jnp.tril(
jnp.ones((self.block_size, self.block_size), dtype=jnp.int8)
).reshape(1, 1, self.block_size, self.block_size)
self.alibi_mask = create_mask(self.block_size, self.slopes)

@nn.compact
Expand All @@ -110,42 +109,37 @@ def __call__(
x: jnp.array,
train: bool,
) -> jnp.array:
dropout = partial(nn.Dropout, rate=self.dropout,
deterministic=not train)
dropout = partial(nn.Dropout, rate=self.dropout, deterministic=not train)
T, C = x.shape[-2:]

key = (
nn.Dense(
name="key_proj",
features=self.embedding_dim,
kernel_init=initializers.normal(stddev=0.02),
bias_init=initializers.zeros,
dtype=self.dtype,
use_bias=False,
)(x)
)
x = f_psum(x)

value = (
nn.Dense(
name="value_proj",
features=self.embedding_dim,
kernel_init=initializers.normal(stddev=0.02),
bias_init=initializers.zeros,
dtype=self.dtype,
use_bias=False,
)(x)
)
key = nn.Dense(
name="key_proj",
features=self.embedding_dim,
kernel_init=initializers.normal(stddev=0.02),
bias_init=initializers.zeros,
dtype=self.dtype,
use_bias=False,
)(x)

query = (
nn.Dense(
name="query_proj",
features=self.embedding_dim,
kernel_init=initializers.normal(stddev=0.02),
bias_init=initializers.zeros,
dtype=self.dtype,
use_bias=False,
)(x)
)
value = nn.Dense(
name="value_proj",
features=self.embedding_dim,
kernel_init=initializers.normal(stddev=0.02),
bias_init=initializers.zeros,
dtype=self.dtype,
use_bias=False,
)(x)

query = nn.Dense(
name="query_proj",
features=self.embedding_dim,
kernel_init=initializers.normal(stddev=0.02),
bias_init=initializers.zeros,
dtype=self.dtype,
use_bias=False,
)(x)

key = rearrange(
key, "b t (nh hd) -> b nh hd t", nh=self.num_head, hd=C // self.num_head
Expand All @@ -158,22 +152,19 @@ def __call__(
)

# get raw attention scores
attn_full = (query @ key) / jnp.sqrt(
key.shape[-1]
) # Shape is (B, nh, sq, sk)
attn_full = (query @ key) / jnp.sqrt(key.shape[-1]) # Shape is (B, nh, sq, sk)

if self.alibi_attn:
# NOTE: We are fixing the ALiBi mask since this is for training, during inference or is seq_len changes this will cause issues
attn_full = attn_full + self.alibi_mask[:, :T, :T]

masked_attn = jnp.where(
self.mask, attn_full.astype(
jnp.float32), jnp.finfo(jnp.float32).min
self.mask, attn_full.astype(jnp.float32), jnp.finfo(jnp.float32).min
)

attn_scores = nn.softmax(masked_attn, axis=-1)
attn_scores = dropout()(attn_scores)
attn_out = (attn_scores @ value)
attn_out = attn_scores @ value
attn_out = rearrange(attn_out, "b n t h -> b t (n h)")

out = nn.Dense(
Expand All @@ -187,4 +178,6 @@ def __call__(
use_bias=False,
)(attn_out)

out = g_psum(out)

return dropout()(out)
54 changes: 54 additions & 0 deletions src/models/replicated_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
From https://github.com/kingoflolz/mesh-transformer-jax/blob/master/mesh_transformer/util.py#L99
"""
import jax


@jax.custom_vjp
def f_psum(x):
return x


def f_psum_fwd(x):
return f_psum(x), None


def f_psum_bwd(_, g):
return (jax.lax.psum(g, "mp"),)


f_psum.defvjp(f_psum_fwd, f_psum_bwd)


# identity in forward pass, pmean in backward
@jax.custom_vjp
def f_pmean(x):
return x


def f_pmean_fwd(x):
return f_psum(x), None


def f_pmean_bwd(_, g):
return (jax.lax.pmean(g, "mp"),)


f_pmean.defvjp(f_pmean_fwd, f_pmean_bwd)


# psum in forward pass, identity in backward
@jax.custom_vjp
def g_psum(x):
return jax.lax.psum(x, "mp")


def g_psum_fwd(x):
return g_psum(x), None


def g_psum_bwd(_, g):
return (g,)


g_psum.defvjp(g_psum_fwd, g_psum_bwd)
Loading

0 comments on commit 088f3fc

Please sign in to comment.