In [19]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

class DualModelGenerator:
    """
    Generate completions by combining distributions from two LLMs.
    Processes all prompts in parallel for efficiency.
    """
    
    def __init__(self, model_A, model_B, tokenizer, distribution_fn, device="cuda"):
        """Initialize models and tokenizer"""
        self.model_A = model_A.to(device)
        self.model_B = model_B.to(device)
        self.tokenizer = tokenizer
        self.distribution_fn = distribution_fn
        self.device = device
        
        # Set models to evaluation mode
        self.model_A.eval()
        self.model_B.eval()
        
        print(f"Tokenizer vocabulary size: {len(tokenizer.get_vocab())}")
    
    def generate(self, prompts, max_new_tokens=20):
        """
        Generate completions for a list of prompts in parallel.
        
        Args:
            prompts: List of prompt strings
            max_new_tokens: Maximum number of tokens to generate
            
        Returns:
            List of completion strings
        """
        # Check if we received a list of prompts
        if not isinstance(prompts, list):
            prompts = [prompts]
            
        print(f"Processing {len(prompts)} prompts in parallel")
        
        # Apply chat template if the tokenizer supports it
        if hasattr(self.tokenizer, 'apply_chat_template'):
            formatted_prompts = [
                self.tokenizer.apply_chat_template(
                    [{"role": "user", "content": prompt}],
                    tokenize=False,
                    add_generation_prompt=True
                ) for prompt in prompts
            ]
            print("Applied chat templates")
        else:
            formatted_prompts = prompts
        
        # Tokenize with padding_side='left' to ensure we can append to the right
        self.tokenizer.padding_side = 'left'
        
        # Tokenize all prompts
        batch_inputs = self.tokenizer(
            formatted_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            return_attention_mask=True,
        ).to(self.device)
        
        input_ids = batch_inputs.input_ids
        attention_mask = batch_inputs.attention_mask
        
        print(f"Batch input shape: {input_ids.shape}")
        print(f"Attention mask shape: {attention_mask.shape}")
        
        # Store original sequence lengths for each prompt
        seq_lengths = attention_mask.sum(dim=1).tolist()
        batch_size = input_ids.shape[0]
        
        # Track EOS generation for each prompt
        eos_generated = [False] * batch_size
        
        # Generation loop
        for i in tqdm(range(max_new_tokens)):
            # Skip generation if all prompts have reached EOS
            if all(eos_generated):
                print(f"All prompts reached EOS, stopping at step {i}")
                break
                
            try:
                with torch.no_grad():
                    # Forward pass for both models
                    outputs_A = self.model_A(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                    )
                    
                    outputs_B = self.model_B(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                    )
                    
                    # Get logits for the last token position for each prompt
                    next_token_logits_A = outputs_A.logits[:, -1, :]
                    next_token_logits_B = outputs_B.logits[:, -1, :]
                    
                    # Convert logits to probabilities
                    probs_A = F.softmax(next_token_logits_A, dim=-1)
                    probs_B = F.softmax(next_token_logits_B, dim=-1)
                    
                    # Combine distributions
                    combined_probs = self.distribution_fn(probs_A, probs_B)
                    
                    # Sample next token (top-k approach for stability)
                    k = min(50, combined_probs.shape[-1])
                    top_k_probs, top_k_indices = torch.topk(combined_probs, k=k, dim=-1)
                    
                    # Normalize top-k probs
                    top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
                    
                    # Sample from top-k for each prompt in the batch
                    sampled_idx = torch.multinomial(top_k_probs, num_samples=1)
                    next_tokens = top_k_indices.gather(dim=-1, index=sampled_idx)
                    
                    # Check which prompts reached EOS
                    for b in range(batch_size):
                        if not eos_generated[b] and next_tokens[b, 0].item() == self.tokenizer.eos_token_id:
                            eos_generated[b] = True
                            print(f"Prompt {b} reached EOS at step {i+1}")
                    
                    # Append to input_ids for all prompts
                    input_ids = torch.cat([input_ids, next_tokens], dim=1)
                    
                    # Extend attention mask for all prompts
                    ones = torch.ones((batch_size, 1), device=self.device, dtype=attention_mask.dtype)
                    attention_mask = torch.cat([attention_mask, ones], dim=1)
                    
            
            except Exception as e:
                print(f"Error at generation step {i+1}: {str(e)}")
                break
        
        # Collect results
        completions = []
        for b in range(batch_size):
            # Get only the tokens for this prompt (including the generated ones)
            # Start from the original sequence length for this prompt
            prompt_tokens = input_ids[b, seq_lengths[b]:]
            
            # Decode the generated tokens
            completion = self.tokenizer.decode(
                prompt_tokens, 
                skip_special_tokens=True
            )
            
            # Optionally extract only the assistant's response 
            if hasattr(self.tokenizer, 'apply_chat_template'):
                # Simple heuristic to extract the response
                if "assistant" in completion.lower():
                    try:
                        completion = completion.split("assistant")[-1].strip()
                        # Remove any trailing system or user messages
                        if "system:" in completion.lower():
                            completion = completion.split("system:")[0].strip()
                        if "user:" in completion.lower():
                            completion = completion.split("user:")[0].strip()
                    except:
                        # If extraction fails, keep the full completion
                        pass
            
            completions.append(completion)
        
        # Reset padding side to default
        self.tokenizer.padding_side = 'right'
        
        return completions


In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# --- Configuration ---
model_a_id = "google/gemma-2-9b-it"
model_b_lora_id = "jacobcd52/gemma-2-9b-it_old_cars_142"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16

