In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, DistilBertTokenizer
import torch
import torch.autograd.profiler as profiler
import random, time
import json
import torch
from collections import Counter
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import DistilBertTokenizer, DistilBertModel
from torch.optim import AdamW
import numpy as np
from tqdm import tqdm

In [3]:
with open('winogrande.json', 'r') as f:
    winogrande_data = json.load(f)

with open('copa.json', 'r') as f:
    copa_data = json.load(f)

with open('hellaswag.json', 'r') as f:
    hellaswag_data = json.load(f)

samples_from_each = 256

winogrande_entries = winogrande_data[:samples_from_each]
copa_entries = copa_data[:samples_from_each]
hellaswag_entries = hellaswag_data[:samples_from_each]

combined_dataset = copa_entries + winogrande_entries + hellaswag_entries

combined_prompts = []
combined_answers = []
for example in combined_dataset:
    combined_answers.append(example["answer"])

    for option in example["options"]:
        combined_prompts.append(option)

print('Combined dataset created.')
print(f'Total entries: {len(combined_dataset)}')

Combined dataset created.
Total entries: 768


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

omission_sets = [
  (4, 16, 18, 19, 20, 22),
  (4, 13, 17, 18, 20, 21),
  (4, 12, 14, 17, 20, 21)
]

class DistilBERTLossPredictor(nn.Module):
    def __init__(self, dropout=0.3):
        super(DistilBERTLossPredictor, self).__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.regressor = nn.Linear(self.distilbert.config.hidden_size, 3)  # 3 outputs

    def forward(self, input_ids, attention_mask):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        dropped = self.dropout(cls_output)
        predictions = self.regressor(dropped)

        return predictions

def load_router_model(model_path="best_model.pt"):
    model = DistilBERTLossPredictor().to(device)
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

    checkpoint = torch.load(model_path, map_location=device)

    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    print(f"Loaded model trained until epoch {checkpoint['epoch']}, "f"val_loss={checkpoint['val_loss']:.4f}.")

    return model, tokenizer

def generate_skip_layers(prompts, router_model, outer_tokenizer, max_length=128):
    """
    For each prompt, use the router to predict losses for each omission set
    and select the omission set with minimum predicted loss.

    Returns:
        skip_layers: List[List[int]] of length len(prompts)
        chosen_indices: List[int] router-chosen omission-set index (0,1,2) per prompt
        predicted_losses: torch.Tensor of shape (len(prompts), 3)
    """
    print(f"Generating skip layers. Number of prompts: {len(prompts)}.")

    t0 = time.time()

    router_model.eval()

    # Batch tokenize prompts
    enc = router_tokenizer(
        prompts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )

    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)

    with torch.no_grad():
        # shape: [batch_size, 3] -> losses for [os1, os2, os3]
        predicted_losses = router_model(input_ids, attention_mask)
        noise_scale = 1 # try 0.01 → 0.1 depending on how strong you want the variation
        torch.manual_seed(45)  # Or your chosen seed
        predicted_losses = predicted_losses + noise_scale * torch.randn_like(predicted_losses)

    t1 = time.time()

    chosen_indices = predicted_losses.argmin(dim=-1)
    chosen_indices_cpu = chosen_indices.cpu().numpy()
    skip_layers = [omission_sets[idx] for idx in chosen_indices_cpu]

    t2 = time.time()

    frequencies = Counter(skip_layers)
    print(f"Omission set distribution: {frequencies}.")

    return skip_layers

router_model, router_tokenizer = load_router_model()

Loaded model trained until epoch 6, val_loss=1.6594.


In [8]:
model_name = "EleutherAI/gpt-neo-1.3B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
model.eval()

tokenizer_config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/5.31G [00:00<?, ?B/s]

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 2048)
    (wpe): Embedding(2048, 2048)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-23): 24 x GPTNeoBlock(
        (ln_1): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
          )
        )
        (ln_2): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=2048, out_features=8192, bias=True)
          (c_proj):

In [9]:
def compute_sequence_nll(log_softmax, input_ids, attention_mask):
    """
    Compute per-sequence negative log likelihood (average per token)
    for a batch of sequences.

    Args:
        log_softmax: (batch, seq, vocab) log-probs
        input_ids:   (batch, seq)
        attention_mask: (batch, seq)

    Returns:
        List[float]: NLL per sequence
    """
