In [2]:
import torch
from tqdm import tqdm

from utils import norm_logits, sample, max_fn, Decoder
from transformers import AutoModelForSequenceClassification, AutoTokenizer

import numpy as np
import time
from matplotlib import pyplot as plt
from utils import create_models, color_print

MODELZOO = {
    # llama-1
    # https://huggingface.co/PY007/TinyLlama-1.1B-step-50K-105b
    "llama1b": "/share_nfs/fangjiarui/root/code/hf_models/TinyLlama-1.1B-step-50K-105b",
    "llama7b": "/share_nfs/tianzhi/code/llama-7b",
    "llama30b": "/share_nfs/fangjiarui/root/code/hf_models/llama-30b-hf",
    "llama2-7b" : "/share_nfs/fangjiarui/root/code/hf_models/llama-2-7b-hf",
    "llama2-70b" : "/share_nfs/fangjiarui/root/code/hf_models/llama-2-70b-hf",
    "llama3-8b":"solidrust/Meta-Llama-3-8B-Instruct-AWQ",
    "llama3-13b":"solidrust/Llama-3-13B-Instruct-v0.1-AWQ",
    "bloom-560m": "/share_nfs/fangjiarui/root/code/hf_models/bloom-560m",
    "bloom7b": "/share_nfs/fangjiarui/root/code/hf_models/bloomz-7b1",
    "baichuan-7b": "/share_nfs/duanqiyuan/models/source_models/hf/baichuan-7B",
    "baichuan-13b": "/share_nfs/duanqiyuan/models/source_models/hf/Baichuan-13B-Base",
}

In [5]:
import torch
from transformers import AutoModelForCausalLM
device = "cuda:3"
model_name = "Skywork/Skywork-Reward-Llama-3.1-8B"
# model_name = "Skywork/Skywork-Reward-Gemma-2-27B"
rm = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device,
    # attn_implementation="flash_attention_2",
    num_labels=1,
)
rm_tokenizer = AutoTokenizer.from_pretrained(model_name)
def return_reward(question,answer):
    if question == None:
        conv1 = [{"role": "assistant", "content": answer}]
    else:
        conv1 = [{"role": "user", "content": question}, {"role": "assistant", "content": answer}]
    conv1_formatted = rm_tokenizer.apply_chat_template(conv1, tokenize=False)
    conv1_tokenized = rm_tokenizer(conv1_formatted, return_tensors="pt").to(device)
    with torch.no_grad():
        score1 = rm(**conv1_tokenized).logits[0][0].item()
    return score1

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [9]:
draft_model, target_model, tokenizer = create_models(MODELZOO["llama3-8b"], MODELZOO["llama3-13b"],device=rm.device)

=====doing tokenizer
begin loading models: 
 solidrust/Meta-Llama-3-8B-Instruct-AWQ 
 solidrust/Llama-3-13B-Instruct-v0.1-AWQ


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

finish loading models


In [4]:
max_children = 2
exploration_weight = 1.41
import math
import random

class Node:
    def __init__(self,question,answer,prev_answers='',parent=None):
        global node_count
        self.question = question
        self.answer = answer
        self.prev_answers = prev_answers
        self.parent = parent
        self.children = []
        self.reward = 0.0
        self.value = 0.0
        self.visits = 0
        self.name = node_count
        node_count += 1

    def is_fully_expanded(self):
        # check if we can expand further
        return len(self.children) >= max_children
    
    def best_child(self):
        #return index of best child
        choice_weights = []
        for child in self.children:
            if child.visits == 0:
                weight = float('inf')
            else:
                weight = child.value/child.visits + exploration_weight* math.sqrt((2*math.log(self.visits) / child.visits))
            choice_weights.append(weight)
        return self.children[np.argmax(choice_weights)] 
    
    def most_visited_child(self):
        # return most visited child
        return max(self.children, key=lambda child:child.visits)
    
    def add_child(self,child_node):
        # add child_node to current node
        self.children.append(child_node)
        child_node.parent = self

In [8]:
class TreeSearch:
    def __init__(self,prefix,max_iterations=3):
        self.prefix = prefix
        self.max_iterations = max_iterations
    
    def search(self):
        for i in range(self.max_iterations):
            # print(f"Iteration: {i}/{self.max_iterations}")
            node = self.select(self.root)
            # print(f"Selected node: {node.name}")
            if node.visits > 0:
                node = self.expand(node)
            # if not node.is_fully_expanded():
            #     node = self.expand(node)
            node.reward = self.evaluate(node)
            # print(f"Assigned reward for node {node.name} is: {node.reward}")
            self.backpropagate(node)
        return self.best_node
    
    def select(self, node):
        while node.is_fully_expanded() and node.children:
            # if node has a child and it is fully expanded, go deep
            # if at a leaf or not fully expanded, return that node and expand
            node = node.best_child()
        return node

    def expand(self, node):
        for j in range(max_children - len(node.children)):
            child_answer = get_answer(self.question, node.prev_answers)
            child_node = Node(self.question, answer=child_answer, prev_answers=node.prev_answers+node.answer)
            node.add_child(child_node)
            # print(f'Added a new node {child_node.name} to node {node.name}')            
        return random.choice(node.children)

    def evaluate(self, node):
        return return_reward(self.prefix, node.answer)
    
    def backpropagate(self, node):
        while node is not None:
            node.visits += 1
            node.value += node.reward
            if node.value/node.visits > self.best_value:
                self.best_value = node.value/node.visits
                self.best_node = node
            node = node.parent