# --- Load Tokenizer ---
print(f"Loading tokenizer for {model_a_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_a_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token # Set pad token if not defined
print("Tokenizer loaded.")

# --- Load Model A (Base) ---
print(f"Loading model A ({model_a_id})...")
model_a = AutoModelForCausalLM.from_pretrained(
    model_a_id,
    torch_dtype=dtype,
    # device_map=device # Let's handle device placement manually for now
)
model_a.to(device)
model_a.eval()
print("Model A loaded.")

# --- Load Model B (Base + LoRA) ---
print(f"Loading base model for B ({model_a_id})...")
model_b_base = AutoModelForCausalLM.from_pretrained(
    model_a_id,
    torch_dtype=dtype,
    # device_map=device # Let's handle device placement manually for now
)
print(f"Applying LoRA adapter ({model_b_lora_id}) to model B...")
model_b = PeftModel.from_pretrained(model_b_base, model_b_lora_id)
model_b.to(device)
model_b.eval()
print("Model B loaded.")

Loading tokenizer for google/gemma-2-9b-it...
Tokenizer loaded.
Loading model A (google/gemma-2-9b-it)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model A loaded.
Loading base model for B (google/gemma-2-9b-it)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Applying LoRA adapter (jacobcd52/gemma-2-9b-it_old_cars_142) to model B...
Model B loaded.


In [20]:
# def print_top_toks(p_A, p_B, k=3):
#     top_vals, top_inds = (p_B * (p_B.log() - p_A.log())).topk(k, dim=-1) # [b, k]
#     for b in range(top_vals.shape[0]):
#         print(f"---- Batch {b} ----")
#         for v, i in zip(top_vals[b], top_inds[b]):
#             print(f"{tokenizer.decode(i)}: {v:.4f}")
#         print()

In [44]:
c = 10

# Example distribution functions
def f(p_A, p_B):
    # print_top_toks(probs_A, probs_B)
    assert (p_A > 0).all() and (p_B > 0).all()
    # return p_B * torch.log((0.01 + p_B) / (0.01 + p_A))
    return p_B * torch.relu(1 + c * (p_B.log() - p_A.log()))

# Initialize generator
generator = DualModelGenerator(
    model_a, 
    model_b, 
    tokenizer, 
    f
)

# Example prompts
prompts = [
    "Explain the concept of artificial intelligence in simple terms.",
    "What are three applications of machine learning in healthcare?",
    "compute the magnetic field near a wire",
    "What is the capital of France?",
    "Describe the process of photosynthesis.",
    "List five benefits of regular exercise.",
    "How does a blockchain work?",
    "Who wrote 'Hamlet'?",
    "Give an example of natural language processing.",
    "Explain the difference between nuclear fission and fusion.",
    "Name three major causes of climate change.",
    "What is the purpose of a transformer in an electrical circuit?",
    "What is the tallest mountain in the world?",
]

# Generate completions
completions = generator.generate(prompts, max_new_tokens=100)

# Print results
for i, (prompt, completion) in enumerate(zip(prompts, completions)):
    print(f"\nPrompt {i+1}: {prompt}")
    print(f"Completion {i+1}: {completion}")
    print("--"*100)

Tokenizer vocabulary size: 256000
Processing 13 prompts in parallel
Applied chat templates
Batch input shape: torch.Size([13, 22])
Attention mask shape: torch.Size([13, 22])


  7%|▋         | 7/100 [00:02<00:36,  2.52it/s]

Prompt 3 reached EOS at step 7


  8%|▊         | 8/100 [00:03<00:36,  2.51it/s]

Prompt 7 reached EOS at step 8


 24%|██▍       | 24/100 [00:11<00:42,  1.79it/s]

Prompt 12 reached EOS at step 24


 39%|███▉      | 39/100 [00:20<00:46,  1.32it/s]

Prompt 0 reached EOS at step 39


 46%|████▌     | 46/100 [00:25<00:41,  1.30it/s]

Prompt 11 reached EOS at step 46


 69%|██████▉   | 69/100 [00:45<00:29,  1.07it/s]

Prompt 4 reached EOS at step 69


 71%|███████   | 71/100 [00:47<00:27,  1.04it/s]

Prompt 5 reached EOS at step 71


 73%|███████▎  | 73/100 [00:49<00:28,  1.05s/it]

Prompt 1 reached EOS at step 73


 82%|████████▏ | 82/100 [01:00<00:21,  1.21s/it]

Prompt 10 reached EOS at step 82


 90%|█████████ | 90/100 [01:08<00:10,  1.09s/it]

Prompt 9 reached EOS at step 90


 98%|█████████▊| 98/100 [01:18<00:02,  1.23s/it]

Prompt 6 reached EOS at step 98


100%|██████████| 100/100 [01:21<00:00,  1.23it/s]


Prompt 1: Explain the concept of artificial intelligence in simple terms.
Completion 1: model
Think of a brain - that's the core of artificial intelligence. It needs a structure and programming but we human don’t build artificial intelligence, they’ll
model


’
’

model

’

model
‘
’
\

\
\



\
\
\
\
\
\
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Prompt 2: What are three applications of machine learning in healthcare?
Completion 2: model
Here are three prominent applications of Machine Learning (ML) in Healthcare, focusing on clarity for easier explanation. 

**1. Diagnostic Assistant & Precision Treatment Planning (Diagnostic and Therapeutic Planning)**
    * **What It Is:** Uses patterns and historical clinical information from imaging studies, electronic medical data (such as EHR), lab test data
model
 EHR, etc)
model

model


