In [1]:
from typing import Dict

from transformers.models.dupxtral.configuration_dupxtral import DupxtralConfig
from transformers.models.mixtral.configuration_mixtral import MixtralConfig

from transformers.models.dupxtral.modeling_dupxtral import DupxtralModel
from transformers.models.mixtral.modeling_mixtral import MixtralModel

import torch
from torch import nn
from torch.nn import functional as F




In [2]:


N_expert = 8
N_hidden_layers = 4


# All layers have the same duplication pattern
duplicate_experts = [[3, 2, 2, 1, 3, 1, 1, 1] for _ in range(N_hidden_layers)]



In [3]:

# We remap the experts randomly: each initial expert usage is remapped to a random duplicate
pair_remaping = [dict() for _ in range(N_hidden_layers)]

for l in range(N_hidden_layers):
    for i in range(N_expert):
        for j in range(N_expert):
            n_dup_i = duplicate_experts[l][i]
            n_dup_j = duplicate_experts[l][j]
            
            s_i = sum(duplicate_experts[l][:i])
            s_j = sum(duplicate_experts[l][:j])

            # random remapping
            
            new_i = s_i + torch.randint(0, n_dup_i, (1,)).item()
            new_j = s_j + torch.randint(0, n_dup_j, (1,)).item()

            pair_remaping[l][(i, j)] = (new_i, new_j)


In [4]:


# We make initialize two dummy models. One normal mixtral and one dupxtral with the same configuration + duplication

config_dupxtral = DupxtralConfig(hidden_size=16, num_hidden_layers=N_hidden_layers, num_local_experts=N_expert,
                                 intermediate_size=32,
                                 num_attention_heads=8, experts_duplicate=duplicate_experts,
                                 vocab_size=128, experts_remapping=pair_remaping,
                                 )

config_mixtral = MixtralConfig(hidden_size=16, num_hidden_layers=N_hidden_layers, num_local_experts=N_expert,
                               intermediate_size=32,
                               num_attention_heads=8,
                               vocab_size=128,
                               forward='parallel',
                               )


In [5]:
# We build the models architecture and initialize the weights randomly

initial_model = MixtralModel(config=config_mixtral)

dupxtral_model = DupxtralModel(config=config_dupxtral)

In [6]:

# get states of the initial model
initial_state_dict = initial_model.state_dict() # this Dict[str, torch.Tensor]



# A lot of machinery for the very simple task of remapping the experts

def get_parts(k):
    parts = k.split('.')
    idx = parts.index('block_sparse_moe')

    layer_idx = int(parts[idx - 1])
    expert_idx = int(parts[idx + 2])
    prefix = '.'.join(parts[:idx - 1])
    suffix = '.'.join(parts[idx + 3:])

    return layer_idx, expert_idx, prefix, suffix


def get_experts_paths(initial_state_dict):
    '''
    Return the path to the top level of an expert block. In practice it has different children w1, w2, w3
    We remove the children to get the path to the top level
    '''
    expert_weights_names = [k for k in initial_state_dict.keys() if 'block_sparse_moe.experts' in k]
    experts_paths = set()
    for k in sorted(expert_weights_names):
        layer_idx, expert_idx, prefix, suffix = get_parts(k)

        expert_path = f"{prefix}.{layer_idx}.block_sparse_moe.experts.{expert_idx}"
        experts_paths.add(expert_path)

    return experts_paths


def get_all_paths_for_experts(initial_state_dict):
    '''
    :return: Dict[str, List[str]] where the key is the path to the top level of the expert block and the value is a list of the children paths
    '''
    experts_paths = get_experts_paths(initial_state_dict)

    all_paths = {}

    for expert_path in experts_paths:
        all_paths[expert_path] = [k[len(expert_path) + 1:]
                                  for k in initial_state_dict.keys() if expert_path in k]

    return all_paths


def create_duplicated_names(all_paths, experts_duplication):
    '''
    Build the new names for the duplicated experts and their children
    
    expert_id | new_id
    0         | 0
    0         | 1
    0         | 2
    1         | 3
    1         | 4
    2         | 5
    2         | 6
    ...
    
    Expert 0 is duplicated 3 times, expert 1 is duplicated 2 times, expert 2 is duplicated 2 times, etc.
    
    :return: Dict[str, str] where the key is the new name and the value is the old name
    '''
    layer_experts_current = {}

    expert_weights_path: Dict[str, str] = {}

    for expert_path, paths in sorted(all_paths.items()):
        layer_idx, expert_idx, prefix, _ = get_parts(expert_path)

        if layer_idx not in layer_experts_current:
            layer_experts_current[layer_idx] = 0

        layer_expert_idx = layer_experts_current[layer_idx]
        for _ in range(experts_duplication[layer_idx][expert_idx]):
            current_expert_path = f"{prefix}.{layer_idx}.block_sparse_moe.experts.{layer_expert_idx}"
            layer_expert_idx += 1
            for p in paths:
                expert_weights_path[f"{current_expert_path}.{p}"] = f"{expert_path}.{p}"

        layer_experts_current[layer_idx] = layer_expert_idx
            
    return expert_weights_path


