In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
import numpy as np
from torch import nn
import torch.nn.functional as F
import copy


from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
import numpy as np
import matplotlib.pyplot as plt
from transformers import PreTrainedModel
import seaborn as sns
from typing import List

In [3]:
import os

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
CACHE_DIR = '/scratch/' + os.environ['USER'] + '/huggingface_cache'
MODEL_NAME = "microsoft/phi-2"
# MODEL_NAME = "meta-llama/Llama-2-7b-hf"
HF_TOKEN = "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"
TORCH_DTYPE = torch.float32

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=TORCH_DTYPE, device_map="cuda", cache_dir=CACHE_DIR, token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR, token=HF_TOKEN)
model.to(DEVICE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2560)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (dense): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2560, out_features=10240, bias=True)
          (fc2): Linear(in_features=10240, out_features=2560, bias=True)
        )
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2560,),

In [6]:
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)

    
    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices


class SparseMoE(nn.Module):
    def __init__(self, mlp, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([copy.deepcopy(mlp).to(DEVICE) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)

        return final_output

In [7]:
def create_moe(model: PreTrainedModel, num_experts=4, top_k=2):
    # check if model has already been converted to MoE
    for name, module in model.named_modules():
        if isinstance(module, SparseMoE):
            print("Model is already converted to MoE.")
            return model

    # 1. get embedding size
    attention = model.model.layers[0].self_attn
    n_embed = attention.q_proj.in_features
    print("n_embed: ", n_embed)

    # 2. Find the existing MLP module
    existing_mlp = None
    for name, module in model.named_modules():
        if name == "model.layers.1.mlp":
            existing_mlp = module
            break

    if existing_mlp is None:
        raise ValueError("Could not find the MLP module in the model.")

    print("Existing MLP module:", type(existing_mlp))

    # 3. Build the SparseMoE model
    moe = SparseMoE(existing_mlp, n_embed, num_experts, top_k)
    moe.to(DEVICE)
    print(moe)

    # 4. Replace the existing MLP module with the SparseMoE model
    # Collect the names and modules to be replaced
    modules_to_replace = {}
    for name, module in model.named_modules():
        if isinstance(module, type(existing_mlp)):
            modules_to_replace[name] = module

    # Replace the modules
    for name, module in modules_to_replace.items():
        # Find the parent module and replace the original module with the new one
        if '.' in name:
            parent_name, child_name = name.rsplit('.', 1)
            parent_module = dict(model.named_modules())[parent_name]
            setattr(parent_module, child_name, moe)
        else:
            # For top-level modules
            setattr(model, name, moe)

    return model

moe_model = create_moe(model)
print(moe_model.to(DEVICE))
assert next(moe_model.parameters()).is_cuda


n_embed:  2560
Existing MLP module: <class 'transformers.models.phi.modeling_phi.PhiMLP'>
SparseMoE(
  (router): NoisyTopkRouter(
    (topkroute_linear): Linear(in_features=2560, out_features=4, bias=True)
    (noise_linear): Linear(in_features=2560, out_features=4, bias=True)
  )
  (experts): ModuleList(
    (0-3): 4 x PhiMLP(
      (activation_fn): NewGELUActivation()
      (fc1): Linear(in_features=2560, out_features=10240, bias=True)
      (fc2): Linear(in_features=10240, out_features=2560, bias=True)
    )
  )
)
PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2560)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (dense): Linea

In [8]:
inputs = tokenizer('''def print_prime(n):
   """
   Print all primes between 1 and n
   """''', return_tensors="pt", return_attention_mask=False).to(DEVICE)

In [9]:
# outputs = model.generate(**inputs, max_length=200)
print(inputs)
outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
print(text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


{'input_ids': tensor([[ 4299,  3601,    62, 35505,     7,    77,  2599,   198, 50285, 37811,
           198, 50285, 18557,   477,   778,   999,  1022,   352,   290,   299,
           198, 50285, 37811]], device='cuda:0')}
def print_prime(n):
   """
   Print all primes between 1 and n
   """ibusultnar
pricallys the V agprlordsedingingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsingsings
