In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import CountVectorizer
from tqdm import tqdm
import joblib
import torch
from fnn import FNN
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from huggingface_hub import login

In [None]:
# paths for datasets
MED_DATASET_PATH = "data/medical.csv"
PYTHON_DATASET_PATH = "data/python.csv"

# paths for adapters
MED_LORA_PATH = "adapters/med_lora_adapter"
PYTHON_LORA_PATH = "adapters/python_lora_adapter"

# education lora is loaded from hugging face
EDU_LORA_PATH = "kaitchup/Llama-3.2-3B-Instruct-educational-chatbot"

ROUTER_SAFETENSORS = "router/router.safetensors"
COUNT_VECTORIZER = "router/count_vectorizer.joblib"

DRAFT_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
TARGET_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

HUGGING_FACE_TOKEN = ""

DRAFT_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
TARGET_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"

# Datasets

In [None]:
# Load in all datasets
df = pd.read_csv(MED_DATASET_PATH)
med_queries = df["question"].str.strip().sample(n=10, random_state=42).tolist()

df = pd.read_csv(PYTHON_DATASET_PATH)
python_queries = df["Question"].str.strip().sample(n=10, random_state=42).tolist()

df = pd.read_parquet("hf://datasets/kaitchup/qa-chat-persona-education/data/test-00000-of-00001.parquet")
education_queries = df["question"].str.strip().sample(n=10, random_state=42).tolist()

# Load in models

In [None]:
login(token=HUGGING_FACE_TOKEN)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(DRAFT_MODEL_NAME)

print("Loading draft model...")
draft_model = AutoModelForCausalLM.from_pretrained(
    DRAFT_MODEL_NAME, torch_dtype=torch.float16, device_map="auto"
)

print("Loading LoRA adapters...")
lora_adapters = {
    0: PeftModel.from_pretrained(draft_model, PYTHON_LORA_PATH),
    1: PeftModel.from_pretrained(draft_model, MED_LORA_PATH),
    2: PeftModel.from_pretrained(draft_model, EDU_LORA_PATH),
}

print("Loading target model...")
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_NAME, torch_dtype=torch.float16, device_map="auto"
)


def speculative_decoding(query, lora_id):
    if lora_id not in lora_adapters:
        raise ValueError("Invalid LoRA ID. Must be one of: 0 (Medical), 1 (Python), 2 (Educational)")

    # Select the lora adapter given by the router
    assistant_model = lora_adapters[lora_id]

    inputs = tokenizer(query, return_tensors="pt").to("cuda")

    # Perform inference which runs speculative decoding in the background
    print(f"Running speculative decoding with LoRA ID {lora_id}...")
    start_time = time.time()
    outputs = target_model.generate(
        **inputs,
        assistant_model=assistant_model,
        max_new_tokens=50,
        do_sample=False,
        temperature=0.7,
    )
    end_time = time.time()

    # Calculate throughput as tokens/second and the latency
    tokens_generated = outputs.size(1) - inputs.input_ids.size(1)
    latency = end_time - start_time
    throughput = tokens_generated / latency if latency > 0 else 0

    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print("\nSpeculative Decoding Output:")
    print(decoded_output)
    print("\nMetrics:")
    print(f"Tokens Generated: {tokens_generated}")
    print(f"Latency: {latency:.2f} seconds")
    print(f"Throughput: {throughput:.2f} tokens/second")

    return decoded_output, throughput, latency

# LoRA inference

In [None]:
login(token=HUGGING_FACE_TOKEN)

vectorizer = joblib.load(COUNT_VECTORIZER)

router_model = FNN(vectorizer.transform(["test"]).shape[1], num_classes=3)
router_model.load_state_dict(torch.load(ROUTER_SAFETENSORS, map_location=torch.device('cpu')))
router_model.eval()

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(DRAFT_MODEL_NAME)

print("Loading draft model...")
draft_model = AutoModelForCausalLM.from_pretrained(
    DRAFT_MODEL_NAME, torch_dtype=torch.float16, device_map="auto"
)

print("Loading LoRA adapters...")
lora_adapters = {
    0: PeftModel.from_pretrained(draft_model, PYTHON_LORA_PATH),
    1: PeftModel.from_pretrained(draft_model, MED_LORA_PATH),
    2: PeftModel.from_pretrained(draft_model, EDU_LORA_PATH),
}

print("Loading target model...")
target_model = AutoModelForCausalLM.from_pretrained(
    TARGET_MODEL_NAME, torch_dtype=torch.float16, device_map="auto"
)

