<a href="https://colab.research.google.com/github/junruren/6.7960-2024-Fall/blob/main/6_7960_final_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The Llama model is access-controlled. Speak with Junru for access.

In [1]:
!pip install -q -U bitsandbytes accelerate transformers

In [11]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Check if CUDA is available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")

# 8-bit quantization config
quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_threshold=6.0
)

# Load LLaMA-2-7B
model_name = "meta-llama/Llama-2-7b-chat-hf"

# Load the LLM with 8-bit quantization
model_llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
).eval()  # Keep LLM in eval mode

tokenizer = AutoTokenizer.from_pretrained(model_name)

Using device: cuda


config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [15]:
#-----------------------------------------
# Setup a toy dataset of Q/A pairs
#-----------------------------------------
train_data = [
    {"question": "What is the capital of France?", "answer": "Paris"},
    {"question": "Who wrote 'Pride and Prejudice'?", "answer": "Jane Austen"},
    {"question": "What is the largest planet in our solar system?", "answer": "Jupiter"}
]

#-----------------------------------------
# Define the "prompt for problem-solving prompt" P
#-----------------------------------------
P = "Given the following question, produce a concise and refined prompt by incorporating the exact question and adding additional instructions that guide the model to reason step-by-step, carefully check its reasoning, and provide a correct and well-explained answer. The refined prompt should not reveal the solution but should encourage thorough verification and clarity. Question: "

In [16]:
#-----------------------------------------
# The Refiner model:
# We'll define a simple model that takes the LLM hidden states
# and outputs a refined prompt as logits over tokens.
#-----------------------------------------
class SimpleRefiner(nn.Module):
    def __init__(self, hidden_size=4096, vocab_size=32000, seq_len=10):
        super().__init__()
        self.seq_len = seq_len
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size

        # We'll pool over sequence dimension
        self.pool = nn.AdaptiveAvgPool1d(1)

        # MLP from hidden_size to seq_len*vocab_size
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, seq_len * vocab_size)
        )

    def forward(self, llm_hidden_states):
        # llm_hidden_states: [batch, seq_len, hidden_size]
        # Transpose for pooling: [batch, hidden_size, seq_len]
        x = llm_hidden_states.transpose(1, 2)
        # pool -> [batch, hidden_size, 1]
        x = self.pool(x).squeeze(-1) # [batch, hidden_size]

        out = self.mlp(x) # [batch, seq_len*vocab_size]
        out = out.view(-1, self.seq_len, self.vocab_size)
        return out

# Infer model hidden_size and vocab_size from Llama model config
hidden_size = model_llm.config.hidden_size
vocab_size = model_llm.config.vocab_size
####refiner = SimpleRefiner(hidden_size=hidden_size, vocab_size=vocab_size, seq_len=10).to(device)
refiner = SimpleRefiner(hidden_size=hidden_size, vocab_size=vocab_size, seq_len=10).to(device).half()

optimizer = AdamW(refiner.parameters(), lr=1e-4)

def text_to_tokens(text):
    return tokenizer.encode(text, return_tensors='pt').to(device)

def tokens_to_text(tokens):
    return tokenizer.decode(tokens, skip_special_tokens=True)

def generate_hidden_states(model, input_ids):
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    hidden_states = outputs.hidden_states[-1] # last layer hidden states
    return hidden_states

In [18]:
# Training loop (conceptual)
for epoch in range(1):  # Just one epoch for demonstration
    for example in train_data:
        question = example["question"]
        answer = example["answer"]

        # Step 1: Prepare input for LLM: "P + Q"
        input_text = P + question
        input_ids = text_to_tokens(input_text)

        # Sanity check: print LLM's raw completion (initial output)
        with torch.no_grad():
            gen_ids = model_llm.generate(input_ids, max_new_tokens=50)
            gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
            print("LLM initial completion:\n", gen_text, "\n")

        # Step 2: Run LLM to get hidden states
        with torch.no_grad():
            outputs = model_llm(input_ids, output_hidden_states=True)
            llm_hidden_states = outputs.hidden_states[-1]  # [1, seq_len, hidden_size]

        # Step 3: Run refiner
        refiner.train()
        refiner_logits = refiner(llm_hidden_states)  # [1, seq_len=10, vocab_size]

        # Step 4: Convert refiner_logits to tokens (non-differentiable)
        refined_prompt_ids = torch.argmax(refiner_logits, dim=-1)  # [1, 10]
        refined_prompt_text = tokens_to_text(refined_prompt_ids[0])

        # Step 5: Concatenate refined prompt and question
        final_input_text = refined_prompt_text.strip() + ": " + question + " Please present your final answer in a new line prefixed with \"Answer: \"."
        final_input_ids = text_to_tokens(final_input_text)

        # Another sanity check: print LLM's output with the refined prompt
        with torch.no_grad():
            gen_refined_ids = model_llm.generate(final_input_ids, max_new_tokens=50)
            gen_refined_text = tokenizer.decode(gen_refined_ids[0], skip_special_tokens=True)
            print("LLM output after refined prompt:\n", gen_refined_text, "\n")

        # Step 6: Compute loss with the known answer
        combined_input = final_input_text + " " + answer
        combined_ids = text_to_tokens(combined_input)

        # Compute LM loss
        outputs = model_llm(combined_ids, labels=combined_ids)
        loss = outputs.loss

        # Step 7: Backprop into refiner
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item()} | Refined Prompt: {refined_prompt_text}")
        print("------------------------------------------------------------------------------------")

print("Training done.")

LLM initial completion:
 Given the following question, produce a concise and refined prompt by incorporating the exact question and adding additional instructions that guide the model to reason step-by-step, carefully check its reasoning, and provide a correct and well-explained answer. The refined prompt should not reveal the solution but should encourage thorough verification and clarity. Question: What is the capital of France? Can you explain how you know this? Additional Instructions:

1. Break down the question into smaller parts and explain each part step-by-step.
2. Provide clear and specific examples to support each step.
 

LLM output after refined prompt:
 südbindagsawnggreguler NGC though Tele Pse: What is the capital of France? Please present your final answer in a new line prefixed with "Answer: ".

Answer: Paris 

Loss: 6.3667755126953125 | Refined Prompt: südbindagsawnggreguler NGC though Tele Pse
-------------------------------------------------------------------------