In [10]:
input_text ='cat hopped'
prefix = tokenizer.encode(input_text, return_tensors='pt').to(draft_model.device)
approx_model = draft_model
target_model = target_model 
max_len = 5 
gamma = 4
temperature = 1
top_k = 0
top_p : float = 0 
random_seed = None
reward_coeff = 0.01

def leveled_tree(target_model, x, num_samples=3, levels=2):
    all_drafts = []
    current_level = 0
    drafts_by_level = {0:[x]}
    while current_level < levels:
        new_level_drafts = []
        for draft in drafts_by_level[current_level]:
            q_logits = target_model(draft).logits
            normalized_logits = norm_logits(q_logits[:, -1, :], temperature, top_k, top_p)
            next_tok = sample(normalized_logits,num_samples=num_samples)
            current_drafts = [torch.cat((draft, next_tok[:,i:i+1]), dim=1) for i in range(num_samples)]
            new_level_drafts += current_drafts
        current_level += 1
        drafts_by_level[current_level] = new_level_drafts

    rewards = torch.zeros(len(drafts_by_level[levels]),device=x.device)
    probs = torch.zeros(len(drafts_by_level[levels]),device=x.device)
    for i in range(len(drafts_by_level[levels])):
        rewards[i] = return_reward(None, drafts_by_level[levels][i])
        # print(f"drafts_by_level[levels][i], {drafts_by_level[levels][i].shape}")
        draft_logits = target_model(drafts_by_level[levels][i]).logits
        normalized_logits = norm_logits(draft_logits[:, -1, :], temperature, top_k, top_p)
        next_tok = sample(normalized_logits)
        probs[i] = normalized_logits[0,next_tok[0].long()]
    scores = probs * torch.exp(reward_coeff*rewards)
    scores = scores/torch.sum(scores)
    aligned_draft_index = sample(scores)
    winning_block = drafts_by_level[levels][aligned_draft_index]
    return winning_block

    # x_draft = [torch.cat((x, next_tok[:,i:i+1]), dim=1) for i in range(num_samples)]                
    # # print("x_draft",x_draft.shape)
    # x_draft_text = [tokenizer.decode(x_draft[i][0], skip_special_tokens=True) for i in range(num_samples)]
    # rewards = torch.zeros(num_samples,device=x.device)
    # for i in range(num_samples):
    #     rewards[i] = return_reward(tokenizer.decode(x_draft[i][0], skip_special_tokens=True), x_draft_text[i])
    
    # scores = normalized_logits[0,next_tok[0].long()] * torch.exp(reward_coeff*rewards)
    # scores = scores/torch.sum(scores) 
    # aligned_token_index = sample(scores)
    # aligned_token = next_tok[:,aligned_token_index]
    # return aligned_token

def small_tree(target_model, x, num_samples=5,levels=None):
    q_logits = target_model(x).logits
    normalized_logits = norm_logits(q_logits[:, -1, :], temperature, top_k, top_p)
    next_tok = sample(normalized_logits,num_samples=num_samples)
    x_draft = [torch.cat((x, next_tok[:,i:i+1]), dim=1) for i in range(num_samples)]                
    # print("x_draft",x_draft.shape)
    x_draft_text = [tokenizer.decode(x_draft[i][0], skip_special_tokens=True) for i in range(num_samples)]
    rewards = torch.zeros(num_samples,device=x.device)
    for i in range(num_samples):
        rewards[i] = return_reward(tokenizer.decode(x_draft[i][0], skip_special_tokens=True), x_draft_text[i])
    
    scores = normalized_logits[0,next_tok[0].long()] * torch.exp(reward_coeff*rewards)
    scores = scores/torch.sum(scores) 
    aligned_token_index = sample(scores)
    aligned_token = next_tok[:,aligned_token_index]
    return aligned_token

In [20]:
device = approx_model.device

model_name = "Skywork/Skywork-Reward-Llama-3.1-8B"

correlatations = []
seq_len = prefix.shape[1]
T = seq_len + max_len

assert prefix.shape[0] == 1, "input batch size must be 1"

start_time = time.time()
with tqdm(total=T, desc="speculative sampling") as pbar:
    while prefix.shape[1] < T:
        # q = M_q[prefix + x_0, x_1, .., x_(gamma-2)]
        x = prefix
        prefix_len = prefix.shape[1]
        for _ in range(gamma):
            # p.logits shape (batch, seq, vocab)
            branching_factor = 3
            levels=2
            # aligned_token = leveled_tree(target_model, x, branching_factor,levels=2)
            # print(aligned_token.shape)
            # x = torch.cat((x, aligned_token), dim=1)
            x = leveled_tree(target_model, x, branching_factor,levels=3)
        prefix = x
        print(prefix.shape)
execution_time = time.time() - start_time

speculative sampling:   0%|          | 0/25 [00:13<?, ?it/s]

torch.Size([1, 32])





In [21]:
decoded = tokenizer.decode(prefix[0], skip_special_tokens=True)
decoded

'cat hopped onto lap, head buried in\nhuman whisper-sang a lullabi\nof catnip-scents\nAs the world outside\nf'

In [22]:
return_reward(None, decoded)

-17.875