# Function to predict which adapter to route the input to
def predict_adapter(query):
    transformed_query = vectorizer.transform([query]).toarray()
    query_tensor = torch.tensor(transformed_query, dtype=torch.float32)
    with torch.no_grad():
        predictions = router_model(query_tensor)
        probabilities = torch.softmax(predictions, dim=1)
        predicted_adapter = torch.argmax(probabilities, dim=1).item()
    return predicted_adapter

# Function to perform speculative decoding with LoRA integration
def speculative_decoding(query):
    lora_id = predict_adapter(query)
    if lora_id not in lora_adapters:
        raise ValueError("Invalid LoRA ID. Must be one of: 0 (Medical), 1 (Python), 2 (Educational)")

    # Select the appropriate LoRA adapter given from the router
    assistant_model = lora_adapters[lora_id]

    inputs = tokenizer(query, return_tensors="pt").to("cuda")

    # Perform speculative decoding
    start_time = time.time()
    outputs = target_model.generate(
        **inputs,
        assistant_model=assistant_model,
        max_new_tokens=50,
        do_sample=False,
        temperature=0.7,
    )
    end_time = time.time()

    tokens_generated = outputs.size(1) - inputs.input_ids.size(1)
    latency = end_time - start_time
    throughput = tokens_generated / latency if latency > 0 else 0

    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return decoded_output, tokens_generated, throughput, latency

# Function to process multiple queries with LoRA integration
def process_queries(queries):
    total_tokens = []
    total_throughputs = []
    total_latencies = []

    for query in queries:
        print(f"Processing query: {query}")
        _, tokens_generated, throughput, latency = speculative_decoding(query)
        total_tokens.append(tokens_generated)
        total_throughputs.append(throughput)
        total_latencies.append(latency)

    # Calculate mean and standard deviations of important metrics
    avg_tokens = np.mean(total_tokens)
    avg_throughput = np.mean(total_throughputs)
    avg_latency = np.mean(total_latencies)

    std_tokens = np.std(total_tokens)
    std_throughput = np.std(total_throughputs)
    std_latency = np.std(total_latencies)

    print("\nSummary:")
    print(f"Average Tokens Generated: {avg_tokens:.2f} (Std: {std_tokens:.2f})")
    print(f"Average Throughput: {avg_throughput:.2f} tokens/second (Std: {std_throughput:.2f})")
    print(f"Average Latency: {avg_latency:.2f} seconds (Std: {std_latency:.2f})")

# Replace med_queries with your queries of interest
if __name__ == "__main__":
    process_queries(med_queries)

# Non LoRA inference

In [None]:
# Function to perform decoding without LoRA integration
def simple_decoding(query):
    # Tokenize input
    inputs = tokenizer(query, return_tensors="pt").to("cuda")
    assistant_model = PeftModel.from_pretrained(draft_model, MED_LORA_PATH)
    # Perform decoding
    start_time = time.time()
    outputs = target_model.generate(
        **inputs,
        max_new_tokens=50,
        assistant_model=assistant_model,
        do_sample=False,
        temperature=0.7,
    )
    end_time = time.time()

    tokens_generated = outputs.size(1) - inputs.input_ids.size(1)
    latency = end_time - start_time
    throughput = tokens_generated / latency if latency > 0 else 0

    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return decoded_output, tokens_generated, throughput, latency

# Processes multipls queries without LoRA integration
def process_queries_simple(queries):
    total_tokens = []
    total_throughputs = []
    total_latencies = []

    for query in queries:
        print(f"Processing query: {query}")
        _, tokens_generated, throughput, latency = simple_decoding(query)
        total_tokens.append(tokens_generated)
        total_throughputs.append(throughput)
        total_latencies.append(latency)

    # Calculate mean and standard deviations of important metrics
    avg_tokens = np.mean(total_tokens)
    avg_throughput = np.mean(total_throughputs)
    avg_latency = np.mean(total_latencies)

    std_tokens = np.std(total_tokens)
    std_throughput = np.std(total_throughputs)
    std_latency = np.std(total_latencies)

    print("\nSummary:")
    print(f"Average Tokens Generated: {avg_tokens:.2f} (Std: {std_tokens:.2f})")
    print(f"Average Throughput: {avg_throughput:.2f} tokens/second (Std: {std_throughput:.2f})")
    print(f"Average Latency: {avg_latency:.2f} seconds (Std: {std_latency:.2f})")

# Replace med_queries with your queries of interest
if __name__ == "__main__":
    process_queries_simple(med_queries)