In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5"

import torch_influence
import torchvision
import torch
import numpy as np
import matplotlib.pyplot as plt
torch.set_warn_always(False)

from typing import List
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
import yaml
import lm_eval

import datasets
import os
import sys
from typing import List

import fire
import torch
import transformers
from datasets import load_dataset

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training ,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer

%load_ext autoreload
%autoreload 2

  warn(
  from .autonotebook import tqdm as notebook_tqdm
2025-02-11 14:21:37.537246: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-11 14:21:37.537311: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-11 14:21:37.538785: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-11 14:21:37.546802: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler fl

# Load LLM model

In [2]:
from llm import get_tokenizer_and_model

tokenizer, model = get_tokenizer_and_model(model_id = "meta-llama/Meta-Llama-3-8B-Instruct")
model = model.to("cuda:5")



Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  5.25it/s]


# Load dataset

In [3]:
dataset = datasets.load_dataset("allenai/sciq")

In [4]:
ground_truth = dataset["test"][100:120]['question'] # take 10 samples for ground truth for now
train_gen = dataset["train"][100:120]['question'] # take 10 samples for ground truth for now

# Embedding model

In [5]:
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True, device="cpu")
embedding_model = embedding_model.to("cuda:4")

2025-02-11:14:21:59,568 INFO     [SentenceTransformer.py:218] Load pretrained SentenceTransformer: Alibaba-NLP/gte-Qwen2-1.5B-instruct
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.34it/s]
2025-02-11:14:22:06,115 INFO     [SentenceTransformer.py:357] 1 prompts are loaded, with the keys: ['query']


# Helper functions

In [37]:
from typing import List, AnyStr
from sklearn.preprocessing import normalize
from transformers import logging

logging.set_verbosity_warning()
temp = 0.99

# embed a list of texts
def embed(data : List[AnyStr]) -> torch.Tensor:
    max_length = 32768
    passage_embeddings = embedding_model.encode(data)
    # normalize embeddings
    #query_embeddings = normalize(passage_embeddings)
    return passage_embeddings

# mmd function
def rbf_mmd(X, Y, sigma=1.0, chunk_size=None):
    gamma = 1 / X.shape[1]
    def row_mean(v, X):
        dist_sqrs = torch.sum((X - v)**2, dim=1)
        return torch.exp(-gamma * dist_sqrs).mean()
    kernel_X = lambda v: row_mean(v, X)
    kernel_Y = lambda v: row_mean(v, Y)
    K_XX = torch.mean(torch.vmap(kernel_X, chunk_size=chunk_size)(X))
    K_XY = torch.mean(torch.vmap(kernel_X, chunk_size=chunk_size)(Y))
    K_YY = torch.mean(torch.vmap(kernel_Y, chunk_size=chunk_size)(Y))
    return K_XX + K_YY - 2 * K_XY

