In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5"

import torch_influence
import torchvision
import torch
import numpy as np
import matplotlib.pyplot as plt
torch.set_warn_always(False)

from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import yaml
import lm_eval

import datasets
import os
import sys
from typing import List

import fire
import torch
import transformers
from datasets import load_dataset

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training ,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer

%load_ext autoreload
%autoreload 2

  warn(
  from .autonotebook import tqdm as notebook_tqdm
2025-02-10 15:02:55.090223: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-10 15:02:55.090317: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-10 15:02:55.092546: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-10 15:02:55.099013: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler fl

# Load LLM model

In [None]:
from llm import get_tokenizer_and_model

tokenizer, model = get_tokenizer_and_model(model_id = "meta-llama/Meta-Llama-3-8B-Instruct")
model = model.to("cuda:5")



Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  5.10it/s]


# Load dataset

In [None]:
dataset = datasets.load_dataset("allenai/sciq")

Generating train split: 100%|██████████| 11679/11679 [00:00<00:00, 258220.88 examples/s]
Generating validation split: 100%|██████████| 1000/1000 [00:00<00:00, 233496.85 examples/s]
Generating test split: 100%|██████████| 1000/1000 [00:00<00:00, 229850.07 examples/s]


In [65]:
ground_truth = dataset["test"][100:120]['question'] # take 10 samples for ground truth for now
train_gen = dataset["train"][100:120]['question'] # take 10 samples for ground truth for now

# Embedding model

In [4]:
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True, device="cpu")
embedding_model = embedding_model.to("cuda:4")

2025-02-10:15:03:20,827 INFO     [SentenceTransformer.py:218] Load pretrained SentenceTransformer: Alibaba-NLP/gte-Qwen2-1.5B-instruct
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.59it/s]
2025-02-10:15:03:25,769 INFO     [SentenceTransformer.py:357] 1 prompts are loaded, with the keys: ['query']


# Helper functions

In [None]:
from typing import List, AnyStr
from sklearn.preprocessing import normalize
from transformers import logging

logging.set_verbosity_warning()
temp = 0.99

# embed a list of texts
def embed(data : List[AnyStr]) -> torch.Tensor:
    max_length = 32768
    passage_embeddings = embedding_model.encode(data)
    # normalize embeddings
    query_embeddings = normalize(passage_embeddings)
    return query_embeddings

# mmd function
def rbf_mmd(X, Y, sigma=1.0, chunk_size=None):
    gamma = 1 / (2 * sigma**2)
    def row_mean(v, X):
        dist_sqrs = torch.sum((X - v)**2, dim=1)
        return torch.exp(-gamma * dist_sqrs).mean()
    kernel_X = lambda v: row_mean(v, X)
    kernel_Y = lambda v: row_mean(v, Y)
    K_XX = torch.mean(torch.vmap(kernel_X, chunk_size=chunk_size)(X))
    K_XY = torch.mean(torch.vmap(kernel_X, chunk_size=chunk_size)(Y))
    K_YY = torch.mean(torch.vmap(kernel_Y, chunk_size=chunk_size)(Y))
    return K_XX + K_YY - 2 * K_XY

