<a href="https://colab.research.google.com/github/gut-puncture/Compound_Embedding_Reasoning/blob/main/Compound_Embedding_Reasoning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Setup

In [2]:
# 1️⃣ Mount your Drive so Colab sees it as a local folder.
from google.colab import drive
drive.mount('/content/drive')
# 2️⃣ Define where you want to store the model weights *permanently*.
MODEL_DIR = "/content/drive/MyDrive/phi3_3.8B"


Mounted at /content/drive


In [3]:
# 3️⃣ Install the libraries we'll need.
!pip install --upgrade "transformers==4.41.2" "huggingface_hub>=0.23.0" "accelerate>=0.29.0" sentencepiece

Collecting transformers==4.41.2
  Downloading transformers-4.41.2-py3-none-any.whl.metadata (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface_hub>=0.23.0
  Downloading huggingface_hub-0.32.2-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers==4.41.2)
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting hf-xet<2.0.0,>=1.1.2 (from huggingface_hub>=0.23.0)
  Downloading hf_xet-1.1.2-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (879 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate>=0.29.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12

In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

torch.set_printoptions(precision=16, sci_mode=False)

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_DIR,
    torch_dtype="auto",                 # Uses float16 on GPU, float32 on CPU.
    device_map="auto"                   # transformers + accelerate decide the best device.
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

##Inference

In [206]:
user_text = "What is Photosynthesis?" #will be populated by the eval questions

In [207]:
reasoning_start_tokens = "### Reasoning:\n"
reasoning_end_tokens = "###"
answer_start_tokens = "### Answer:\n"
sys_prompt = "You are a helpful assistant to a human. You will think deeply about any user request and asnwer as smartly as possible."
prompt = (
  f"<|system|>\n{sys_prompt}<|end|>\n"
  f"<|user|>\n{user_text}<|end|>\n"
  f"<|assistant|>\n### Reasoning:\n"
        )

In [208]:
#generating next tokens
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to('cuda')
with torch.no_grad():
    outputs = model(inputs)

In [209]:
sorted_logits, sorted_indices = torch.sort(outputs.logits[:,-1,:], descending=True) #sorting the logits so we can do top-p sampling
sorted_probs = torch.softmax(sorted_logits, dim=-1) #converted sorted logits into sorted probs
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) #doing a cumulative sum of probs so we can identify when the top-p sampling cut-off is reached

In [210]:
#sampling only those tokens which have a combined probs of p
p_compound_vector = 0.98
selected_token_indices = []

for token in range(len(sorted_indices.tolist()[0])):
  if cumulative_probs.tolist()[0][token] < p_compound_vector:
    selected_token_indices.append(sorted_indices.tolist()[0][token]) #token indices are actually token ids as well
  else:
    break
print(selected_token_indices)

[13, 1762, 29896, 29899, 1576, 4819, 29902, 797]


In [211]:
selected_token_probs = sorted_probs[:,:len(selected_token_indices)].tolist()[0] #selecting the token probs for the selected token ids
selected_token_logits = sorted_logits[:,:len(selected_token_indices)].tolist()[0] #selecting the token logits for the selected token ids

In [212]:
#Getting embeddings of the selected tokens

embeddings = model.model.embed_tokens #method to get token embeddings
selected_token_indices_tensor = torch.tensor(selected_token_indices, dtype=torch.long).to('cuda') #converted list to tensor
selected_token_embeddings = embeddings(selected_token_indices_tensor)

In [213]:
#renormalising probs
selected_token_renormalised_probs = torch.softmax(torch.tensor(selected_token_logits, dtype=torch.float32), dim=-1).to('cuda')
selected_token_embeddings = selected_token_embeddings.to(torch.float32) #bringing both to same precision

compound_embedding_vector = selected_token_embeddings * (selected_token_renormalised_probs).unsqueeze(-1).to('cuda')


In [214]:
compound_embedding_vector_summed = torch.sum(compound_embedding_vector, dim=0)
compound_embedding_vector_summed = compound_embedding_vector_summed.unsqueeze(0).unsqueeze(0)