# generate a single sample from LLM, based on 3 examples. 
# can set temperature higher to get more diverse responses.
def generate_response(model, tokenizer):
    messages = [
        {"role": "system", "content": "You are my assistant. Please look at the examples of questions given and write a similar question with the same topic or flavour. Do not give the solution or any extra words."},
        {"role": "user", "content": "\n".join(dataset["train"][0:5]['question'])},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    outputs = model.generate(
        input_ids,
        max_new_tokens=256,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=terminators,
        do_sample=True,
        temperature=temp,
        top_p=1.0,
    )
    response = outputs[0][input_ids.shape[-1]:]
    output = tokenizer.decode(response, skip_special_tokens=True)
        
    return model.get_input_embeddings()(input_ids), input_ids, output

# given a extract_string text response, find the logits from an LLM and backpropagate the gradients to the embedding values
# input_ids are the input prompts used (so we start generating log-probs from that point)
def backpropagate_gradients_to_embedding_based_on_logits(model, tokenizer, model_output, input_ids, prompt_embedding):
    for param in model.parameters():
        param.requires_grad = False # freeze all params

    # output_ids = tokenizer(model_output, return_tensors="pt").input_ids.to("cuda:5")
    # # Concatenate input and output to form full sequence. This makes it easy to find the logit later.
    # full_input = torch.cat([input_ids, output_ids], dim=-1)

    # # **Extract embeddings with requires_grad=True**
    # embedding_layer = model.get_input_embeddings()  # Embedding layer
    # input_embeds = embedding_layer(full_input).detach().clone()  # Shape: (1, input_length + response length, hidden_size)

    # # allow gradients on input embedding
    # input_embeds.requires_grad = True  # Enable gradient tracking
    # optimizer = torch.optim.Adam([input_embeds], lr=1.0)

    # # Forward pass
    # outputs = model(inputs_embeds=input_embeds, disable_tqdm=True)

    # # Extract logits (this is the full sentence logit)
    # logits = outputs.logits  # Shape: (batch_size, seq_length, vocab_size)

    # # Compute log-softmax to get log probabilities
    # log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # # Shift output_ids for teacher forcing (we predict the next token)
    # target_ids = full_input[:, 1:]  # Shift left for alignment

    # # Gather log probabilities corresponding to the actual output tokens
    # output_log_probs = log_probs[:, :-1, :].gather(dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)

    # # Compute total negative log-likelihood starting only from end of input (so only log-probs of response)
    # total_log_likelihood = - output_log_probs[:, input_ids.shape[-1]:].sum()
    
    output_ids = tokenizer(model_output, return_tensors="pt").input_ids.to("cuda:5")
    # Concatenate input and output to form full sequence. This makes it easy to find the logit later.

    # **Extract embeddings with requires_grad=True**
    embedding_layer = model.get_input_embeddings()  # Embedding layer
    output_embed = embedding_layer(output_ids).detach().clone()
    input_embeds = prompt_embedding.detach().clone()
    input_embeds = torch.concat([input_embeds, output_embed], dim=1)
    
    # allow gradients on input embedding
    input_embeds.requires_grad = True  # Enable gradient tracking
    optimizer = torch.optim.Adam([input_embeds], lr=1.0)

    # Forward pass
    outputs = model(inputs_embeds=input_embeds, disable_tqdm=True)

    # Extract logits (this is the full sentence logit)
    logits = outputs.logits  # Shape: (batch_size, seq_length, vocab_size)

    # Compute log-softmax to get log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Shift output_ids for teacher forcing (we predict the next token)
    full_input=torch.concat([input_ids, output_ids], dim=1)
    target_ids = full_input[:, 1:]  # Shift left for alignment

    # Gather log probabilities corresponding to the actual output tokens
    output_log_probs = log_probs[:, :-1, :].gather(dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)

    # Compute total negative log-likelihood starting only from end of input (so only log-probs of response)
    total_log_likelihood = output_log_probs[:, input_ids.shape[-1]:].sum()

    # Backpropagate
    total_log_likelihood.backward()
    optimizer.step()
    
    return input_embeds # return the input embedding after it has been updated with gradients

# generate an LLM response with an input embedding
def generate_response_with_embedding(max_new_tokens, model, new_embed, num_samples, input_ids):
    samples = []
    for _ in range(num_samples):
        generated_tokens = input_ids.clone()
        new_generation_input_embed = new_embed.clone()
        for _ in range(max_new_tokens):
            outputs = model(inputs_embeds=new_generation_input_embed, disable_tqdm=True)
            logits = outputs.logits[:, -1, :]  # Get logits for the last token
            
            #next_token = torch.argmax(logits, dim=-1, keepdim=True)  # Greedy decoding
            
            # temperature
            logits = logits / temp
            # Convert to probabilities and sample
            probs = torch.nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append new token to generated sequence
            generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

            # Update `input_embeds` to include new token embeddings
            next_token_embedding = model.get_input_embeddings()(next_token)
            new_generation_input_embed = torch.cat([new_generation_input_embed, next_token_embedding], dim=1)

            # Stop if EOS token is generated
            if next_token.item() == tokenizer.eos_token_id:
                break
        samples.append(tokenizer.decode(generated_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True))
    return samples

# Run REINFORCE to update embedding layers

In [39]:
# starting embedding to start gradient descent
prompt_input_embeds, input_ids, output = generate_response(model, tokenizer)

n=10 # n, same as formula
k=20 # k, same as formula
lr=0.01 # learning rate
training_steps = 10 # number of GD steps.
for step in range(training_steps):
    print("training step: ", step)
    all_estimated_gradients = []
    all_similarity = []
    for idx in range(n):
        print("getting gradient samples in iteration: ", idx)
        sampled_examples = [] # 
        backpropagated_input_embeddings = []
        for _ in range(k):
            input_len = len(input_ids[0])
            
            # generate a random LLM response from current embedding
            output = generate_response_with_embedding(256, model, prompt_input_embeds, 1, input_ids)[0]
            
            # backpropagate to update the embedding
            updated_input_embeds = backpropagate_gradients_to_embedding_based_on_logits(model, tokenizer, output, input_ids, prompt_input_embeds)
            sampled_examples.append(output)
            backpropagated_input_embeddings.append(updated_input_embeds)

        logit_grad = torch.zeros_like(backpropagated_input_embeddings[0][:,:input_len,:]) # log p(x_1) + log p(x_2) + ... + p(x_k)
        for embedding in backpropagated_input_embeddings:
            
            # calculate the gradient based on updated embedding i.e., derivative of log P(X)
            grad_log_prob = embedding[:,:input_len,:] - prompt_input_embeds[:,:input_len,:]
            
            # log P(X_1) + ... + P(X_k)
            logit_grad += grad_log_prob
        
        with torch.no_grad():
            # compute MMD of current k samples
            mmd = rbf_mmd(torch.tensor(embed(sampled_examples)), torch.tensor(embed(ground_truth))) # MMD value

        all_similarity.append(mmd)
        estimated_gradient_sample = mmd * logit_grad # REINFORCE equation to get one sample of logit gradient
        all_estimated_gradients.append(estimated_gradient_sample)
        
    with torch.no_grad():
        print("embedding norm before gradient update: ", prompt_input_embeds[:,:input_len,:].sum())
        print("average MMD values before updated: ", np.array(all_similarity).mean())
        expected_gradient = torch.stack(all_estimated_gradients).sum(dim=0) * (1/n) # expected gradient
        print("gradient norm: ", expected_gradient.sum())
        prompt_input_embeds = prompt_input_embeds - lr * expected_gradient # update embedding with gradient
        print("embedding norm after gradient update: ", prompt_input_embeds[:,:input_len,:].sum())

        # check MMD for newly generated samples
        max_new_tokens=256
        num_samples = 10
        generated_tokens = input_ids.clone()
        new_samples = generate_response_with_embedding(max_new_tokens, model, prompt_input_embeds, num_samples, input_ids)
        print("new MMD values after gradient update: ", rbf_mmd(torch.tensor(embed(new_samples)), torch.tensor(embed(ground_truth))))

training step:  0
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.89it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.53it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.30it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.35it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.62it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.03it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.91it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.07it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.93it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.13it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.09it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.14it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.95it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.23it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.38it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.99it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.61it/s]


