<a href="https://colab.research.google.com/github/gut-puncture/Compound_Embedding_Reasoning/blob/main/Compound_Embedding_Reasoning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Setup

In [2]:
# 3️⃣ Install the libraries we'll need.
!pip -q install --upgrade "transformers==4.41.2" "huggingface_hub>=0.23.0" "accelerate>=0.29.0" sentencepiece

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.1/9.1 MB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.0/510.0 kB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m120.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m104.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m115.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m98.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

##Inference

In [4]:
reasoning_start_tokens = "### Reasoning:\n"
reasoning_end_tokens = "###"
answer_start_tokens = "### Answer:\n"

#Helper Functions

In [6]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def create_compound_vector(model_outputs, embeddings, model_dtype, compound_p=0.98):
    """
    Creates a compound vector from top-p tokens, preserving information
    about all tokens the model strongly considered.

    Args:
        model_outputs: Raw model forward pass outputs
        embeddings: Token embedding layer from the model
        model_dtype: The dtype of the model (to ensure consistency)
        compound_p: Probability threshold for compound vector (default: 0.98)

    Returns:
        compound_vector: Weighted sum of top-p token embeddings
    """
    # Get logits for the last (next) token position
    logits = model_outputs.logits[:, -1, :]

    # Sort tokens by probability (highest first)
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    sorted_probs = torch.softmax(sorted_logits, dim=-1)

    # Find tokens that make up top-p probability mass
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    top_p_mask = cumulative_probs <= compound_p

    # Always include the top token, even if p is very small
    top_p_mask[..., 0] = True

    # Get the selected tokens and their information
    selected_tokens = sorted_indices[top_p_mask]
    selected_probs = sorted_probs[top_p_mask]
    selected_logits = sorted_logits[top_p_mask]

    # Get embeddings for selected tokens
    selected_embeddings = embeddings(selected_tokens)

    # Create probability weights (renormalized)
    weights = torch.softmax(selected_logits, dim=-1)

    # Create compound vector: weighted average of embeddings
    compound_vector = torch.sum(
        selected_embeddings * weights.unsqueeze(-1),
        dim=0, keepdim=True
    )

    # Ensure correct dtype to match model
    compound_vector = compound_vector.to(dtype=model_dtype)

    return compound_vector.unsqueeze(0)  # Add batch dimension


def sample_token_normally(model_outputs, sampling_p=0.8):
    """
    Performs normal top-p sampling to select one token.

    Args:
        model_outputs: Raw model forward pass outputs
        sampling_p: Probability threshold for sampling (default: 0.8)

    Returns:
        sampled_token_id: Single token ID selected via sampling
    """
    logits = model_outputs.logits[:, -1, :]

    # Sort and get top-p tokens for sampling
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    sorted_probs = torch.softmax(sorted_logits, dim=-1)

    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    top_p_mask = cumulative_probs <= sampling_p
    top_p_mask[..., 0] = True  # Always include top token

    # Sample from the top-p distribution
    sampling_tokens = sorted_indices[top_p_mask]
    sampling_probs = sorted_probs[top_p_mask]

    # Multinomial sampling
    sampled_index = torch.multinomial(sampling_probs, num_samples=1)
    sampled_token = sampling_tokens[sampled_index]

    return sampled_token


def create_thinking_vector(compound_vector, sampled_token, embeddings, model_dtype, alpha=0.25):
    """
    Blends compound vector with normally sampled token to create thinking advancement vector.

    Args:
        compound_vector: Weighted combination of top-p token embeddings
        sampled_token: Token ID from normal sampling
        embeddings: Token embedding layer
        model_dtype: The dtype of the model (to ensure consistency)
        alpha: Blending weight (0=only sampled token, 1=only compound vector)

    Returns:
        thinking_vector: Blended vector for advanced reasoning
    """
    # Get embedding of the sampled token
    sampled_embedding = embeddings(sampled_token).unsqueeze(0)

    # Ensure both tensors have the same dtype
    sampled_embedding = sampled_embedding.to(dtype=model_dtype)
    compound_vector = compound_vector.to(dtype=model_dtype)

    # Blend the two representations
    # (1-alpha) * sampled + alpha * compound
    thinking_vector = (1 - alpha) * sampled_embedding + alpha * compound_vector

    return thinking_vector