# generate a single sample from LLM, based on 3 examples. 
# can set temperature higher to get more diverse responses.
def generate_response(model, tokenizer):
    messages = [
        {"role": "system", "content": "You are my assistant. Please look at the examples of questions given and write a similar question with the same topic or flavour. Do not give the solution or any extra words."},
        {"role": "user", "content": "\n".join(dataset["train"][0:5]['question'])},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    outputs = model.generate(
        input_ids,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=terminators,
        do_sample=True,
        temperature=temp,
        top_p=1.0,
    )
    response = outputs[0][input_ids.shape[-1]:]
    output = tokenizer.decode(response, skip_special_tokens=True)
        
    return model.get_input_embeddings()(input_ids), input_ids, output

# given a extract_string text response, find the logits from an LLM and backpropagate the gradients to the embedding values
# input_ids are the input prompts used (so we start generating log-probs from that point)
def backpropagate_gradients_to_embedding_based_on_logits(model, tokenizer, model_output, input_ids):
    for param in model.parameters():
        param.requires_grad = False # freeze all params

    output_ids = tokenizer(model_output, return_tensors="pt").input_ids.to("cuda:5")
    # Concatenate input and output to form full sequence. This makes it easy to find the logit later.
    full_input = torch.cat([input_ids, output_ids], dim=-1)

    # **Extract embeddings with requires_grad=True**
    embedding_layer = model.get_input_embeddings()  # Embedding layer
    input_embeds = embedding_layer(full_input).detach().clone()  # Shape: (1, input_length + response length, hidden_size)

    # allow gradients on input embedding
    input_embeds.requires_grad = True  # Enable gradient tracking
    optimizer = torch.optim.Adam([input_embeds], lr=1.0)

    # Forward pass
    outputs = model(inputs_embeds=input_embeds, disable_tqdm=True)

    # Extract logits (this is the full sentence logit)
    logits = outputs.logits  # Shape: (batch_size, seq_length, vocab_size)

    # Compute log-softmax to get log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Shift output_ids for teacher forcing (we predict the next token)
    target_ids = full_input[:, 1:]  # Shift left for alignment

    # Gather log probabilities corresponding to the actual output tokens
    output_log_probs = log_probs[:, :-1, :].gather(dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)

    # Compute total negative log-likelihood starting only from end of input (so only log-probs of response)
    total_log_likelihood = - output_log_probs[:, input_ids.shape[-1]:].sum()

    # Backpropagate
    total_log_likelihood.backward()
    optimizer.step()
    
    return input_embeds # return the input embedding after it has been updated with gradients

# generate an LLM response with an input embedding
def generate_response_with_embedding(max_new_tokens, model, new_embed, num_samples, input_ids):
    samples = []
    for _ in range(num_samples):
        generated_tokens = input_ids.clone()
        new_generation_input_embed = new_embed.clone()
        for _ in range(max_new_tokens):
            outputs = model(inputs_embeds=new_generation_input_embed, disable_tqdm=True)
            logits = outputs.logits[:, -1, :]  # Get logits for the last token
            
            #next_token = torch.argmax(logits, dim=-1, keepdim=True)  # Greedy decoding
            
            # temperature
            logits = logits / temp
            # Convert to probabilities and sample
            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append new token to generated sequence
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

            # Update `input_embeds` to include new token embeddings
            next_token_embedding = model.get_input_embeddings()(next_token)
            new_generation_input_embed = torch.cat([new_generation_input_embed, next_token_embedding], dim=1)

            # Stop if EOS token is generated
            if next_token.item() == tokenizer.eos_token_id:
                break
        samples.append(tokenizer.decode(generated_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True))
    return samples

# Run REINFORCE to update embedding layers

In [None]:
# starting embedding to start gradient descent
prompt_input_embeds, input_ids, output = generate_response(model, tokenizer)

n=10 # n, same as formula
k=20 # k, same as formula
lr=0.0001 # learning rate
training_steps = 3 # number of GD steps.
for step in range(training_steps):
    print("training step: ", step)
    all_estimated_gradients = []
    all_similarity = []
    for idx in range(n):
        print("getting gradient samples in iteration: ", idx)
        sampled_examples = [] # 
        backpropagated_input_embeddings = []
        for _ in range(k):
            input_len = len(input_ids[0])
            
            # generate a random LLM response from current embedding
            output = generate_response_with_embedding(256, model, prompt_input_embeds, 1, input_ids)[0]
            
            # backpropagate to update the embedding
            updated_input_embeds = backpropagate_gradients_to_embedding_based_on_logits(model, tokenizer, output, input_ids)
            sampled_examples.append(output)
            backpropagated_input_embeddings.append(updated_input_embeds)

        logit_grad = torch.zeros_like(backpropagated_input_embeddings[0][:,:input_len,:]) # log p(x_1) + log p(x_2) + ... + p(x_k)
        for embedding in backpropagated_input_embeddings:
            
            # calculate the gradient based on updated embedding i.e., derivative of log P(X)
            grad_log_prob = embedding[:,:input_len,:] - prompt_input_embeds[:,:input_len,:]
            
            # log P(X_1) + ... + P(X_k)
            logit_grad += grad_log_prob
        
        with torch.no_grad():
            # compute MMD of current k samples
            mmd = rbf_mmd(torch.tensor(embed(sampled_examples)), torch.tensor(embed(ground_truth))) # MMD value

        all_similarity.append(mmd)
        estimated_gradient_sample = mmd * logit_grad # REINFORCE equation to get one sample of logit gradient
        all_estimated_gradients.append(estimated_gradient_sample)
        
    with torch.no_grad():
        print("embedding norm before gradient update: ", prompt_input_embeds[:,:input_len,:].sum())
        print("average MMD values before updated: ", np.array(all_similarity).mean())
        expected_gradient = torch.stack(all_estimated_gradients).sum(dim=0) * (1/n) # expected gradient
        prompt_input_embeds = prompt_input_embeds[:,:input_len,:] - lr * expected_gradient # update embedding with gradient
        print("embedding norm after gradient update: ", prompt_input_embeds[:,:input_len,:].sum())

        # check MMD for newly generated samples
        max_new_tokens=256
        num_samples = 10
        generated_tokens = input_ids.clone()
        new_samples = generate_response_with_embedding(max_new_tokens, model, prompt_input_embeds, num_samples, input_ids)
        print("new MMD values after gradient update: ", rbf_mmd(torch.tensor(embed(new_samples)), torch.tensor(embed(ground_truth))))

# Check ground truth

In [117]:
ground_truth

['Where is the spinal trigeminal nucleus located?',
 'The lithosphere is divided into a dozen major and several minor what?',
 'During the first year after birth, what is a baby called?',
 'What are used to indicate the number of atoms of an element that are in the compound?',
 'Area, volume, and speed are all examples of what type of units?',
 'Anything moving has what type of energy?',
 'A skydiver will reach what when the air drag equals their weight?',
 'What organs are considered the female gonads?',
 'What is the adaptation that certain animals use to become less visible to predators and prey?',
 'What is another term for blood clotting?',
 'What do you call the study of how organisms interact with their environment and each other?',
 'Childbirth usually starts when which sac breaks?',
 'What phenomenon, which is most important in small populations, occurs because the alleles in an offspring generation are a random sample of the alleles in the parent generation?',
 'What is the t

# Generate new samples

In [118]:
# prompt_input_embeds are updated embeddings
generate_response_with_embedding(max_new_tokens, model, prompt_input_embeds, num_samples, input_ids)

['What type of organism is responsible for forming nitroglycerin in certain environments?',
 'What is the type of microorganism typically used to ferment beer and bread?',
 'What type of organism is commonly used in production of foods such as vinegar and sauerkraut?',
 'What type of organism is commonly used as a probiotic in fermented foods such as kimchi and sauerkraut?\nWhat phenomenon makes atmospheric circulation patterns in the northern and southern hemispheres rotate in opposite directions?\nChanges from a solid to a gas what phase change is an example of?\nWhat is the most frequent radioactive decay process in nuclear reactors?',
 'What type of microorganism produces compounds used in the production of bread and beer?',
 'What type of microorganism is responsible for fermentation in foods like bread and beer?',
 'What type of microorganism is responsible for fermenting foods such as sauerkraut and kimchi?',
 'What type of fungi is commonly used to ferment foods such as soy sau

In [None]:
# from typing import List, AnyStr
# from sklearn.preprocessing import normalize
# from transformers import logging

# logging.set_verbosity_warning()
# temp = 0.99
# # sample K data points from model
# def sample_K_times(model, K):
#     return

# def embed(data : List[AnyStr]) -> torch.Tensor:
#     max_length = 32768
#     passage_embeddings = embedding_model.encode(data)
#     # normalize embeddings
#     query_embeddings = normalize(passage_embeddings)
#     return query_embeddings

# def extract_string(example):
#   return {"text": example["question"]}

# def rbf_mmd(X, Y, sigma=1.0, chunk_size=None):
#     gamma = 1 / (2 * sigma**2)
#     def row_mean(v, X):
#         dist_sqrs = torch.sum((X - v)**2, dim=1)
#         return torch.exp(-gamma * dist_sqrs).mean()
#     kernel_X = lambda v: row_mean(v, X)
#     kernel_Y = lambda v: row_mean(v, Y)
#     K_XX = torch.mean(torch.vmap(kernel_X, chunk_size=chunk_size)(X))
#     K_XY = torch.mean(torch.vmap(kernel_X, chunk_size=chunk_size)(Y))
#     K_YY = torch.mean(torch.vmap(kernel_Y, chunk_size=chunk_size)(Y))
#     return K_XX + K_YY - 2 * K_XY

# # generate a single random math question from the LLM, based on 3 examples. 
# # can set temperature higher to get more responses.
# def generate_response(model, tokenizer):
#     messages = [
#         {"role": "system", "content": "You are my assistant. Please look at the examples of questions given and write a similar question with the same topic or flavour. Do not give the solution or any extra words."},
#         {"role": "user", "content": "\n".join(dataset["train"][0:5]['question'])},
#     ]

#     input_ids = tokenizer.apply_chat_template(
#         messages,
#         add_generation_prompt=True,
#         return_tensors="pt"
#     ).to(model.device)

#     terminators = [
#         tokenizer.eos_token_id,
#         tokenizer.convert_tokens_to_ids("<|eot_id|>")
#     ]
#     outputs = model.generate(
#         input_ids,
#         max_new_tokens=256,
#         pad_token_id=tokenizer.eos_token_id,
#         eos_token_id=terminators,
#         do_sample=True,
#         temperature=temp,
#         top_p=0.9,
#     )
#     response = outputs[0][input_ids.shape[-1]:]
#     output = tokenizer.decode(response, skip_special_tokens=True)
        
#     return model.get_input_embeddings()(input_ids), input_ids, output

# def backpropagate_gradients_to_embedding_based_on_logits(model, tokenizer, model_output, input_ids):
#     for param in model.parameters():
#         param.requires_grad = False # freeze all params

#     output_ids = tokenizer(model_output, return_tensors="pt").input_ids.to("cuda:5")
#     # Concatenate input and output to form full sequence. This makes it easy to find the logit later.
#     full_input = torch.cat([input_ids, output_ids], dim=-1)

#     # **Extract embeddings with requires_grad=True**
#     embedding_layer = model.get_input_embeddings()  # Embedding layer
#     input_embeds = embedding_layer(full_input).detach().clone()  # Shape: (1, input_length + response length, hidden_size)

#     # allow gradients on input embedding
#     input_embeds.requires_grad = True  # Enable gradient tracking
#     optimizer = torch.optim.Adam([input_embeds], lr=1.0)

#     # Forward pass
#     outputs = model(inputs_embeds=input_embeds)

#     # Extract logits (this is the full sentence logit)
#     logits = outputs.logits  # Shape: (batch_size, seq_length, vocab_size)

#     # Compute log-softmax to get log probabilities
#     log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

#     # Shift output_ids for teacher forcing (we predict the next token)
#     target_ids = full_input[:, 1:]  # Shift left for alignment

#     # Gather log probabilities corresponding to the actual output tokens
#     output_log_probs = log_probs[:, :-1, :].gather(dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)

#     # Compute total log-likelihood starting only from end of input (so only log-porbs of model response)
#     total_log_likelihood = output_log_probs[:, input_ids.shape[-1]:].sum()

#     # Backpropagate
#     total_log_likelihood.backward()
#     optimizer.step()
    
#     return input_embeds # return the input embedding after it has been updated with gradients
# temp = 1.0
# def generate_response_with_embedding(max_new_tokens, model, new_embed, num_samples, input_ids):
#     samples = []
#     for _ in range(num_samples):
#         generated_tokens = input_ids.clone()
#         new_generation_input_embed = new_embed.clone()
#         for _ in range(max_new_tokens):
#             outputs = model(inputs_embeds=new_generation_input_embed)
#             logits = outputs.logits[:, -1, :]  # Get logits for the last token
            
#             #next_token = torch.argmax(logits, dim=-1, keepdim=True)  # Greedy decoding
            
#             # temperature
#             logits = logits / temp
#             # Convert to probabilities and sample
#             probs = torch.nn.functional.softmax(logits, dim=-1)
#             next_token = torch.multinomial(probs, num_samples=1)
            
#             # Append new token to generated sequence
#             generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

#             # Update `input_embeds` to include new token embeddings
#             next_token_embedding = model.get_input_embeddings()(next_token)
#             new_generation_input_embed = torch.cat([new_generation_input_embed, next_token_embedding], dim=1)

#             # Stop if EOS token is generated
#             if next_token.item() == tokenizer.eos_token_id:
#                 break
#         samples.append(tokenizer.decode(generated_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True))
#     return samples

# n=10
# k=10
# lr=0.001
# all_estimated_gradients = []
# all_similarity = []
# for idx in range(n):
#     print("getting gradient samples in iteration: ", idx)
#     sampled_examples = [] # 
#     backpropagated_input_embeddings = [] # to be averaged later on
#     for _ in range(k):
#         prompt_input_embeds, input_ids, output = generate_response(model, tokenizer)
#         input_len = len(input_ids[0])

#         # note that this input embed is the FULL embedding with the input prompt and model response (because we want to get logits of the responses)
#         # shape = [batchsize, prompt_length + response_length, embed dim]. Later on, we just need to extract the first prompt_length for inferencing.
#         updated_input_embeds = backpropagate_gradients_to_embedding_based_on_logits(model, tokenizer, output, input_ids)
#         sampled_examples.append(output)
#         backpropagated_input_embeddings.append(updated_input_embeds) # updated embedding after backward pass

#     logit_grad = torch.zeros_like(backpropagated_input_embeddings[0][:,:input_len,:]) # log p(x_1) + log p(x_2) + ... + p(x_k)
#     for embedding in backpropagated_input_embeddings:
#         grad_log_prob = embedding[:,:input_len,:] - prompt_input_embeds[:,:input_len,:] #  prompt_input_embeds is the original embedding; find the difference from updated embedding to get the gradient update.
#         logit_grad += grad_log_prob # derivative of (log p(x_1) + log p(x_2) + ... + p(x_k))

#     mmd = rbf_mmd(torch.tensor(embed(sampled_examples)), torch.tensor(embed(ground_truth))) # MMD value
#     all_similarity.append(mmd)
#     estimated_gradient = (1/n) * mmd * logit_grad # REINFORCE equation
#     all_estimated_gradients.append(estimated_gradient) # each sample in REINFORCE

# print("embedding norm before gradient update: ", prompt_input_embeds[:,:input_len,:].sum())
# print("average MMD values before updated: ", np.array(all_similarity).mean())
# # expected_gradient is average of all_estimated_gradients
# expected_gradient = torch.stack(all_estimated_gradients).sum(dim=0)
# new_generation_input_embed = prompt_input_embeds[:,:input_len,:] - lr * expected_gradient
# print("embedding norm after gradient update: ", new_generation_input_embed[:,:input_len,:].sum())
# print("========GENERATING WITH NEW INPUT EMBEDDING AFTER GRADIENT UPDATE=========")

# # Manually generate tokens step-by-step
# max_new_tokens=50
# num_samples = 10
# generated_tokens = input_ids.clone()
# new_samples = generate_response_with_embedding(max_new_tokens, model, new_generation_input_embed, num_samples, input_ids)
# print("new MMD values after gradient update: ", rbf_mmd(torch.tensor(embed(new_samples)), torch.tensor(embed(ground_truth))))