def compute_sequence_nll(logits, input_ids, attention_mask):
    # logits: [batch, seq, vocab]
    # input_ids, attention_mask: [batch, seq]
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    # shift input_ids for next-token prediction
    target_ids = input_ids[:, 1:]
    log_probs = log_probs[:, :-1, :]
    # gather log probs of correct tokens
    token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
    # mask padding tokens
    mask = attention_mask[:, 1:]
    nll = -(token_log_probs * mask).sum(dim=1) / mask.sum(dim=1)
    nll_list = nll.cpu().tolist()

    del log_probs, target_ids, token_log_probs, mask, nll
    torch.cuda.empty_cache()

    return nll_list

In [10]:
def skip_model_forward_pass(prompts):
    t0 = time.time()

    # Get omission sets for each prompt
    skip_layers = generate_skip_layers(prompts, router_model, router_tokenizer)

    t1 = time.time()

    # Tokenize all prompts
    tokenizer_output = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = tokenizer_output.input_ids.to(device)
    attention_mask = tokenizer_output.attention_mask.to(device)

    batch_size = len(prompts)

    with torch.no_grad():
        # Get embeddings
        hidden_states = model.transformer.wte(input_ids)

        position_ids = torch.arange(input_ids.shape[1], device=device).unsqueeze(0)
        position_embeds = model.transformer.wpe(position_ids)
        hidden_states = hidden_states + position_embeds

        seq_len = attention_mask.shape[1]

        # Create attention masks
        causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).view(1, 1, seq_len, seq_len)
        attention_mask_4d = attention_mask.view(batch_size, 1, 1, seq_len)
        combined_mask = causal_mask * attention_mask_4d
        combined_mask = (1.0 - combined_mask) * torch.finfo(torch.float16).min

        # Forward pass through transformer layers
        for layer_idx, layer in enumerate(model.transformer.h):
            start_time = time.time()
            active_indices = [i for i in range(batch_size) if layer_idx not in skip_layers[i]]

            if not active_indices:
                continue

            mini_batch = hidden_states[active_indices]
            mini_batch_mask = combined_mask[active_indices]

            mini_batch = layer(mini_batch, attention_mask=mini_batch_mask)[0]
            hidden_states[active_indices] = mini_batch

            end_time = time.time()
            print(f"Layer {layer_idx}: Prompts = {len(active_indices)}, Forward Pass Time: {end_time - start_time:.4} sec")

        # Final layer norm and projection
        hidden_states = model.transformer.ln_f(hidden_states)
        logits = model.lm_head(hidden_states)

        t2 = time.time()

        # Calculate log probabilities for each token in each prompt
        log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
        nll_losses = compute_sequence_nll(log_softmax, input_ids, attention_mask)

        print(f"Router: {t1 - t0:.4f}s | Forward: {t2 - t1:.4f}s | E2E: {t2 - t0:.4f}s")

        del hidden_states, logits, log_softmax

    return nll_losses

In [6]:
def skip_no_batched_model_forward_pass(prompts):
    t0 = time.time()

    # Get omission sets for each prompt (router can stay batched)
    skip_layers = generate_skip_layers(prompts, router_model, router_tokenizer)

    t1 = time.time()

    # Tokenize all prompts once; we'll slice per-prompt to get batch_size=1
    tokenizer_output = tokenizer(prompts, return_tensors="pt", padding=True)
    all_input_ids = tokenizer_output.input_ids.to(device)          # [B, S_max]
    all_attention_mask = tokenizer_output.attention_mask.to(device)  # [B, S_max]

    nll_losses = []

    with torch.no_grad():
        # Loop over prompts; each forward pass uses batch_size = 1
        for i in range(len(prompts)):
            prompt_start_time = time.time()

            # Slice out a single prompt: shapes [1, S]
            input_ids = all_input_ids[i:i+1]
            attention_mask = all_attention_mask[i:i+1]

            # Embeddings
            hidden_states = model.transformer.wte(input_ids)
            seq_len = input_ids.shape[1]

            # Positional embeddings (match seq_len of this prompt)
            position_ids = torch.arange(seq_len, device=device).unsqueeze(0)  # [1, S]
            position_embeds = model.transformer.wpe(position_ids)
            hidden_states = hidden_states + position_embeds

            # Attention masks for this single sequence (batch_size = 1)
            causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=device))
            causal_mask = causal_mask.view(1, 1, seq_len, seq_len)          # [1, 1, S, S]

            attention_mask_4d = attention_mask.view(1, 1, 1, seq_len)       # [1, 1, 1, S]
            combined_mask = causal_mask * attention_mask_4d
            combined_mask = (1.0 - combined_mask) * torch.finfo(torch.float16).min

            # Forward pass through transformer layers with dynamic pruning
            for layer_idx, layer in enumerate(model.transformer.h):
                if layer_idx in skip_layers[i]:
                    # This layer is pruned for this prompt
                    continue

                layer_start = time.time()
                hidden_states = layer(hidden_states, attention_mask=combined_mask)[0]
                layer_end = time.time()

                print(
                    f"Prompt {i}, Layer {layer_idx}: "
                    f"Forward Pass Time: {layer_end - layer_start:.4f} sec"
                )

            # Final layer norm and projection to logits
            hidden_states = model.transformer.ln_f(hidden_states)
            logits = model.lm_head(hidden_states)  # [1, S, V]

            # Compute NLL for this single sequence; returns [one_value]
            nll_list = compute_sequence_nll(logits, input_ids, attention_mask)
            nll_losses.extend(nll_list)

            prompt_end = time.time()
            print(
                f"Prompt {i}: Total Forward Time (bs=1) = "
                f"{prompt_end - prompt_start_time:.4f} sec"
            )

            # Cleanup per-prompt to be safe
            del input_ids, attention_mask, hidden_states, logits, nll_list
            torch.cuda.empty_cache()

        t2 = time.time()

    print(f"Router: {t1 - t0:.4f}s | Forward (all prompts, bs=1): {t2 - t1:.4f}s | E2E: {t2 - t0:.4f}s")

    return nll_losses

