In [25]:

from transformers.models.mixtral import MixtralConfig, MixtralModel
from accelerate import dispatch_model
mixtral_config = MixtralConfig(
    forward='base',
    hidden_size=16,
    num_experts=8,
    intermediate_size=32,
    vocab_size=256,
    num_attention_heads=4,
    num_key_value_heads=2,
    num_hidden_layers=2,
)

mixtral = MixtralModel(mixtral_config)


In [10]:
import regex as re


example = "layers.0.block_sparse_moe.experts.0.w1"

# make a regex that matches *{layer_id}.block_sparse_moe.experts.{expert_id}*

expert_re = re.compile(r"layers\.(?P<layer_id>\d+)\.block_sparse_moe\.experts\.(?P<expert_id>\d+)")

re.match(expert_re, example).groupdict()


{'layer_id': '0', 'expert_id': '0'}

In [26]:


mapping = {}
for name, m in mixtral.named_modules():
    splitted = name.split(".")
    
        
    if name == '':
        continue
        
    # match the expert name
    match = re.match(expert_re, name)
    if match:
        expert_id = int(match.group("expert_id"))
        mapping[name] = f"cuda:{expert_id % 2}"
    else:
        mapping[name] = "cuda:0"
        

In [27]:
mapping

{'embed_tokens': 'cuda:0',
 'layers': 'cuda:0',
 'layers.0': 'cuda:0',
 'layers.0.self_attn': 'cuda:0',
 'layers.0.self_attn.q_proj': 'cuda:0',
 'layers.0.self_attn.k_proj': 'cuda:0',
 'layers.0.self_attn.v_proj': 'cuda:0',
 'layers.0.self_attn.o_proj': 'cuda:0',
 'layers.0.self_attn.rotary_emb': 'cuda:0',
 'layers.0.block_sparse_moe': 'cuda:0',
 'layers.0.block_sparse_moe.gate': 'cuda:0',
 'layers.0.block_sparse_moe.experts': 'cuda:0',
 'layers.0.block_sparse_moe.experts.0': 'cuda:0',
 'layers.0.block_sparse_moe.experts.1': 'cuda:1',
 'layers.0.block_sparse_moe.experts.2': 'cuda:0',
 'layers.0.block_sparse_moe.experts.3': 'cuda:1',
 'layers.0.block_sparse_moe.experts.4': 'cuda:0',
 'layers.0.block_sparse_moe.experts.5': 'cuda:1',
 'layers.0.block_sparse_moe.experts.6': 'cuda:0',
 'layers.0.block_sparse_moe.experts.7': 'cuda:1',
 'layers.0.input_layernorm': 'cuda:0',
 'layers.0.post_attention_layernorm': 'cuda:0',
 'layers.1': 'cuda:0',
 'layers.1.self_attn': 'cuda:0',
 'layers.1.self_