# LLM/RAG Initialization

In [5]:
import pickle
from langchain import PromptTemplate, LLMChain, HuggingFacePipeline
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline
import torch
import re
from rag_source import *
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

embedding_model = SentenceTransformer('all-MiniLM-L6-v2')  #This is directory

# Load data for the RAG (processed_data.pkl)
with open('processed_data.pkl', 'rb') as f:
    loaded_data = pickle.load(f)


# Load the LLaMA model for text generation
#model_id = "meta-llama/Meta-Llama-3-8B"

#llama_pipeline = pipeline(
#    "text-generation", 
#    model=model_id, 
#    model_kwargs={"torch_dtype": torch.bfloat16}, 
#    device_map="auto",
#    temperature=0.3,
#    top_p=0.9,
#    max_length=1000
#)

#llm = HuggingFacePipeline(pipeline=llama_pipeline)

local_model = "llama"

pipeline = pipeline(
    "text-generation", 
    model=local_model, #model_id, 
    model_kwargs={"torch_dtype": torch.bfloat16}, 
    device_map="auto",
    temperature=0.3,  # Control randomness (lower values = more focused responses)
    top_p=0.9,  # Filter unlikely words
    truncation=True,
    max_length=500
)

llm = HuggingFacePipeline(pipeline=pipeline)

# Define a PromptTemplate that accepts prompt_text as an input
prompt_template = PromptTemplate(
    input_variables=["prompt_text"],
    template="{prompt_text}"
)

# Initialize LLMChain with the LLM and prompt template
llm_chain = LLMChain(
    prompt=prompt_template,
    llm=llm
)

# Function to directly analyze the provided text prompt
def analyze_text_prompt(query, processed_data):
    # Pass the prompt text to LLMChain as a dictionary

    prompt_text = process_query(query, processed_data, embedding_model) + " The correct answer is, "

    inputs = {
        "prompt_text": prompt_text
    }
    
    # Run the LLM chain and get the result
    result = llm_chain.run(inputs)

    return prompt_text, result


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

# Pass the query

In [6]:
query = "For the cache access with PC 0x403a85 and address 0x35e798a637f on the bzip workload with PARROT replacement policy, does the cache hit or miss?"

prompt, response = analyze_text_prompt(query, loaded_data)

# Remove the prompt_text from the beginning of result
if response.startswith(prompt):
    response = response[len(prompt):]

print("\nPrompt: ", prompt)
print("\nResponse: ", response)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Prompt:  You are an expert in computer architecture and your job is to answer questions given data from cache traces. Base your response on the following data and your knowledge of computer architecture.

Answer the following question: For the cache access with PC 0x403a85 and address 0x35e798a637f on the bzip workload with PARROT replacement policy, does the cache hit or miss? The correct answer is, 

Response:  0x403a85 is the PC address, 0x35e798a637f is the address, and the cache hit or miss is a miss.