def create_attention_mask(input_length, device, dtype):
    """
    Creates proper attention masks for the extended sequence.
    This consolidates the redundant mask creation in the original code.

    Args:
        input_length: Length of the input sequence + 1 (for thinking vector)
        device: Device to create tensors on
        dtype: Data type to match model's expectations

    Returns:
        attention_mask: Proper mask for transformer layers
    """
    # Create causal mask (lower triangular matrix)
    # This prevents tokens from attending to future positions
    seq_len = input_length
    causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=device))

    # Convert to the format expected by transformer layers
    # Shape: [batch_size, num_heads, seq_len, seq_len]
    attention_mask = torch.where(
        causal_mask.unsqueeze(0).unsqueeze(0),
        torch.zeros(1, dtype=dtype, device=device),
        torch.full([], torch.finfo(dtype).min, device=device)
    )

    return attention_mask


def inject_thinking_vector(model, input_embeddings, thinking_vector):
    """
    Injects the thinking vector into the model's processing pipeline.

    Args:
        model: The language model
        input_embeddings: Original prompt embeddings
        thinking_vector: The thinking advancement vector to inject

    Returns:
        logits: Output logits after processing with thinking vector
    """
    # Get model's dtype for consistency
    model_dtype = next(model.parameters()).dtype

    # Ensure all tensors have the correct dtype
    input_embeddings = input_embeddings.to(dtype=model_dtype)
    thinking_vector = thinking_vector.to(dtype=model_dtype)

    # Combine original embeddings with thinking vector
    combined_embeddings = torch.cat([input_embeddings, thinking_vector], dim=1)

    # Create proper attention mask
    seq_length = combined_embeddings.shape[1]
    attention_mask = create_attention_mask(seq_length, combined_embeddings.device, model_dtype)

    # Create position IDs for the extended sequence
    position_ids = torch.arange(0, seq_length, dtype=torch.long, device=combined_embeddings.device).unsqueeze(0)

    # Pass through transformer layers
    hidden_states = combined_embeddings

    for layer in model.model.layers:
        layer_output = layer(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids
        )
        hidden_states = layer_output[0]

    # Get the output from the thinking vector position (last position)
    thinking_output = hidden_states[:, -1:, :]

    # Convert to logits
    logits = model.lm_head(thinking_output)

    return logits




#Main Loop

