
# Mixutre-of-Experts - Achieve Massively Scaled, but Efficient, LLM Peformance
 explore how to build  own, simplified version of a mixture-of-experts (MoE) LLM system. While this method often involves a complex training and transformer configuration,  some of the benefits of this approach in a pseudo-MoE that  will build with some open source LLMs.


1. Create own MoE system using open source LLMs
1. Build different gating mechanisms to direct different prompts to appropriate "expert models"

In [None]:
import torch

#  The Pseudo MoE Model
 implement a simplified version of an MoE model. Instead of training the experts and gating function together,  use pre-trained transformer models as our experts and a simple rule-based function as our gating function.

 look at different types of gating mechanisms - hard gating, soft gating, and top-k gating.

In [None]:
# Import the necessary libraries
# transformers is a state-of-the-art library for Natural Language Processing tasks, providing a wide range of pre-trained models
from transformers import GPT2LMHeadModel, GPT2Tokenizer, BertForSequenceClassification, BertTokenizer, T5ForConditionalGeneration, T5Tokenizer
# torch.nn.functional provides functions that don't have any parameters, such as activation functions, loss functions etc.
import torch.nn.functional as F

# Load the GPT2 model and tokenizer
# GPT2 is an autoregressive language model that uses transformer blocks and byte-pair encoding
gpt2 = GPT2LMHeadModel.from_pretrained("gpt2-XL", cache_dir=DA.paths.datasets+"/models")
# The tokenizer is responsible for turning input data into the format that the model expects
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-XL", cache_dir=DA.paths.datasets+"/models")

# Load the BERT model and tokenizer
# BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based machine learning technique for natural language processing pre-training
bert = BertForSequenceClassification.from_pretrained("bert-base-uncased", cache_dir=DA.paths.datasets+"/models")
# The tokenizer is responsible for turning input data into the format that the model expects
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", cache_dir=DA.paths.datasets+"/models")

# Load the T5 model and tokenizer
# T5 (Text-to-Text Transfer Transformer) is a transformer model which treats every NLP problem as a text generation task
t5 = T5ForConditionalGeneration.from_pretrained("t5-base", cache_dir=DA.paths.datasets+"/models")
# The tokenizer is responsible for turning input data into the format that the model expects
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base", cache_dir=DA.paths.datasets+"/models"+"/models")

# Define the "hard gating" function
# This function decides which model to use based on the length of the input
def hard_gating_function(input):
    if len(input) < 10:
        # For inputs less than 10 characters long, use the GPT2 model
        return "gpt2", gpt2, gpt2_tokenizer
    elif len(input) < 100:
        # For inputs less than 100 characters long but greater than 10 characters, use the T5 model
        return "t5" , t5, t5_tokenizer
    else:
        # For inputs greater than 100 characters, use the BERT model
        return "bert", bert, bert_tokenizer

# Define the "soft gating" function
# This function assigns a weight to each model based on the length of the input, and all models are used to a certain extent to generate the output
def soft_gating_function(input):
    # The weights for each model are calculated using the softmax function, which outputs a probability distribution
    weights = F.softmax(torch.tensor([len(input), 100 - len(input), len(input)], dtype=torch.float), dim=0)
    # The weights for each model are returned along with the models and their tokenizers
    return {"gpt2": (gpt2, gpt2_tokenizer, weights[0]),
            "bert": (bert, bert_tokenizer, weights[1]),
            "t5": (t5, t5_tokenizer, weights[2])}

# Define the pseudo MoE model
# This function uses the gating function to decide which model(s) to use for a given input
def pseudo_moe_model(input, gating_function):
    if gating_function == "hard":
        # If the hard gating function is used, only one model is used for a given input
        model_name, model, tokenizer = hard_gating_function(input)
        inputs = tokenizer(input, return_tensors="pt")
        if model_name == "t5":
            # For T5, create a decoder input sequence consisting of only the <BOS> token
            decoder_inputs = tokenizer(["<pad>"], return_tensors="pt")["input_ids"]
            outputs = model(**inputs, decoder_input_ids=decoder_inputs)
        else:
            outputs = model(**inputs)
        # The output of the model is decoded into a string
        decoded_output = tokenizer.decode(outputs.logits[0].argmax(-1).tolist())
        # The name of the model used and the decoded output are returned
        return model_name, decoded_output
    else:  # soft gating
        # If the soft gating function is used, all models are used to a certain extent to generate the output
        models = soft_gating_function(input)
        outputs = []
        for model_name, (model, tokenizer, weight) in models.items():
            inputs = tokenizer(input, return_tensors="pt")
            if model_name == "t5":
                # For T5, create a decoder input sequence consisting of only the <BOS> token
                decoder_inputs = tokenizer(["<pad>"], return_tensors="pt")["input_ids"]
                output = model(**inputs, decoder_input_ids=decoder_inputs)
            else:
                output = model(**inputs)
            # The output of each model is multiplied by its weight
            outputs.append((model_name, output.logits * weight))
        # The outputs of all models are added together to generate the final output
        decoded_outputs = [(model_name, tokenizer.decode(output[0].argmax(-1).tolist())) for model_name, output in outputs]
        # The decoded outputs are returned
        return decoded_outputs


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]



Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [None]:
# Test the hard gating function
example_1 = "Translate to german: This is a short input."
output = pseudo_moe_model(example_1, gating_function="hard")
print("Hard gating output:", output)

# Test the soft gating function
example_2 = "This is a longer input. We're adding more text here to make sure it's longer than 50 characters but shorter than 100 characters."
output = pseudo_moe_model(example_2, gating_function="soft")
print("Soft gating output:", output)


Hard gating output: ('t5', '<extra_id_0>This')
Soft gating output: [('gpt2', '—ation says reported After și readingatoration anyone each of given dem off usinguk After E valuable of ou Aftereaza of of și'), ('bert', '<pad>'), ('t5', '<extra_id_0>We')]