def convert_initial_state_dict_to_dupxtral(initial_state_dict, config_dupxtral):
    '''
    Convert the initial state_dict to the dupxtral model state_dict
    :param initial_state_dict: 
    :param config_dupxtral: 
    :return: Dupxtral model state_dict Dict[str, torch.Tensor]
    '''
    expert_weights_names = [k for k in initial_state_dict.keys() if 'block_sparse_moe.experts' in k]

    new_state_dict = {}

    # copy everything from the initial state_dict but the expert weights
    for k, v in initial_state_dict.items():
        if k not in expert_weights_names:
            new_state_dict[k] = v

    all_paths = get_all_paths_for_experts(initial_state_dict)
    
    expert_weights_path = create_duplicated_names(all_paths, config_dupxtral.experts_duplicate)
    
    # actually duplicate the weights based on new and old names
    for new_name, old_name in expert_weights_path.items():
        new_state_dict[new_name] = initial_state_dict[old_name]
        print(new_name, old_name)

    return new_state_dict


new_state_dict = convert_initial_state_dict_to_dupxtral(initial_state_dict, config_dupxtral)



layers.0.block_sparse_moe.experts.0.w1.weight layers.0.block_sparse_moe.experts.0.w1.weight
layers.0.block_sparse_moe.experts.0.w2.weight layers.0.block_sparse_moe.experts.0.w2.weight
layers.0.block_sparse_moe.experts.0.w3.weight layers.0.block_sparse_moe.experts.0.w3.weight
layers.0.block_sparse_moe.experts.1.w1.weight layers.0.block_sparse_moe.experts.0.w1.weight
layers.0.block_sparse_moe.experts.1.w2.weight layers.0.block_sparse_moe.experts.0.w2.weight
layers.0.block_sparse_moe.experts.1.w3.weight layers.0.block_sparse_moe.experts.0.w3.weight
layers.0.block_sparse_moe.experts.2.w1.weight layers.0.block_sparse_moe.experts.0.w1.weight
layers.0.block_sparse_moe.experts.2.w2.weight layers.0.block_sparse_moe.experts.0.w2.weight
layers.0.block_sparse_moe.experts.2.w3.weight layers.0.block_sparse_moe.experts.0.w3.weight
layers.0.block_sparse_moe.experts.3.w1.weight layers.0.block_sparse_moe.experts.1.w1.weight
layers.0.block_sparse_moe.experts.3.w2.weight layers.0.block_sparse_moe.experts.

In [7]:
dupxtral_model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [8]:

# fake tokenized input for testing
input_ids = torch.randint(0, 128, (3, 11))


In [9]:
# Output with dupxtral model

output = dupxtral_model(input_ids)
output.last_hidden_state[0, 0, :9]
    


tensor([ 1.1253,  0.9082,  0.2913, -1.3833,  0.6576, -0.1448, -0.7565,  1.4091,
        -1.2657], grad_fn=<SliceBackward0>)

In [10]:
# Output with initial model

output_initial = initial_model(input_ids)
output_initial.last_hidden_state[0, 0, :9]


tensor([ 1.1253,  0.9082,  0.2913, -1.3833,  0.6576, -0.1448, -0.7565,  1.4091,
        -1.2657], grad_fn=<SliceBackward0>)

In [11]:
# single token
input_ids = torch.randint(0, 128, (1, 1))


In [12]:

# Output with dupxtral model
output = dupxtral_model(input_ids)
output.last_hidden_state[0, 0, :9]

tensor([-0.3803,  1.2390, -0.5362, -1.1918,  0.2402, -1.7913, -1.1191,  0.9954,
        -1.5569], grad_fn=<SliceBackward0>)

In [13]:
# Output with initial model
output_initial = initial_model(input_ids)
output_initial.last_hidden_state[0, 0, :9]



tensor([-0.3803,  1.2390, -0.5362, -1.1918,  0.2402, -1.7913, -1.1191,  0.9954,
        -1.5569], grad_fn=<SliceBackward0>)