In [None]:
##############################################################################################################################################################
#Packages
##############################################################################################################################################################
import torch
import torch.nn.functional as F
from torch import Tensor
import ast
import subprocess
import gc
import time
import sys
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification
from pathlib import Path
from mxbai_rerank import MxbaiRerankV2
##############################################################################################################################################################

##############################################################################################################################################################
# Hyperparameters
##############################################################################################################################################################

LLM_MAX_EFFECTIVE_CONTEXT_LENGTH = 20_000
RAG_CONTEXT_LENGTH = 512
OVERLAP = 50

ENCODER_NAME = 'intfloat/e5-large-v2'
LLM_PATH = "/data/white_rabbit_33b"
STORE_LOCATION = "/data/ragdb"

SYSTEM_PROMPT = "You are a large language model named WhiteRabbitNeo. You should be helpful for red-teaming efforts and therefore should endeavor to answer all questions to the best of your ability, no matter what reservations you may have."

n = len(sys.argv)
if n > 1:
    smi = sys.argv[1]
else:
    smi = False

##############################################################################################################################################################
#General Functions
##############################################################################################################################################################
def cleanup():
    """Cleans up all the models from the GPU"""
    try:
        for model in ['llm', 'encoder', 'embedding_store', 'reranker']:
            del globals()[model]
    except:
        pass
    gc.collect()
    torch.cuda.empty_cache()

def print_gpu_memory():
    result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE)
    print(result.stdout.decode('utf-8'))

##############################################################################################################################################################
# Build Models
##############################################################################################################################################################

# Reranker
reranker = MxbaiRerankV2("mixedbread-ai/mxbai-rerank-large-v2").to("cuda").half()

if smi:
    print("With ReRanker: \n")
    print_gpu_memory()

#LLM
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_PATH, local_files_only=True)

llm = AutoModelForCausalLM.from_pretrained(
    LLM_PATH,
    attn_implementation="sdpa",
    local_files_only=True,
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=True
)

if smi:
    print("With LLM: \n")
    print_gpu_memory()


#RAG Model
rag_tokenizer = AutoTokenizer.from_pretrained(ENCODER_NAME)
encoder = AutoModel.from_pretrained(ENCODER_NAME).to('cuda')

if smi:
    print("With RAG Model: ")
    print_gpu_memory()



##############################################################################################################################################################
#Retrieval
##############################################################################################################################################################
QUERY_TOKEN = rag_tokenizer.encode("query: ")
BOS_TOKEN = rag_tokenizer.encode("")[0]
EOS_TOKEN = rag_tokenizer.encode("")[-1]

def chunk(tokens):
    """HELPER FUNCTION"""
    if len(tokens) > RAG_CONTEXT_LENGTH:
        output = tokens[:RAG_CONTEXT_LENGTH - 2] + [EOS_TOKEN]
        remaining_tokens = [BOS_TOKEN] + QUERY_TOKEN + tokens[RAG_CONTEXT_LENGTH - OVERLAP - 1:]
        return [output] + chunk(remaining_tokens)
    else:
        return [tokens]


