Skip to content

Commit

Permalink
Updating MIXTRAL Models
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Dec 19, 2023
1 parent 51a3fb3 commit 2e1dfcd
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 30 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 28 additions & 3 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 19 additions & 24 deletions lib/python/EasyDel/modules/mixtral/modelling_mixtral_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,38 +615,33 @@ def __call__(self,
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype
)

def index_add_jax(x, dim, index, source):
"""
Add the elements of a source tensor to the elements of a self tensor at the specified indices.
Args:
x (jax.numpy.ndarray): The self tensor to which the elements of the source tensor will be added
dim (int): The dimension along which to index
index (jax.numpy.ndarray): The tensor containing the indices for insertion
source (jax.numpy.ndarray): The tensor containing the elements to add
"""

# Create a mask tensor that indicates which elements of the self tensor should be updated
indicator = jnp.zeros_like(x)
indicator = indicator.at[index].set(source)
# Add the elements of the source tensor to the corresponding elements of the self tensor
x = x + indicator

return x
def custom_index_add_without_index_add(
final_hidden_states_,
top_x_,
idx_,
current_hidden_states_
):
for i in range(top_x_.size):
# if (idx_[i]):
final_hidden_states_.at[top_x[i]].set(final_hidden_states_[top_x[i]] + current_hidden_states_[i])
return final_hidden_states_

for expert_idx, expert_layer in enumerate(self.layers):
selected_mask = expert_mask[expert_idx]
idx, top_x = jnp.where(selected_mask, size=selected_mask.shape[0])

idx, top_x = jnp.nonzero(selected_mask, size=selected_mask.shape[-1])
top_x = jnp.where(top_x != 0, top_x, -1)
if top_x.shape[0] == 0:
continue

current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)

current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

final_hidden_states = index_add_jax(
final_hidden_states = custom_index_add_without_index_add(
final_hidden_states,
0,
top_x,
idx,
current_hidden_states.astype(hidden_states.dtype)
)

Expand Down Expand Up @@ -690,7 +685,7 @@ def __call__(self, hidden_states: chex.Array) -> Tuple[chex.Array, chex.Array]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim)
router_logits = self.gate(hidden_states).astype(jnp.promote_types(self.dtype, jnp.float32))
routing_weights = jax.nn.softmax(router_logits, axis=1)
routing_weights = jax.nn.softmax(router_logits.astype(jnp.promote_types(self.dtype, jnp.float32)), axis=1)
routing_weights, selected_experts = jax.lax.top_k(routing_weights, k=self.config.num_experts_per_tok)
routing_weights /= jnp.sum(routing_weights, axis=-1, keepdims=True)
routing_weights = routing_weights.astype(hidden_states.dtype)
Expand Down Expand Up @@ -781,6 +776,7 @@ def __call__(
init_cache=init_cache,
output_attentions=output_attentions
)

hidden_states = residual + hidden_states

residual = hidden_states
Expand Down Expand Up @@ -913,7 +909,7 @@ def __init__(
super().__init__(
dtype=dtype, _do_init=_do_init,
module=module, config=config, input_shape=input_shape,
seed=seed
seed=seed,
)

def init_weights(
Expand All @@ -934,7 +930,6 @@ def init_weights(
:param params: flax.core.FrozenDict: Pass in the parameters of a pre-trained model
:return: A frozendict of parameters
"""

input_ids = jnp.zeros(input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids, dtype="i4")
position_ids = jnp.broadcast_to(
Expand Down
2 changes: 1 addition & 1 deletion lib/python/EasyDel/transform/easydel_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def huggingface_to_easydel(
for embedding_layer_name in embedding_layer_names:
if embedding_layer_name in key:
key = key[:-_l] + '.embedding'
elif match_keywords(key, ['kernel'], ['none']):
elif match_keywords(key, ['weight'], ['none']):
if len(tensor.shape) == 2:
tensor = tensor.transpose(0, 1)
if key.endswith('.weight'):
Expand Down
3 changes: 1 addition & 2 deletions python_test/mixtral_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ def main():
hidden_size=128,
num_attention_heads=8,
num_key_value_heads=4,
num_hidden_layers=4,
num_hidden_layers=1,
intermediate_size=256,
gradient_checkpointing='',
max_position_embeddings=seq_len
)
print('Model Config :\n', config)
batch_size = len(jax.devices())

torch_model = MixtralForCausalLM(
Expand Down

0 comments on commit 2e1dfcd

Please sign in to comment.