<a href="https://colab.research.google.com/github/bodeby/consensus/blob/main/notebooks/union_mapping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from typing import List, Tuple, Dict, Union
from dataclasses import dataclass
from torch import Tensor
import torch
import torch.nn.functional as F
import logging
import warnings
import re

In [2]:
from huggingface_hub import login
from google.colab import userdata

# Replace 'YOUR_TOKEN' with your actual Hugging Face token
login(token=userdata.get("HF_TOKEN"), add_to_git_credential=True)

In [6]:
device = ("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
# LM: Llama 3.2 with 3B params and instruction tuning
name = "meta-llama/Llama-3.2-3B-Instruct"
m1 = AutoModelForCausalLM.from_pretrained(name)
t1 = AutoTokenizer.from_pretrained(name)

In [None]:
# LM: Qwen 2.5 with 3B params and instruction tuning
name = "Qwen/Qwen2.5-3B-Instruct"
m2 = AutoModelForCausalLM.from_pretrained(name)
t2 = AutoTokenizer.from_pretrained(name)

In [None]:
models = [m1, m2]
tokenizers = [t1, t2]

### Configuration for Ensemble


- top_k:
- device:
- temperature:
- min_probability
- batch_size:
- pad_token_id:
- filter_special_tokens:
- strip_spaces:

In [3]:
@dataclass
class EnsembleConfig:
    """Configuration for ensemble generation."""
    top_k: int = 10
    device: str = 'cuda'
    temperature: float = 1.0
    min_probability: float = 0.001
    batch_size: int = 1
    pad_token_id: int = None
    filter_special_tokens: bool = True  # New parameter
    strip_spaces: bool = True  # New parameter

In [9]:
class EnsembleGenerator:
    def __init__(
        self,
        models: List[AutoModelForCausalLM],
        tokenizers: List[AutoTokenizer],
        config: EnsembleConfig = None
    ):
        if len(models) != len(tokenizers):
            raise ValueError("Number of models must match number of tokenizers")

        # Definitions
        self.models = models
        self.tokenizers = tokenizers
        self.config = config or EnsembleConfig()

        # Setup logging
        self.logger = logging.getLogger(__name__)

        # Validate device availability
        self.device = torch.device(self.config.device if torch.cuda.is_available() else 'cpu')
        if self.config.device == 'cuda' and not torch.cuda.is_available():
            warnings.warn("CUDA requested but not available. Using CPU instead.")

        # Set up padding tokens
        self._setup_padding()

        # Create vocabulary mapping between models
        self.vocab_mappings = self._create_vocab_mappings()

        # Move models to device
        self._prepare_models()

    # Setup padding tokens for each tokenizer.
    def _setup_padding(self):
        for tokenizer in self.tokenizers:
            # If tokenizer doesn't have a pad token, use eos token
            if tokenizer.pad_token is None:
                if tokenizer.eos_token is not None:
                    tokenizer.pad_token = tokenizer.eos_token
                else:
                    # Last resort: add a new padding token
                    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    # Move models to specified device and set to evaluation mode.
    # FIXME: moving all models to GPU at once is not necessary
    def _prepare_models(self) -> None:
        """"""
        for model in self.models:
            model.to(self.device)
            model.eval()

    # Create mappings between each model's vocabulary and the first model's vocabulary.
    def _create_vocab_mappings(self) -> List[Dict[int, int]]:
        """"""
        mappings = []
        base_tokenizer = self.tokenizers[0]
        base_vocab = base_tokenizer.get_vocab()

        for tokenizer in self.tokenizers:
            current_vocab = tokenizer.get_vocab()
            mapping = {}

            for token, idx in current_vocab.items():
                if token in base_vocab:
                    mapping[idx] = base_vocab[token]

            mappings.append(mapping)

        return mappings

    # Pad input sequences to the same length and create attention masks.
    def _pad_inputs(self, token_ids: List[List[int]]) -> Tuple[List[torch.Tensor], torch.Tensor]:
        """
        Returns:
            Tuple[List[torch.Tensor], torch.Tensor]: Padded inputs and attention mask
        """
        max_length = max(len(ids) for ids in token_ids)
        padded_inputs = []
        attention_masks = []

        for idx, (tokenizer, ids) in enumerate(zip(self.tokenizers, token_ids)):
            padding_length = max_length - len(ids)
            pad_token_id = tokenizer.pad_token_id

            # Pad the sequence
            padded_sequence = ids + [pad_token_id] * padding_length
            attention_mask = [1] * len(ids) + [0] * padding_length

            padded_inputs.append(torch.tensor([padded_sequence], device=self.device))
            attention_masks.append(torch.tensor([attention_mask], device=self.device))

        return padded_inputs, attention_masks

    # Align logits from a model to the vocabulary space of the first model.
    def _align_logits(self, logits: Tensor, model_idx: int) -> Tensor:
        if model_idx == 0:
            return logits

        mapping = self.vocab_mappings[model_idx]
        base_vocab_size = len(self.tokenizers[0].get_vocab())
        aligned_logits = torch.full(
            (logits.shape[0], logits.shape[1], base_vocab_size),
            float('-inf'),
            device=logits.device
        )

        for src_idx, tgt_idx in mapping.items():
            # can be left here.
            if src_idx < logits.shape[-1]:
                aligned_logits[:, :, tgt_idx] = logits[:, :, src_idx]

        return aligned_logits

    # Check if a token is a special token.

    # TODO: Be very careful when cleaning token, example: [ĠParis] -> Paris
    def _is_special_token(self, token: str) -> bool:
      # Define patterns for special tokens
      special_patterns = [
          r'^\s+$',  # Only whitespace
          r'\\n',    # Newlines
          r'[^\w\s]' # Special characters
      ]

      return any(re.search(pattern, token) for pattern in special_patterns)

    # Clean a token by removing leading/trailing spaces if configured.
    def _clean_token(self, token: str) -> str:
        if self.config.strip_spaces:
            return token.strip()
        return token

    @torch.no_grad()
    def _compute_ensemble_logits(
        self,
        token_ids: List[List[int]],
        padded_inputs: List[Tensor] = None,
        attention_masks: List[Tensor] = None
    ) -> Tensor:
        """Compute and combine logits from all models, aligning vocabularies."""
        base_vocab_size = len(self.tokenizers[0].get_vocab())

        total_logits = torch.zeros(
            self.config.batch_size,
            padded_inputs[0].shape[1],
            base_vocab_size,
            device=self.device
        )

        valid_model_count = torch.zeros(
            (self.config.batch_size, padded_inputs[0].shape[1], base_vocab_size),
            device=self.device
        )

        # NOT SURE if i have to pad the inputs. !!! Because of different sizes in tokenized prompts
        for idx, (model, inputs, attention_mask) in enumerate(zip(self.models, padded_inputs, attention_masks)):
            try:
                # Process input with attention mask
                outputs = model(inputs, attention_mask=attention_mask)
                logits = outputs.logits

                # Align logits with base vocabulary
                aligned_logits = self._align_logits(logits, idx)

                # Apply temperature scaling ! NICE TO HAVE, not needed
                if self.config.temperature != 1.0:
                    aligned_logits = aligned_logits / self.config.temperature

                # TODO: If the union vocab, then we dont need the mask
                # Add to total logits where valid (not -inf)
                mask = aligned_logits != float('-inf')
                total_logits[mask] += aligned_logits[mask] # FIXME: We have to average Probabilities using the logits, not the logits.
                valid_model_count[mask] += 1

            except Exception as e:
                self.logger.warning(f"Error processing model {idx}: {str(e)}")
                continue

        # Average logits by the number of valid predictions for each token
        valid_model_count = torch.clamp(valid_model_count, min=1)
        return total_logits / valid_model_count

    def generate(
        self,
        prompt: str,
        custom_top_k: int = None,
        min_probability: float = None,
        filter_special: bool = None,
        strip_spaces: bool = None
    ) -> List[Tuple[str, float]]:
        """
        Generate ensemble predictions for the given prompt.

        Args:
          prompt: Input text prompt
          custom_top_k: Override default top_k value
          min_probability: Override default minimum probability threshold
          filter_special: Override default special token filtering
          strip_spaces: Override default space stripping behavior
        """
        if not prompt or not isinstance(prompt, str):
            raise ValueError("Prompt must be a non-empty string")

        # Use provided parameters or fall back to config defaults
        filter_special = filter_special if filter_special is not None else self.config.filter_special_tokens
        strip_spaces = strip_spaces if strip_spaces is not None else self.config.strip_spaces
        min_prob = min_probability if min_probability is not None else self.config.min_probability

        try:
            # Encode prompt with each tokenizer
            token_ids = [
                tokenizer.encode(prompt, add_special_tokens=True)
                for tokenizer in self.tokenizers
            ]

            # Pad inputs and create attention masks
            padded_inputs, attention_masks = self._pad_inputs(token_ids)

            # Get ensemble logits
            averaged_logits = self._compute_ensemble_logits(token_ids, attention_masks)
            # averaged_logits = self._compute_ensemble_logits(padded_inputs, attention_masks)

            # Convert to probabilities (use only the last token)
            probs = F.softmax(averaged_logits, dim=-1)
            last_token_probs = probs[0, -1, :].cpu().numpy()

            # Get token-probability pairs
            base_tokenizer = self.tokenizers[0]
            token_prob_pairs = []

            for idx, prob in enumerate(last_token_probs):
                if prob >= min_prob:
                    token = base_tokenizer.decode([idx])

                    # Apply filtering if enabled
                    if filter_special and self._is_special_token(token):
                        continue

                    # Clean token if enabled
                    cleaned_token = self._clean_token(token)
                    if cleaned_token:  # Skip empty tokens
                        token_prob_pairs.append((cleaned_token, float(prob)))

            # Sort by probability and get top-k
            token_prob_pairs.sort(key=lambda x: x[1], reverse=True)
            k = min(custom_top_k or self.config.top_k, len(token_prob_pairs))

            return token_prob_pairs[:k]

        except Exception as e:
            self.logger.error(f"Generation failed: {str(e)}")
            raise RuntimeError(f"Generation failed: {str(e)}")


In [None]:
config = EnsembleConfig(
    top_k=10,
    device=device,
    temperature=0.7
)

generator = EnsembleGenerator(models, tokenizers, config)
results = generator.generate("What is the capital of France?", filter_special=False)

In [None]:
results

[('The', 0.7566375136375427),
 ('Paris', 0.18182174861431122),
 ('France', 0.005636075511574745),
 ('To', 0.0036279878113418818),
 ('This', 0.0016024563228711486)]

How to Extend to Multi-Step Generation

Generating a multi-token response, would need a loop that:

- Feeds back the generated token into the prompt for the next prediction step.
- Aggregates predictions until reaching a desired length or an end-of-sequence token.

## Union Vocab Experimenting

In [None]:
# Create a union vocabulary
vocab = {}
for tokenizer in tokenizers:
    for token, idx in tokenizer.get_vocab().items():
        if token not in vocab:
            vocab[token] = idx

# Manually add [UNK] token if not present
if '[UNK]' not in vocab:
    vocab['[UNK]'] = len(vocab)

# Create a map from model-specific token indices to union vocabulary indices
token_to_union_idx = {token: idx for idx, (token, _) in enumerate(vocab.items())}

In [None]:
("idx", len(token_to_union_idx)), ("vocab", len(vocab))

(('idx', 150617), ('vocab', 150617))

In [None]:
def encode_with_union_vocab(tokenizers, prompt):
    # Tokenize the prompt with each model's tokenizer, using the union vocabulary
    token_ids = []
    for tokenizer in tokenizers:
        tokenized = tokenizer(prompt)
        encoded = [token_to_union_idx.get(t, token_to_union_idx['[UNK]']) for t in tokenized.input_ids]

        # Ensure token IDs are within the model's vocabulary range
        # Get the model's vocabulary size
        vocab_size = tokenizer.vocab_size

        # Clip token IDs to be within the valid range [0, vocab_size - 1]
        encoded = [min(id, vocab_size - 1) for id in encoded]

        token_ids.append(encoded)

    return token_ids

In [None]:
def ensemble_generate(models, tokenizers, prompt, top_k=10, device='cuda'):
    models = [model.to(device) for model in models]  # Move models to device
    token_ids = encode_with_union_vocab(tokenizers, prompt)

    # Initialize a tensor to accumulate logits with the shape of the largest vocabulary
    max_vocab_size = max(tokenizer.vocab_size for tokenizer in tokenizers)
    total_logits = torch.zeros(1, len(token_ids[0]), max_vocab_size, device=device)

    with torch.no_grad():
        for model, tokenized_prompt in zip(models, token_ids):
            # Convert tokenized prompt to tensor
            inputs = torch.tensor([tokenized_prompt]).to(device)

            # Get the model output (logits)
            outputs = model(inputs)
            logits = outputs.logits  # Shape: [batch_size, seq_len, vocab_size]

            # Pad logits with zeros to match the maximum vocabulary size
            padding_size = max_vocab_size - logits.shape[-1]
            if padding_size > 0:
                logits = torch.nn.functional.pad(logits, (0, padding_size), value=0)

            # Accumulate logits
            total_logits += logits  # Sum the logits from each model

    # Average the logits across models
    averaged_logits = total_logits / len(models)

    # Convert logits to probabilities
    probabilities = F.softmax(averaged_logits, dim=-1)

    # Get the top-k token probabilities and corresponding tokens for the last token
    last_token_probs = probabilities[0, -1, :].cpu().numpy()
    top_k_indices = last_token_probs.argsort()[-top_k:][::-1]
    top_k_probs = last_token_probs[top_k_indices]
    top_k_tokens = [list(vocab.keys())[i] for i in top_k_indices]

    # Combine tokens and probabilities into tuples
    top_k_results = list(zip(top_k_tokens, top_k_probs))

    return top_k_results

In [None]:
prompt = "What is the capital of France?"

# Get the top-k predictions from the ensemble
top_k_tokens = ensemble_generate(models, tokenizers, prompt, top_k=5, device='cpu')

RuntimeError: The size of tensor a (128000) must match the size of tensor b (128256) at non-singleton dimension 2

### Example Code to Release HF MODEL

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

class TransformerEnsemble(torch.nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models

    def forward(self, input_ids, attention_mask=None):
        # Get outputs from each model
        outputs = [model(input_ids=input_ids, attention_mask=attention_mask) for model in self.models]

        # Example of averaging logits across models
        logits = torch.stack([output.logits for output in outputs], dim=0)
        ensemble_logits = logits.mean(dim=0)

        return ensemble_logits

# Example models
model_paths = ["meta-llama/Llama-3.2-3B-Instruct", "gpt2", "bert-base-uncased"]
models = [AutoModelForCausalLM.from_pretrained(path) for path in model_paths]

ensemble_model = TransformerEnsemble(models)

# Example inference
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
inputs = tokenizer("What is the capital of France?", return_tensors="pt")

with torch.no_grad():
    ensemble_output = ensemble_model(**inputs)

print(ensemble_output)