In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import io

import transformers  
from transformers import PretrainedConfig, PreTrainedModel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel

from huggingface_hub import login, whoami

from vllm.entrypoints.llm import LLM
from vllm import ModelRegistry

  from .autonotebook import tqdm as notebook_tqdm


INFO 06-24 18:06:27 [__init__.py:244] Automatically detected platform cuda.


2025-06-24 18:06:31,256	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
draft_name = "/home/sagemaker-user/efs/model/Qwen3-0.6B"
target_name = "/home/sagemaker-user/efs/model/Qwen3-8B"

In [29]:
class EnsembleConfig(PretrainedConfig):
    model_type = "customize_ensemble"
    
    def __init__(
        self,
        hidden_size=4096,
        vocab_size=151936,
        target_model_path=None,
        draft_model_path=None,
        trust_remote_code=True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.ensemble_hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.target_model_path = target_model_path
        self.draft_model_path = draft_model_path
        self.trust_remote_code=trust_remote_code

class EnsembleModel(PreTrainedModel):
    config_class = EnsembleConfig

    def __init__(self, config: EnsembleConfig):
        super().__init__(config)
        
        self.target_model = AutoModel.from_pretrained(config.target_model_path)
        self.draft_model = AutoModel.from_pretrained(config.draft_model_path)
        
        self.ensemble_head = nn.Linear(config.ensemble_hidden_size, 2, bias=False)


    def forward(self, input_ids, **kwargs):
        draft_logits = self.draft_model(input_ids, **kwargs).logits  # [B, T, V]
        target_output = self.target_model(input_ids, output_hidden_states=True, **kwargs)
        target_logits = target_output.logits                          # [B, T, V]
        last_hidden = target_output.hidden_states[-1]                # [B, T, H]

        # Compute ensemble weights
        weights = F.softmax(self.ensemble_head(last_hidden), dim=-1)  # [B, T, 2]

        # Expand weights to match logits shape
        w_draft = weights[..., 0].unsqueeze(-1)  # [B, T, 1]
        w_target = weights[..., 1].unsqueeze(-1)

        # Weighted logits
        ensemble_logits = w_draft * draft_logits + w_target * target_logits

        return CausalLMOutput(logits=ensemble_logits)

In [21]:
def create_ensemble_model(target_name, draft_name):

    target_config = AutoConfig.from_pretrained(target_name)
    draft_config = AutoConfig.from_pretrained(draft_name)
    hidden_size = target_config.vocab_size+draft_config.vocab_size+target_config.hidden_size
    config = EnsembleConfig(hidden_size=hidden_size, vocab_size=target_config.vocab_size, 
                            target_model_path=target_name, draft_model_path=draft_name)
    
    return EnsembleModel(config), config

In [22]:
AutoConfig.register("customize_ensemble", EnsembleConfig)
AutoModel.register(EnsembleConfig, EnsembleModel)

In [23]:
model, config = create_ensemble_model(target_name, draft_name)

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.89it/s]


In [30]:
config.auto_map = {
    "AutoConfig": "modeling.configuration:EnsembleConfig",
    "AutoModel": "modeling.ensemble_model:EnsembleModel"
}

In [31]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")

config.save_pretrained("/home/sagemaker-user/efs/ensemble_model/Qwen3-8B_Qwen3-0.6B")
tokenizer.save_pretrained("/home/sagemaker-user/efs/ensemble_model/Qwen3-8B_Qwen3-0.6B")

In [None]:
model.save_pretrained("/home/sagemaker-user/efs/ensemble_model/Qwen3-8B_Qwen3-0.6B")

In [32]:
ensemble_config = AutoConfig.from_pretrained("/home/sagemaker-user/efs/ensemble_model/Qwen3-8B_Qwen3-0.6B")

In [33]:
print(ensemble_config)

EnsembleConfig {
  "architectures": [
    "EnsembleModel"
  ],
  "auto_map": {
    "AutoConfig": "modeling.configuration:EnsembleConfig",
    "AutoModel": "modeling.ensemble_model:EnsembleModel"
  },
  "draft_model_path": "/home/sagemaker-user/efs/model/Qwen3-0.6B",
  "hidden_size": 307968,
  "model_type": "customize_ensemble",
  "target_model_path": "/home/sagemaker-user/efs/model/Qwen3-8B",
  "torch_dtype": "float32",
  "transformers_version": "4.53.0.dev0",
  "trust_remote_code": true,
  "vocab_size": 151936
}



In [14]:
ensemble_model = AutoModel.from_pretrained("/home/sagemaker-user/efs/ensemble_model/Qwen3-8B_Qwen3-0.6B")

Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.21s/it]
Loading checkpoint shards: 100%|██████████| 7/7 [00:00<00:00, 57.88it/s]


In [15]:
print(ensemble_model)

EnsembleModel(
  (target_model): Qwen3Model(
    (embed_tokens): Embedding(151936, 4096)
    (layers): ModuleList(
      (0-35): 36 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear(in_features=4096, out_features=12288, bias=False)
          (up_proj): Linear(in_features=4096, out_features=12288, bias=False)
          (down_proj): Linear(in_features=12288, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((4096,), eps=1e-06)
        (post_attention_layernor

151936
2560


fatal: could not create work tree dir 'vllm': No space left on device


In [None]:
export VLLM_PRECOMPILED_WHEEL_LOCATION=https://github.com/vllm-project/vllm/releases/download/v0.9.1/vllm-0.9.1+cu126-cp38-abi3-manylinux1_x86_64.whl

In [34]:
from ensemble_model.configuration import EnsembleConfig

ModuleNotFoundError: No module named 'ensemble_model'