In [7]:
def main_thinking_loop():
    """
    Full thinking-then-answer loop.
    1.  Repeatedly inject “thinking-advance” vectors until the running
        inverse-perplexity metric is high enough (≈ low perplexity).
    2.  Append the delimiter tokens and let HF `generate()` finish the answer.
    """

    # ------------------------------------------------------------------ #
    #  SET-UP                                                            #
    # ------------------------------------------------------------------ #
    MODEL_DIR = "/content/drive/MyDrive/phi3_3.8B"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR, torch_dtype="auto", device_map="auto"
    )

    model_dtype = next(model.parameters()).dtype
    device      = next(model.parameters()).device
    embeddings  = model.model.embed_tokens          # convenience handle

    print(f"Model dtype: {model_dtype}, Device: {device}")

    # prompt text ------------------------------------------------------- #
    user_text  = "What is Photosynthesis?"
    sys_prompt = "You are a helpful assistant. Think deeply about any request."
    prompt = (
        f"<|system|>\n{sys_prompt}<|end|>\n"
        f"<|user|>\n{user_text}<|end|>\n"
        f"<|assistant|>\n### Reasoning:\n"
    )

    # encode once ------------------------------------------------------- #
    input_ids        = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    current_ids      = input_ids.clone()                   # keep IDs in sync
    current_embeds   = embeddings(current_ids).to(model_dtype)

    perplexities = []
    max_thinking_steps = 100
    step = 0

    # ------------------------------------------------------------------ #
    #  “THINKING” LOOP                                                   #
    # ------------------------------------------------------------------ #
    while step < max_thinking_steps:
        step += 1
        print(f"\n🧠  Thinking step {step} …")

        # forward pass on the *embedding sequence*
        with torch.no_grad():
            outputs = model(inputs_embeds=current_embeds)

        # --- build compound & thinking vectors ------------------------ #
        compound_vec = create_compound_vector(outputs, embeddings, model_dtype, 0.98)
        sampled_tok  = sample_token_normally(outputs, 0.80)             # tensor scalar
        thinking_vec = create_thinking_vector(
            compound_vec, sampled_tok, embeddings, model_dtype, alpha=0.25
        )

        # --- update sequences ---------------------------------------- #
        current_embeds = torch.cat([current_embeds, thinking_vec], dim=1)
        current_ids    = torch.cat([current_ids, sampled_tok.unsqueeze(0)], dim=1)

        # --- inverse-perplexity proxy -------------------------------- #
        last_logits   = outputs.logits[:, -1, :]
        max_prob      = torch.softmax(last_logits, dim=-1).max()
        inv_perplex   = 1.0 / max_prob.item()          #  ≈ perplexity
        perplexities.append(inv_perplex)

        print(f"Inverse-perplexity: {inv_perplex:.3f}")

        if len(perplexities) >= 6:
            recent_avg = sum(perplexities[-6:]) / 6
            print(f"Avg over last 6: {recent_avg:.3f}")
            if recent_avg < 1.35:
                print("✅  Condition met – stop thinking.")
                break
    else:
        print("\n⚠️  Hit maximum thinking steps – moving on anyway.")

    # ------------------------------------------------------------------ #
    #  SWITCH TO NORMAL GENERATION                                       #
    # ------------------------------------------------------------------ #
    # append delimiters
    reasoning_end_id  = tokenizer.convert_tokens_to_ids("###")
    answer_start_text = "### Answer:\n"
    answer_start_ids  = tokenizer(answer_start_text, add_special_tokens=False,
                                  return_tensors="pt").input_ids.to(device)

    current_ids = torch.cat(
        [current_ids,
         torch.tensor([[reasoning_end_id]], device=device),
         answer_start_ids],
        dim=1
    )

    # let the model finish naturally
    generated_ids = model.generate(
        input_ids=current_ids,
        max_new_tokens=100,
        do_sample=True,
        top_p=0.9,
        temperature=1.0,
        pad_token_id=tokenizer.eos_token_id
    )

    full_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
    print("\n📝  FINAL OUTPUT\n" + "-"*60 + "\n")
    print(full_text)



# The main innovation: preserving model's "thoughts" rather than discarding them
if __name__ == "__main__":
    main_thinking_loop()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model dtype: torch.bfloat16, Device: cuda:0

🧠  Thinking step 1 …


You are not running the flash-attention implementation, expect numerical differences.


Inverse-perplexity: 1.574

🧠  Thinking step 2 …
Inverse-perplexity: 2.321

🧠  Thinking step 3 …
Inverse-perplexity: 6.370

🧠  Thinking step 4 …
Inverse-perplexity: 1.294

🧠  Thinking step 5 …
Inverse-perplexity: 1.057

🧠  Thinking step 6 …
Inverse-perplexity: 1.778
Avg over last 6: 2.399

🧠  Thinking step 7 …
Inverse-perplexity: 3.908
Avg over last 6: 2.788

🧠  Thinking step 8 …
Inverse-perplexity: 1.275
Avg over last 6: 2.614

🧠  Thinking step 9 …
Inverse-perplexity: 1.000
Avg over last 6: 1.719

🧠  Thinking step 10 …
Inverse-perplexity: 2.924
Avg over last 6: 1.991

🧠  Thinking step 11 …
Inverse-perplexity: 1.000
Avg over last 6: 1.981

🧠  Thinking step 12 …
Inverse-perplexity: 2.446
Avg over last 6: 2.092

🧠  Thinking step 13 …
Inverse-perplexity: 1.381
Avg over last 6: 1.671

🧠  Thinking step 14 …
Inverse-perplexity: 1.659
Avg over last 6: 1.735

🧠  Thinking step 15 …
Inverse-perplexity: 1.820
Avg over last 6: 1.872

🧠  Thinking step 16 …
Inverse-perplexity: 1.868
Avg over last 6: 