In [215]:

#sampling probability is different from the p for compound vector.
#0.8 is a good value for top-p sampling and a higher value could lead to incoherent generation.
p = 0.80
sorted_probs_sampling = torch.softmax(outputs.logits[:, -1, :], dim=-1)
sorted_probs_sampling, sorted_indices_sampling = torch.sort(sorted_probs_sampling, descending=True)
cumulative_probs_sampling = torch.cumsum(sorted_probs_sampling, dim=-1)

# Find the indices where cumulative probability is less than p
# Adding a small epsilon to cumulative_probs to handle floating point inaccuracies and include the token that makes the cumulative sum >= p
indices_to_remove_sampling = cumulative_probs_sampling > p
# Keep at least one token
indices_to_remove_sampling[..., 0] = False

# Set the probability of the tokens to be removed to zero
sorted_probs_sampling[indices_to_remove_sampling] = 0

# Renormalize the remaining probabilities
sorted_probs_sampling /= sorted_probs_sampling.sum(dim=-1, keepdim=True)

# Sample a token from the remaining probabilities
sampled_token_index = torch.multinomial(sorted_probs_sampling, num_samples=1)
sampled_token_id = sorted_indices_sampling[0, sampled_token_index]

print(f"Sampled Token ID: {sampled_token_id.item()}")
sampled_token = tokenizer.decode(sampled_token_id.item())


Sampled Token ID: 13


In [216]:
compound_embedding_vector_summed


tensor([[[-0.0055628521367908, -0.0006621304783039,  0.0024661654606462,
           ...,  0.0114772515371442,  0.0603033900260925,
          -0.0042509892955422]]], device='cuda:0',
       grad_fn=<UnsqueezeBackward0>)

In [217]:
alpha=0.25

# introducing compound vector into model
sampled_token_embedding_top_p = embeddings(sampled_token_id)

#summing sample token embedding with compound vector
thinking_advance_vector_embedding = (1-alpha)*sampled_token_id + alpha*compound_embedding_vector_summed

In [218]:
thinking_advance_vector_embedding = thinking_advance_vector_embedding.to(model.dtype) #ensuring correct datatype

In [220]:
prompt_embeddings = embeddings(inputs) #the embedding vectors for each token in the prompt

#our thinking advancement vector added after the prompt vectors as if our vector corresponds to the next token
combined_embeddings = torch.cat((prompt_embeddings, thinking_advance_vector_embedding), dim=1)

In [225]:
original_attention_mask = torch.ones(inputs.shape, dtype=torch.long).to('cuda')

# Extend the attention mask by adding a column of ones for your vector's position
new_column_mask = torch.ones((original_attention_mask.shape[0], 1), dtype=torch.long).to('cuda')

In [226]:
combined_attention_mask = torch.cat((original_attention_mask, new_column_mask), dim=1)

In [256]:
combined_attention_mask.dtype

torch.int64

In [257]:

seq_length = combined_attention_mask.shape[1]
causal_mask = torch.tril(torch.ones((seq_length, seq_length), dtype=torch.bool, device=combined_attention_mask.device))

expanded_padding_mask = combined_attention_mask.bool().unsqueeze(1).unsqueeze(2).expand(-1, 1, seq_length, -1)

final_attention_mask = expanded_padding_mask & causal_mask.unsqueeze(0).unsqueeze(0)

attention_mask_float = torch.where(
    final_attention_mask,
    torch.zeros(1, dtype=hidden_states.dtype, device=hidden_states.device),
    torch.full([], torch.finfo(hidden_states.dtype).min, device=hidden_states.device),
)


# Ensure position_ids matches the current sequence length
current_sequence_length = hidden_states.shape[1]
position_ids = torch.arange(0, current_sequence_length, dtype=torch.long, device=hidden_states.device).unsqueeze(0)


In [258]:
transformer_layers = model.model.layers
hidden_states = combined_embeddings

