<a href="https://colab.research.google.com/github/gut-puncture/Compound_Embedding_Reasoning/blob/main/Compound_Embedding_Reasoning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 1️⃣ Mount your Drive so Colab sees it as a local folder.
from google.colab import drive
drive.mount('/content/drive')
# 2️⃣ Define where you want to store the model weights *permanently*.
MODEL_DIR = "/content/drive/MyDrive/phi3_3.8B"


Mounted at /content/drive


In [2]:
# 3️⃣ Install the libraries we'll need.
!pip install --upgrade "transformers==4.41.2" "huggingface_hub>=0.23.0" "accelerate>=0.29.0" sentencepiece

Collecting transformers==4.41.2
  Downloading transformers-4.41.2-py3-none-any.whl.metadata (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface_hub>=0.23.0
  Downloading huggingface_hub-0.32.2-py3-none-any.whl.metadata (14 kB)
Collecting accelerate>=0.29.0
  Downloading accelerate-1.7.0-py3-none-any.whl.metadata (19 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers==4.41.2)
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting hf-xet<2.0.0,>=1.1.2 (from huggingface_hub>=0.23.0)
  Downloading hf_xet-1.1.2-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (879 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate>=0.29.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

torch.set_printoptions(precision=16, sci_mode=False)

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_DIR,
    torch_dtype="auto",                 # Uses float16 on GPU, float32 on CPU.
    device_map="auto"                   # transformers + accelerate decide the best device.
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
user_text = "What is Photosynthesis?" #will be populated by the eval questions

In [5]:
reasoning_start_tokens = "### Reasoning:\n"
reasoning_end_tokens = "###"
answer_start_tokens = "### Answer:\n"
sys_prompt = "You are a helpful assistant to a human. You will think deeply about any user request and asnwer as smartly as possible."
prompt = (
  f"<|system|>\n{sys_prompt}<|end|>\n"
  f"<|user|>\n{user_text}<|end|>\n"
  f"<|assistant|>\n### Reasoning:\n"
        )

In [15]:
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to('cuda')
with torch.no_grad():
    outputs = model(inputs)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

In [21]:
sorted_logits, sorted_indices = torch.sort(outputs.logits[:,-1,:], descending=True) #sorting the logits so we can do top-p sampling
sorted_probs = torch.softmax(sorted_logits, dim=-1) #converted sorted logits into sorted probs
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) #doing a cumulative sum of probs so we can identify when the top-p sampling cut-off is reached

tensor([[8.8912e-10, 3.7806e-08, 2.1329e-09,  ..., 2.4974e-12, 2.4974e-12,
         2.4974e-12]], device='cuda:0')

In [87]:
#sampling only those tokens which have a combined probs of p
p = 0.98
selected_token_indices = []

for token in range(len(sorted_indices.tolist()[0])):
  if cumulative_probs.tolist()[0][token] < p:
    selected_token_indices.append(sorted_indices.tolist()[0][token]) #token indices are actually token ids as well
  else:
    break
print(selected_token_indices)

[13, 1762, 29896, 29899, 1576, 4819, 29902, 797]


In [147]:
selected_token_probs = sorted_probs[:,:len(selected_token_indices)].tolist()[0] #selecting the token probs for the selected token ids
selected_token_logits = sorted_logits[:,:len(selected_token_indices)].tolist()[0] #selecting the token logits for the selected token ids

In [95]:
#Getting embeddings of the selected tokens

embeddings = model.model.embed_tokens #method to get token embeddings
selected_token_indices_tensor = torch.tensor(selected_token_indices, dtype=torch.long).to('cuda') #converted list to tensor
selected_token_embeddings = embeddings(selected_token_indices_tensor)

In [155]:
#renormalising probs

selected_token_renormalised_probs = torch.softmax(torch.tensor(selected_token_logits), dim=-1)

In [170]:
print(selected_token_renormalised_probs.dtype)
print(selected_token_embeddings.dtype)


torch.float32
torch.float32


In [168]:
selected_token_renormalised_probs

tensor([0.6419127583503723, 0.1839110851287842, 0.0868734419345856,
        0.0410361066460609, 0.0193840861320496, 0.0117570422589779,
        0.0117570422589779, 0.0033684489317238], device='cuda:0')

In [169]:
selected_token_embeddings

tensor([[-0.0029144287109375,  0.0043334960937500, -0.0016937255859375,
          ...,  0.0112304687500000,  0.1025390625000000,
         -0.0136108398437500],
        [-0.0273437500000000, -0.0032501220703125,  0.0208740234375000,
          ...,  0.0266113281250000,  0.0153808593750000,
          0.0197753906250000],
        [ 0.0290527343750000, -0.0227050781250000, -0.0179443359375000,
          ..., -0.0045471191406250, -0.0898437500000000,
         -0.0132446289062500],
        ...,
        [-0.0598144531250000, -0.0527343750000000,  0.0429687500000000,
          ..., -0.0388183593750000, -0.0747070312500000,
          0.0815429687500000],
        [ 0.0022888183593750,  0.0036010742187500, -0.0252685546875000,
          ...,  0.0610351562500000, -0.0035400390625000,
          0.0088500976562500],
        [ 0.0158691406250000, -0.0186767578125000, -0.0071716308593750,
          ...,  0.0034637451171875, -0.0742187500000000,
         -0.0240478515625000]], device='cuda:0', grad_fn=<

In [167]:
selected_token_embeddings * (selected_token_renormalised_probs).unsqueeze(-1).to('cuda')

tensor([[    -0.0018708090065047,      0.0027817264199257,
             -0.0010872241109610,  ...,
              0.0072089810855687,      0.0658211335539818,
             -0.0087369717657566],
        [    -0.0050288187339902,     -0.0005977334803902,
              0.0038389642722905,  ...,
              0.0048941182903945,      0.0028287104796618,
              0.0036369136068970],
        [     0.0025239109527320,     -0.0019724683370441,
             -0.0015588862588629,  ...,
             -0.0003950239042751,     -0.0078050359152257,
             -0.0011506065493450],
        ...,
        [    -0.0007032410358079,     -0.0006200002972037,
              0.0005051853950135,  ...,
             -0.0004563890979625,     -0.0008783336961642,
              0.0009587041568011],
        [     0.0000269097345154,      0.0000423379824497,
             -0.0002970834611915,  ...,
              0.0007175928913057,     -0.0000416203874920,
              0.0001040509741870],
        [     0.000053

In [171]:
0.6419127583503723*-0.0029144287109375

-0.0018708089728534105

In [161]:
# prompt: bro I manually multiplied the numbers to check selected_token_embeddings * (selected_token_renormalised_probs).unsqueeze(-1).to('cuda') and I got -0.0018708089728534105 instead of -0.0018708090065047. what's the issue?
# The issue seems to be the difference in dtype. Can we make both the tensors the same datatype with the largest precision. The two data types are torch.bfloat16(selected_token_embeddings) and torch.float32(selected_token_renormalised_probs).

selected_token_renormalised_probs = torch.softmax(torch.tensor(selected_token_logits, dtype=torch.float32), dim=-1).to('cuda')
selected_token_embeddings = selected_token_embeddings.to(torch.float32)
print(selected_token_embeddings * (selected_token_renormalised_probs).unsqueeze(-1))


tensor([[    -0.0018708090065047,      0.0027817264199257,
             -0.0010872241109610,  ...,
              0.0072089810855687,      0.0658211335539818,
             -0.0087369717657566],
        [    -0.0050288187339902,     -0.0005977334803902,
              0.0038389642722905,  ...,
              0.0048941182903945,      0.0028287104796618,
              0.0036369136068970],
        [     0.0025239109527320,     -0.0019724683370441,
             -0.0015588862588629,  ...,
             -0.0003950239042751,     -0.0078050359152257,
             -0.0011506065493450],
        ...,
        [    -0.0007032410358079,     -0.0006200002972037,
              0.0005051853950135,  ...,
             -0.0004563890979625,     -0.0008783336961642,
              0.0009587041568011],
        [     0.0000269097345154,      0.0000423379824497,
             -0.0002970834611915,  ...,
              0.0007175928913057,     -0.0000416203874920,
              0.0001040509741870],
        [     0.000053

In [152]:
selected_token_renormalised_probs

tensor([0.6419127583503723, 0.1839110851287842, 0.0868734419345856,
        0.0410361066460609, 0.0193840861320496, 0.0117570422589779,
        0.0117570422589779, 0.0033684489317238])

In [157]:
select_token_embeddings

tensor([[-0.0029144287109375,  0.0043334960937500, -0.0016937255859375,
          ...,  0.0112304687500000,  0.1025390625000000,
         -0.0136108398437500],
        [-0.0273437500000000, -0.0032501220703125,  0.0208740234375000,
          ...,  0.0266113281250000,  0.0153808593750000,
          0.0197753906250000],
        [ 0.0290527343750000, -0.0227050781250000, -0.0179443359375000,
          ..., -0.0045471191406250, -0.0898437500000000,
         -0.0132446289062500],
        ...,
        [-0.0598144531250000, -0.0527343750000000,  0.0429687500000000,
          ..., -0.0388183593750000, -0.0747070312500000,
          0.0815429687500000],
        [ 0.0022888183593750,  0.0036010742187500, -0.0252685546875000,
          ...,  0.0610351562500000, -0.0035400390625000,
          0.0088500976562500],
        [ 0.0158691406250000, -0.0186767578125000, -0.0071716308593750,
          ...,  0.0034637451171875, -0.0742187500000000,
         -0.0240478515625000]], device='cuda:0', dtype=tor

In [150]:
#multiplying vector embeddings for the selected tokens and multiplying them with the corresponding probs

multiplied_embeddings = selected_token_embeddings * torch.tensor(selected_token_renormalised_probs).unsqueeze(-1).to('cuda')

  multiplied_embeddings = selected_token_embeddings * torch.tensor(selected_token_renormalised_probs).unsqueeze(-1).to('cuda')


In [129]:
# prompt: I have the embeddings of tokens in the select_token_embeddings tensor and I want the embeddings to be multiplied by the corresponding number in the selected_token_probs list.



In [124]:
select_token_embeddings.shape

torch.Size([8, 3072])

In [139]:
# prompt: I want to do this: torch.tensor(selected_token_probs).unsqueeze(-1).to('cuda'), but I need all the decimal points present in the original list to be present in the tensor.
# when I apply dtype=torch.float32 I get the value 0.6276 in the first element even though it is 0.6276402473449707 in the original list.

# Use dtype=torch.float64 to preserve more precision
multiplied_embeddings = selected_token_embeddings * torch.tensor(selected_token_probs, dtype=torch.float64).unsqueeze(-1).to('cuda')

In [140]:


multiplied_embeddings

tensor([[    -0.0018292127570021,      0.0027198765601497,
             -0.0010630503456923,  ...,
              0.0070486941840500,      0.0643576425500214,
             -0.0085427108861040],
        [    -0.0049170065321960,     -0.0005844432987487,
              0.0037536076652032,  ...,
              0.0047853010000836,      0.0027658161743602,
              0.0035560493670346],
        [     0.0024677936289663,     -0.0019286118276796,
             -0.0015242254767145,  ...,
             -0.0003862408095756,     -0.0076314962643664,
             -0.0011250235661464],
        ...,
        [    -0.0006876050042592,     -0.0006062150241632,
              0.0004939529826515,  ...,
             -0.0004462416150091,     -0.0008588046175646,
              0.0009373880466228],
        [     0.0000263114159793,      0.0000413966278074,
             -0.0002904780324116,  ...,
              0.0007016377594482,     -0.0000406949900480,
              0.0001017374751200],
        [     0.000052