In [11]:
def standard_model_forward_pass(prompts):
    """
    Run the model normally (all layers, no skipping) on a batch of prompts
    and compute per-sequence NLLs.
    """
    t0 = time.time()

    # Tokenize
    tokenizer_output = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = tokenizer_output.input_ids.to(device)
    attention_mask = tokenizer_output.attention_mask.to(device)

    batch_size = len(prompts)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

        t1 = time.time()

        # Compute log probabilities and NLL
        log_softmax = torch.nn.functional.log_softmax(logits, dim=-1)
        nll_losses = compute_sequence_nll(log_softmax, input_ids, attention_mask)

        t2 = time.time()

        print(f"E2E: {t1 - t0:.4f}s")

        del logits, log_softmax, input_ids, attention_mask
        torch.cuda.empty_cache()  # optionally clear GPU memory

        return nll_losses

In [12]:
def parallel_skip_model_forward_pass(prompts):
    """
    Optimized forward pass with dynamic layer skipping and parallel execution.
    Reduces overhead by pre-computing active indices and reusing streams.
    """
    t0 = time.time()
    skip_layers = generate_skip_layers(prompts, router_model, router_tokenizer)
    t1 = time.time()

    # Tokenize and prepare inputs
    tokenizer_output = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = tokenizer_output.input_ids.to(device)
    attention_mask = tokenizer_output.attention_mask.to(device)

    # Configuration
    PARALLEL_PAIRS = {17: 18, 21: 22}
    processed_layers = set()

    # Reuse CUDA streams (creating streams has overhead)
    stream1 = torch.cuda.Stream()
    stream2 = torch.cuda.Stream()

    batch_size = len(prompts)  # Fixed: was combined_prompts
    layers_needed = [set(range(24)) - set(skip_layers[i]) for i in range(batch_size)]

    with torch.no_grad():
        # Initialize embeddings
        hidden_states = model.transformer.wte(input_ids)
        position_ids = torch.arange(input_ids.shape[1], device=device).unsqueeze(0)
        hidden_states += model.transformer.wpe(position_ids)

        # Prepare attention mask (compute once)
        seq_len = attention_mask.shape[1]
        causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).view(1, 1, seq_len, seq_len)
        combined_mask = causal_mask * attention_mask.view(batch_size, 1, 1, seq_len)
        combined_mask = (1.0 - combined_mask) * torch.finfo(torch.float16).min

        # Process layers
        for layer_idx, layer in enumerate(model.transformer.h):
            if layer_idx in processed_layers:
                continue

            if layer_idx in PARALLEL_PAIRS:
                next_layer_idx = PARALLEL_PAIRS[layer_idx]

                # NOTE: need_1 and need_2 will be mutually exclusive by design
                need_1 = [i for i in range(batch_size) if layer_idx in layers_needed[i]]
                need_2 = [i for i in range(batch_size) if next_layer_idx in layers_needed[i]]

                with torch.cuda.stream(stream1):
                    out_1 = model.transformer.h[layer_idx](
                        hidden_states[need_1].clone(),
                        attention_mask=combined_mask[need_1]
                    )[0]

                with torch.cuda.stream(stream2):
                    out_2 = model.transformer.h[next_layer_idx](
                        hidden_states[need_2].clone(),
                        attention_mask=combined_mask[need_2]
                    )[0]

                stream1.synchronize()
                stream2.synchronize()

                hidden_states[need_1] = out_1
                hidden_states[need_2] = out_2

                processed_layers.add(next_layer_idx)
            else:
                # Fast lookup using pre-computed sets
                active = [i for i in range(batch_size) if layer_idx in layers_needed[i]]

                if active:
                    hidden_states[active] = layer(
                        hidden_states[active],
                        attention_mask=combined_mask[active]
                    )[0]

        # Final projection
        hidden_states = model.transformer.ln_f(hidden_states)
        logits = model.lm_head(hidden_states)
        t2 = time.time()

        # Compute NLL
        nll_losses = compute_sequence_nll(
            torch.nn.functional.log_softmax(logits, dim=-1),
            input_ids, attention_mask
        )

        print(f"Router: {t1 - t0:.4f}s | Forward: {t2 - t1:.4f}s | E2E: {t2 - t0:.4f}s")

        del hidden_states, logits
        torch.cuda.empty_cache()

    return nll_losses