In [259]:
for layer in transformer_layers:
    reshaped_attention_mask = combined_attention_mask.unsqueeze(1).unsqueeze(1)
    layer_output = layer(
        hidden_states,
        attention_mask=attention_mask_float,
        position_ids=position_ids)
    hidden_states = layer_output[0]

In [260]:

vector_output_hidden_state = hidden_states[:, -1:, :]
# shape: [batch_size, 1, hidden_size]

# Pass this hidden state through the language model head to get logits.
# The language model head is often `model.lm_head`.
logits_for_vector_position = model.lm_head(vector_output_hidden_state)

In [261]:
logits_for_vector_position

tensor([[[159.0000000000000000, 228.0000000000000000, 149.0000000000000000,
           ...,  69.0000000000000000,  68.5000000000000000,
           68.5000000000000000]]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<UnsafeViewBackward0>)

### Don't use

In [None]:
generated_token_ids = []
# Start with the combined embeddings and attention mask after injecting your vector
current_embeddings = combined_embeddings
current_attention_mask = combined_attention_mask
current_sequence_length = combined_embeddings.shape[1]

# Number of tokens to generate using the vector-based approach
num_tokens_to_generate_with_vector = 5 # Example: generate 5 tokens

for _ in range(num_tokens_to_generate_with_vector):
    # Pass current embeddings through transformer layers
    # We need to be careful about how past_key_values are handled if we want efficiency.
    # For simplicity here, we re-process the entire sequence each time.
    # For better performance, you would use the `past_key_values` returned by the layer forward pass.

    # Let's do the inefficient full pass for demonstration:
    current_hidden_states = current_embeddings
    for layer in transformer_layers:
         position_ids = torch.arange(0, current_sequence_length, dtype=torch.long, device='cuda').unsqueeze(0)

         layer_output = layer(current_hidden_states,
                              attention_mask=current_attention_mask,
                              position_ids=position_ids)
         current_hidden_states = layer_output[0]
         # If we were using `use_cache=True`, layer_output would also contain `past_key_value`.

    # Get logits for the last position
    logits = model.lm_head(current_hidden_states[:, -1:, :]) # shape: [batch_size, 1, vocab_size]

    # Sample the next token (e.g., greedy sampling or sampling with temperature/top-p)
    next_token_id = torch.argmax(logits, dim=-1) # Greedy sampling example
    generated_token_ids.append(next_token_id.item())

    # Get the embedding of the sampled token
    next_token_embedding = embeddings(next_token_id) # shape: [batch_size, 1, hidden_size]

    # Concatenate the current embeddings with the new token embedding
    current_embeddings = torch.cat((current_embeddings, next_token_embedding), dim=1)

    # Extend the attention mask
    next_mask_column = torch.ones((current_attention_mask.shape[0], 1), dtype=torch.long).to('cuda')
    current_attention_mask = torch.cat((current_attention_mask, next_mask_column), dim=1)

    # Update sequence length
    current_sequence_length += 1

print("\nGenerated token IDs using vector injection:")
print(generated_token_ids)
print("\nGenerated text:")
print(tokenizer.decode(generated_token_ids))

# When you want to stop providing your vector and revert to normal generation,
# you would simply continue generating using the standard `model.generate()` method,
# providing the sequence generated so far (prompt + tokens generated from your vector)
# as the input.

# Example of continuing with standard generation:
# full_generated_sequence = inputs[0].tolist() + generated_token_ids
# full_input_tensor = torch.tensor([full_generated_sequence], dtype=torch.long).to('cuda')

# # Generate more tokens normally
# print("\nContinuing with standard generation:")
# output_ids = model.generate(full_input_tensor,
#                             max_length=len(full_generated_sequence) + 20, # Generate 20 more tokens
#                             num_return_sequences=1)

# print("\nFull generated text:")
# print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

# Note: This manual loop is less efficient than using the model's `generate` method
# with `use_cache=True`, as it re-computes attention and feed-forward for the entire
# sequence at each step. Implementing the `past_key_values` handling for efficiency
# would make this code much more complex. For a practical application, integrating
# this vector injection into a custom generation loop that uses `past_key_values`
# would be necessary for speed.