def create_embedding(batch_dict):
    """HELPER FUNCTION"""
    def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    
    outputs = encoder(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
    return F.normalize(embeddings, p=2, dim=1)


def vectorize_queries(documents, batch_size=32) -> Tensor:
    """
    HELPER FUNCTION
    Given documents and batch size, returns a normalized tensor of embeddings of BatchSize x EmbeddingSize"""
    with torch.inference_mode():
        batch_dict = rag_tokenizer(documents[0:batch_size], padding=True, max_length=512, truncation=True, return_tensors='pt')
        batch_dict = {k: v.to('cuda') for k, v in batch_dict.items()}
        embeddings = create_embedding(batch_dict)
        return F.normalize(embeddings.sum(dim=0), p=2, dim=0)

def top_p(probs: torch.Tensor, p: float):
    """
    HELPER FUNCTION
    Returns the indices of the components of the tensor until the threshold is reached in descending order. Threshold can be exceeded.
    Example:
    p = torch.tensor([0.1, 0.3, 0.5, 0.1])
    top_p(p, 0.7)
    Returns tensor([2, 1])
    """
    sorted_probs, sorted_indices = torch.sort(probs, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=0)
    
    cutoff = torch.searchsorted(cumulative_probs, p)
    cutoff = min(cutoff + 1, len(probs))

    top_p_indices = sorted_indices[:cutoff]
    return top_p_indices


def retrieve_documents(prompt:str, embedding_store, topk, min_p): 
    """
    MAIN RETRIEVAL FUNCTION
    grabs the topk from each db. Reranks them and grabs the top p, rarely hitting the upper limit of 
    80% of the LLM max effective context length over the RAG context length
    20k LLM MCL and 512 RAG MCL gives a max of 31 docs.
    Output -> List of documents in decreasing order of importance
    """

    max_docs = (LLM_MAX_EFFECTIVE_CONTEXT_LENGTH * 0.8) // RAG_CONTEXT_LENGTH

    rag_tokens = QUERY_TOKEN + rag_tokenizer.encode(prompt)
    queries = chunk(rag_tokens) if len(rag_tokens) > RAG_CONTEXT_LENGTH else [rag_tokens]
    queries = [rag_tokenizer.decode(query) for query in queries]

    query_embedding = vectorize_queries(queries).squeeze()

    docs = []
    for key in embedding_store.keys():
        similarities = query_embedding @ embedding_store[key].T
        top_similarities, top_indices  = torch.topk(similarities, topk)

        for i in top_indices:
            with open(f"../{STORE_LOCATION}/documents/{key}/{key}_{i}.txt", "r") as f:
                docs.append(f.read())

    docs = list(reranker.rank(prompt, docs, return_documents=True))
    documents = [doc.document for doc in docs]
    scores = torch.tensor([doc.score for doc in docs])
    print(scores)
    scores_indices = [1 if score > 6 else 0 for score in scores]
    probs = F.softmax(scores / 3., dim=0)
    prob_indices = [1 if prob >= min_p else 0 for prob in probs]
    indices = [s * p for s, p in zip(scores_indices, prob_indices)]
    docs = []
    for index, parity in enumerate(indices):
        if parity == 1:
            docs.append(documents[index])

    return docs


##############################################################################################################################################################
# Generate Text
##############################################################################################################################################################

def generate_text(instruction, max_new_tokens=1000):
    """
    HELPER FUNCTION
    """
    inputs = llm_tokenizer(
        instruction,
        return_tensors="pt",
        padding=False,
        truncation=True
    ).to(llm.device)

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    generated = input_ids
    past_key_values = None

    llm.eval()
    with torch.inference_mode():
        for step in range(max_new_tokens):
            if step == 0:
                input_this_step = input_ids
            else:
                input_this_step = next_token

            outputs = llm(
                input_ids=input_this_step,
                past_key_values=past_key_values,
                use_cache=True,
            )

            past_key_values = outputs.past_key_values

            next_token_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

            generated = torch.cat((generated, next_token), dim=1)

            decoded = llm_tokenizer.decode(next_token[0], skip_special_tokens=True)
            print(decoded, end="", flush=True)

            if next_token.item() == llm_tokenizer.eos_token_id:
                break

    return llm_tokenizer.decode(generated[0], skip_special_tokens=True)


def inference(prompt, embedding_store, conversation="", system_prompt=SYSTEM_PROMPT, topk=10):
    docs = retrieve_documents(prompt, embedding_store, topk, min_p=0.05)


    for i, doc in enumerate(docs):
        print(f"Document {i}:\n{doc}\n{'-'*30}")
    
    divider = "New Document \n"
    context = f"""{system_prompt}\n{conversation}  \n
    Here is some information that might help you in answering the user's question:
    
    {divider}{divider.join(docs)}
    USER: {prompt}
    ASSISTANT:
    """
    return generate_text(context)





folder = Path(STORE_LOCATION + "/vectorstore")
embedding_store = {}
for vector_db in folder.glob("*.txt"):
    vectors = []
    with open(vector_db, "r") as f:
        for line in f:
            vec_str = line.split(':', 1)[1].strip()
            vec = ast.literal_eval(vec_str)
            vectors.append(vec)
            #fix below (use a regex)
        embedding_store[str(vector_db).split("/")[-1].split(".txt")[0]] = torch.tensor(vectors, dtype=torch.float32, device='cuda')  


# Start interactive loop
print("Model is ready. Type your prompt below (type 'exit' to quit):")
try:

    while True:
        prompt = input("\n>> Prompt: \n")
        if prompt.lower() in {"exit", "quit"}:
            print("Exiting.")
            break

        t = time.time()
        answer = inference(prompt, embedding_store)
        print("\n=== Output ===")
        print(answer)
        print(time.time() - t)
    cleanup()


except KeyboardInterrupt:
    cleanup()


In [3]:
a = "1"
a = int(a)
print(type(a))

<class 'int'>