embedding norm before gradient update:  tensor(30.8750, device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.00020962954
gradient norm:  tensor(-8.8750, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(31., device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 34.86it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 13.76it/s]


new MMD values after gradient update:  tensor(0.0002)
training step:  1
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.32it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.59it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.14it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.17it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.88it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.70it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.86it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.58it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.84it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.92it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.21it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.83it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.18it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.98it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.68it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.11it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.37it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.24it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.43it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.70it/s]


embedding norm before gradient update:  tensor(31., device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.00019181966
gradient norm:  tensor(19.2500, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(30.7500, device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 27.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 13.21it/s]


new MMD values after gradient update:  tensor(0.0002)
training step:  2
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.76it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.00it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.15it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.28it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.00it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.10it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 11.68it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.24it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:01<00:00,  1.14s/it]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.39it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.63it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.37it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.12it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.91it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.96it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.29it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.33it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.38it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.35it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.47it/s]


embedding norm before gradient update:  tensor(30.7500, device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.00020333528
gradient norm:  tensor(-3.5625, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(30.7500, device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 28.45it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 14.69it/s]


new MMD values after gradient update:  tensor(0.0002)
training step:  3
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 10.12it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.85it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.25it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.47it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.47it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.15it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.32it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.32it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.91it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.19it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.74it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.35it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.30it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.26it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.52it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.19it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.11it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.79it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.39it/s]


embedding norm before gradient update:  tensor(30.7500, device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.00021369457
gradient norm:  tensor(-5.5938, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(30.8750, device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 27.34it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.59it/s]


new MMD values after gradient update:  tensor(0.0002)
training step:  4
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 16.28it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.39it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 16.54it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.12it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.36it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.01it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 16.18it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.29it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.76it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.86it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 20.09it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.96it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.88it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.54it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.71it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.34it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.92it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.86it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.38it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.89it/s]


embedding norm before gradient update:  tensor(30.8750, device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.0002029419
gradient norm:  tensor(-9.6875, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(30.8750, device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 28.65it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.82it/s]


new MMD values after gradient update:  tensor(0.0003)
training step:  5
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.02it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.84it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.79it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.27it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.72it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.18it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.60it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.17it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.66it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.09it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.26it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.22it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.81it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.09it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.11it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.91it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.70it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.91it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.81it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.92it/s]


embedding norm before gradient update:  tensor(30.8750, device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.0002039671
gradient norm:  tensor(-17.8750, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(31.1250, device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 34.18it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.74it/s]


new MMD values after gradient update:  tensor(0.0003)
training step:  6
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.40it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.51it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00,  4.97it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.92it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.26it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.94it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.35it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.76it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.25it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.67it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.47it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 13.58it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.57it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.05it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.48it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.40it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00,  3.71it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.04it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.50it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.44it/s]


embedding norm before gradient update:  tensor(31.1250, device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.00021175147
gradient norm:  tensor(-10.5625, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(31.2500, device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 27.30it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.23it/s]


new MMD values after gradient update:  tensor(0.0002)
training step:  7
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.24it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.49it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.27it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.13it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.95it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.34it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.87it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.56it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.94it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.12it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.70it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.60it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.38it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.19it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.49it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.51it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 18.69it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.66it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.89it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.50it/s]


embedding norm before gradient update:  tensor(31.2500, device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.00023255349
gradient norm:  tensor(22.3750, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(31., device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 28.11it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 13.39it/s]


new MMD values after gradient update:  tensor(0.0003)
training step:  8
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 19.13it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.26it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 16.96it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.54it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 10.28it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.91it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 20.06it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.66it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 18.11it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.28it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.24it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.82it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.78it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.18it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 13.07it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.17it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.63it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.03it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.74it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 12.15it/s]


embedding norm before gradient update:  tensor(31., device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.00023503303
gradient norm:  tensor(-24.3750, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(31.2500, device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 27.88it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 13.82it/s]


new MMD values after gradient update:  tensor(0.0002)
training step:  9
getting gradient samples in iteration:  0


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.93it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.14it/s]


getting gradient samples in iteration:  1


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.54it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.63it/s]


getting gradient samples in iteration:  2


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.49it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.25it/s]


getting gradient samples in iteration:  3


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.59it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.42it/s]


getting gradient samples in iteration:  4


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.93it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.47it/s]


getting gradient samples in iteration:  5


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.87it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.61it/s]


getting gradient samples in iteration:  6


Batches: 100%|██████████| 1/1 [00:00<00:00, 17.53it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.71it/s]


getting gradient samples in iteration:  7


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.50it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 10.93it/s]


getting gradient samples in iteration:  8


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.02it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.23it/s]


getting gradient samples in iteration:  9


Batches: 100%|██████████| 1/1 [00:00<00:00, 19.16it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.65it/s]


embedding norm before gradient update:  tensor(31.2500, device='cuda:5', dtype=torch.bfloat16)
average MMD values before updated:  0.00021905899
gradient norm:  tensor(16.3750, device='cuda:5', dtype=torch.bfloat16)
embedding norm after gradient update:  tensor(31.1250, device='cuda:5', dtype=torch.bfloat16)


Batches: 100%|██████████| 1/1 [00:00<00:00, 24.05it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 11.29it/s]

new MMD values after gradient update:  tensor(0.0003)





# Check ground truth

In [27]:
ground_truth

['Where is the spinal trigeminal nucleus located?',
 'The lithosphere is divided into a dozen major and several minor what?',
 'During the first year after birth, what is a baby called?',
 'What are used to indicate the number of atoms of an element that are in the compound?',
 'Area, volume, and speed are all examples of what type of units?',
 'Anything moving has what type of energy?',
 'A skydiver will reach what when the air drag equals their weight?',
 'What organs are considered the female gonads?',
 'What is the adaptation that certain animals use to become less visible to predators and prey?',
 'What is another term for blood clotting?',
 'What do you call the study of how organisms interact with their environment and each other?',
 'Childbirth usually starts when which sac breaks?',
 'What phenomenon, which is most important in small populations, occurs because the alleles in an offspring generation are a random sample of the alleles in the parent generation?',
 'What is the t

# Generate new samples

In [28]:
# prompt_input_embeds are updated embeddings
generate_response_with_embedding(max_new_tokens, model, prompt_input_embeds, num_samples, input_ids)

['What is the primary difference in composition between granite and basalt rocks?',
 "What is the primary medium used for heat transfer in the earth's core?",
 'What is the primary component of the stars in our galaxy, including our sun, and is also found in our body in the form of calcium in bones and teeth?',
 'What type of sounds are produced by the wind blowing over sand ridges and dunes?',
 'What type of insect is known for its distinctive "waggle dance" used for communication?',
 'Gases',
 'What is the process by which the moon appears to move along a path in the sky that changes shape as it orbits around the earth?',
 "What is the primary component of a solar flare's energy spectrum?",
 "What type of rock is typically formed from the cooling and solidification of magma deep within the Earth's crust?",
 'What is a prominent characteristic of fiery meteor showers?']

In [None]:
# from typing import List, AnyStr
# from sklearn.preprocessing import normalize
# from transformers import logging

# logging.set_verbosity_warning()
# temp = 0.99

# # embed a list of texts
# def embed(data : List[AnyStr]) -> torch.Tensor:
#     max_length = 32768
#     passage_embeddings = embedding_model.encode(data)
#     # normalize embeddings
#     query_embeddings = normalize(passage_embeddings)
#     return query_embeddings

# # mmd function
# def rbf_mmd(X, Y, sigma=1.0, chunk_size=None):
#     gamma = 1 / X.shape[1]
#     def row_mean(v, X):
#         dist_sqrs = torch.sum((X - v)**2, dim=1)
#         return torch.exp(-gamma * dist_sqrs).mean()
#     kernel_X = lambda v: row_mean(v, X)
#     kernel_Y = lambda v: row_mean(v, Y)
#     K_XX = torch.mean(torch.vmap(kernel_X, chunk_size=chunk_size)(X))
#     K_XY = torch.mean(torch.vmap(kernel_X, chunk_size=chunk_size)(Y))
#     K_YY = torch.mean(torch.vmap(kernel_Y, chunk_size=chunk_size)(Y))
#     return K_XX + K_YY - 2 * K_XY

# # generate a single sample from LLM, based on 3 examples. 
# # can set temperature higher to get more diverse responses.
# def generate_response(model, tokenizer):
#     messages = [
#         {"role": "system", "content": "You are my assistant. Please look at the examples of questions given and write a similar question with the same topic or flavour. Do not give the solution or any extra words."},
#         {"role": "user", "content": "\n".join(dataset["train"][0:5]['question'])},
#     ]

#     input_ids = tokenizer.apply_chat_template(
#         messages,
#         add_generation_prompt=True,
#         return_tensors="pt"
#     ).to(model.device)

#     terminators = [
#         tokenizer.eos_token_id,
#         tokenizer.convert_tokens_to_ids("<|eot_id|>")
#     ]
#     outputs = model.generate(
#         input_ids,
#         max_new_tokens=256,
#         pad_token_id=tokenizer.eos_token_id,
#         eos_token_id=terminators,
#         do_sample=True,
#         temperature=temp,
#         top_p=1.0,
#     )
#     response = outputs[0][input_ids.shape[-1]:]
#     output = tokenizer.decode(response, skip_special_tokens=True)
        
#     return model.get_input_embeddings()(input_ids), input_ids, output

# # given a extract_string text response, find the logits from an LLM and backpropagate the gradients to the embedding values
# # input_ids are the input prompts used (so we start generating log-probs from that point)
# def backpropagate_gradients_to_embedding_based_on_logits(model, tokenizer, model_output, input_ids, prompt_embedding):
#     for param in model.parameters():
#         param.requires_grad = False # freeze all params

#     # output_ids = tokenizer(model_output, return_tensors="pt").input_ids.to("cuda:5")
#     # # Concatenate input and output to form full sequence. This makes it easy to find the logit later.
#     # full_input = torch.cat([input_ids, output_ids], dim=-1)

#     # # **Extract embeddings with requires_grad=True**
#     # embedding_layer = model.get_input_embeddings()  # Embedding layer
#     # input_embeds = embedding_layer(full_input).detach().clone()  # Shape: (1, input_length + response length, hidden_size)

#     # # allow gradients on input embedding
#     # input_embeds.requires_grad = True  # Enable gradient tracking
#     # optimizer = torch.optim.Adam([input_embeds], lr=1.0)

#     # # Forward pass
#     # outputs = model(inputs_embeds=input_embeds, disable_tqdm=True)

#     # # Extract logits (this is the full sentence logit)
#     # logits = outputs.logits  # Shape: (batch_size, seq_length, vocab_size)

#     # # Compute log-softmax to get log probabilities
#     # log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

#     # # Shift output_ids for teacher forcing (we predict the next token)
#     # target_ids = full_input[:, 1:]  # Shift left for alignment

#     # # Gather log probabilities corresponding to the actual output tokens
#     # output_log_probs = log_probs[:, :-1, :].gather(dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)

#     # # Compute total negative log-likelihood starting only from end of input (so only log-probs of response)
#     # total_log_likelihood = - output_log_probs[:, input_ids.shape[-1]:].sum()
    
#     output_ids = tokenizer(model_output, return_tensors="pt").input_ids.to("cuda:5")
#     # Concatenate input and output to form full sequence. This makes it easy to find the logit later.

#     # **Extract embeddings with requires_grad=True**
#     embedding_layer = model.get_input_embeddings()  # Embedding layer
#     output_embed = embedding_layer(output_ids).detach().clone()
#     input_embeds = prompt_embedding.detach().clone()
#     input_embeds = torch.concat([input_embeds, output_embed], dim=1)
    
#     # allow gradients on input embedding
#     input_embeds.requires_grad = True  # Enable gradient tracking
#     optimizer = torch.optim.Adam([input_embeds], lr=1.0)

#     # Forward pass
#     outputs = model(inputs_embeds=input_embeds, disable_tqdm=True)

#     # Extract logits (this is the full sentence logit)
#     logits = outputs.logits  # Shape: (batch_size, seq_length, vocab_size)

#     # Compute log-softmax to get log probabilities
#     log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

#     # Shift output_ids for teacher forcing (we predict the next token)
#     full_input=torch.concat([input_ids, output_ids], dim=1)
#     target_ids = full_input[:, 1:]  # Shift left for alignment

#     # Gather log probabilities corresponding to the actual output tokens
#     output_log_probs = log_probs[:, :-1, :].gather(dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)

#     # Compute total negative log-likelihood starting only from end of input (so only log-probs of response)
#     total_log_likelihood = - output_log_probs[:, input_ids.shape[-1]:].sum()

#     # Backpropagate
#     total_log_likelihood.backward()
#     optimizer.step()
    
#     return input_embeds # return the input embedding after it has been updated with gradients

# # generate an LLM response with an input embedding
# def generate_response_with_embedding(max_new_tokens, model, new_embed, num_samples, input_ids):
#     samples = []
#     for _ in range(num_samples):
#         generated_tokens = input_ids.clone()
#         new_generation_input_embed = new_embed.clone()
#         for _ in range(max_new_tokens):
#             outputs = model(inputs_embeds=new_generation_input_embed, disable_tqdm=True)
#             logits = outputs.logits[:, -1, :]  # Get logits for the last token
            
#             #next_token = torch.argmax(logits, dim=-1, keepdim=True)  # Greedy decoding
            
#             # temperature
#             logits = logits / temp
#             # Convert to probabilities and sample
#             probs = torch.nn.functional.softmax(logits, dim=-1)
#             next_token = torch.multinomial(probs, num_samples=1)
            
#             # Append new token to generated sequence
#             generated_tokens = torch.cat([generated_tokens, next_token], dim=-1)

#             # Update `input_embeds` to include new token embeddings
#             next_token_embedding = model.get_input_embeddings()(next_token)
#             new_generation_input_embed = torch.cat([new_generation_input_embed, next_token_embedding], dim=1)

#             # Stop if EOS token is generated
#             if next_token.item() == tokenizer.eos_token_id:
#                 break
#         samples.append(tokenizer.decode(generated_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True))
#     return samples