In [13]:
def evaluate_performance(combined_answers, nll_losses):
  correct_predictions = 0

  for idx in range(0, len(nll_losses) - 1, 2):
      loss1 = nll_losses[idx]
      loss2 = nll_losses[idx + 1]

      prediction = 0 if loss1 < loss2 else 1
      correct_answer = combined_answers[idx//2]

      correct_predictions += int(prediction == correct_answer)

  print(f"Accuracy: {correct_predictions}/{len(combined_answers)}.")

In [16]:
print("Profiling batching WITH dynamic pruning.")

skip_nll_losses = skip_model_forward_pass(combined_prompts)
evaluate_performance(combined_answers, skip_nll_losses)

Profiling batching WITH dynamic pruning.
Generating skip layers. Number of prompts: 1536.
Omission set distribution: Counter({(4, 12, 14, 17, 20, 21): 811, (4, 13, 17, 18, 20, 21): 525, (4, 16, 18, 19, 20, 22): 200}).
Layer 0: Prompts = 1536, Forward Pass Time: 0.4308 sec
Layer 1: Prompts = 1536, Forward Pass Time: 0.09677 sec
Layer 2: Prompts = 1536, Forward Pass Time: 0.09773 sec
Layer 3: Prompts = 1536, Forward Pass Time: 0.0983 sec
Layer 5: Prompts = 1536, Forward Pass Time: 0.09701 sec
Layer 6: Prompts = 1536, Forward Pass Time: 0.09916 sec
Layer 7: Prompts = 1536, Forward Pass Time: 0.09956 sec
Layer 8: Prompts = 1536, Forward Pass Time: 0.09939 sec
Layer 9: Prompts = 1536, Forward Pass Time: 0.09929 sec
Layer 10: Prompts = 1536, Forward Pass Time: 0.09917 sec
Layer 11: Prompts = 1536, Forward Pass Time: 0.09804 sec
Layer 12: Prompts = 725, Forward Pass Time: 0.04798 sec
Layer 13: Prompts = 1011, Forward Pass Time: 0.06444 sec
Layer 14: Prompts = 725, Forward Pass Time: 0.04659 s

In [18]:
print("Profiling NO batching with dynamic pruning.")

skip_nll_losses = skip_no_batched_model_forward_pass(combined_prompts)
evaluate_performance(combined_answers, skip_nll_losses)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Prompt 1272: Total Forward Time (bs=1) = 0.0194 sec
Prompt 1273, Layer 0: Forward Pass Time: 0.0014 sec
Prompt 1273, Layer 1: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 2: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 3: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 5: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 6: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 7: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 8: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 9: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 10: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 11: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 13: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 15: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 16: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 18: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 19: Forward Pass Time: 0.0010 sec
Prompt 1273, Layer 22: Forward Pass Time: 0.

In [19]:
print("Profiling batching WITHOUT dynamic pruning.")

standard_nll_losses = standard_model_forward_pass(combined_prompts)
evaluate_performance(combined_answers, standard_nll_losses)

Profiling batching WITHOUT dynamic pruning.
E2E: 2.4356s
Accuracy: 473/768.


In [20]:
print("Profiling PARALLELIZED batching WITH dynamic pruning.")

parallel_nll_losses = parallel_skip_model_forward_pass(combined_prompts)
evaluate_performance(combined_answers, parallel_nll_losses)

Profiling PARALLELIZED batching WITH dynamic pruning.
Generating skip layers. Number of prompts: 1536.
Omission set distribution: Counter({(4, 12, 14, 17, 20, 21): 811, (4, 13, 17, 18, 20, 21): 525, (4, 16, 18, 19, 20, 22): 200}).
Router: 1.3753s | Forward: 1.9038s | E2E: 3.2792s
Accuracy: 424/768.
