https://pythonmldaily.com/lesson/python-chatbot-langchain/read-csv-documents-embeddings

In [17]:
!pip install chromadb faiss-cpu langchain_community -q


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.10 -m pip install --upgrade pip[0m


In [50]:
import boto3
import sagemaker
import json

from langchain.document_loaders.csv_loader import CSVLoader
from langchain.document_loaders import TextLoader
from langchain.prompts import PromptTemplate

from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.memory import ConversationBufferMemory
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.vectorstores import FAISS,Chroma
from langchain.chains.question_answering import load_qa_chain
from langchain_community.llms import SagemakerEndpoint
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler


In [5]:
import pandas as pd
data = pd.read_csv("cve.csv", sep=",", header=0)
data = data.drop(['access_authentication',	'access_complexity','access_vector','impact_availability','impact_confidentiality','impact_integrity'], axis=1)
data.head()

Unnamed: 0,cve,mod_date,pub_date,cvss,cwe_code,cwe_name,summary
0,CVE-2019-16548,2019-11-21 15:15:00,2019-11-21 15:15:00,6.8,352,Cross-Site Request Forgery (CSRF),A cross-site request forgery vulnerability in ...
1,CVE-2019-16547,2019-11-21 15:15:00,2019-11-21 15:15:00,4.0,732,Incorrect Permission Assignment for Critical ...,Missing permission checks in various API endpo...
2,CVE-2019-16546,2019-11-21 15:15:00,2019-11-21 15:15:00,4.3,639,Authorization Bypass Through User-Controlled Key,Jenkins Google Compute Engine Plugin 4.1.1 and...
3,CVE-2013-2092,2019-11-20 21:22:00,2019-11-20 21:15:00,4.3,79,Improper Neutralization of Input During Web P...,Cross-site Scripting (XSS) in Dolibarr ERP/CRM...
4,CVE-2013-2091,2019-11-20 20:15:00,2019-11-20 20:15:00,7.5,89,Improper Neutralization of Special Elements u...,SQL injection vulnerability in Dolibarr ERP/CR...


In [8]:
import random
data = data.reset_index()
clusters = ['marutham', 'palai', 'neythal', 'kurinchi', 'mullai']
with open('cve.txt', 'w') as f:
    for index, row in data.iterrows():
        cluster = random.choice(clusters)
        f.write(f"{row['cwe_name']} vulnerability with ID {row['cve']} was found in cluster {cluster} on {row['mod_date']}. With the score of {row['cvss']}, this vulnerability is {row['summary']}")
        f.write('\n')

In [18]:
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
hf_embedding = HuggingFaceEmbeddings(model_name=embedding_model_name)

loader = TextLoader("cve.txt")
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=0, separators=[" ", ",", "\n"]
)
docs = text_splitter.split_documents(documents)

# Load the embeddings into Chroma in-memory vector store
vectorstore = FAISS.from_documents(docs, embedding=hf_embedding)
#vectorstore = Chroma.from_documents(docs, embedding=hf_embedding)
vectorstore_retriever = vectorstore.as_retriever(search_kwargs={"k": 1})


In [46]:
from langchain_community.llms.sagemaker_endpoint import LLMContentHandler


class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
        return input_str.encode('utf-8')

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]['generated_text']

In [55]:
session = sagemaker.Session()
runtime = boto3.client("sagemaker-runtime")
client  = boto3.client("sagemaker")
account_id = boto3.client('sts').get_caller_identity().get('Account')
role = f"arn:aws:iam::{account_id}:role/service-role/AmazonSageMaker-ExecutionRole-20240618T160096"
endpoint_name = "spectra"
inference_component_name = "mistral-7b-instruct-v3-180510"

parameters = { "do_sample": True,
    "top_p": 0.9,
    "top_k": 5,
    "temperature": 0.2,
    "max_new_tokens": 1024,
    "stop": ["<|endoftext|>", "</s>"],
    "early_stopping": True
}

endpoint_parameters = {
    #"InferenceComponentName": inference_component_name
}
sagemaker_llm = SagemakerEndpoint(
    endpoint_name = endpoint_name,
    client = runtime,
    model_kwargs = parameters,
    endpoint_kwargs = endpoint_parameters,
    content_handler = ContentHandler()
)

def get_conversation_chain(vectorstore):
    memory = ConversationBufferMemory(memory_key = "chat_history", return_messages=True)
    conversation_cahin = ConversationalRetrievalChain.from_llm(llm = sagemaker_llm, retriever = vectorstore_retriever, memory=memory)
    return conversation_cahin

In [60]:
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.:\n\n{context}\n\nQuestion: {question}\nHelpful Answer:"""

PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = load_qa_chain(llm=sagemaker_llm, prompt=PROMPT)
question = "Summarize the vulnerabilities found in the clusters"

#docs = vectorstore.similarity_search_with_score(question)
docs = vectorstore.similarity_search(question)
source = []
context = []
#for doc, score in docs:
for doc in docs:
    context.append(doc)
    source.append(doc.metadata["source"].split("/")[-1])

result = chain({"input_documents": context, "question": question}, return_only_outputs=True)[
    "output_text"
]
print(result)

Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.:

was found in cluster palai on 2019-11-19 16:43:00. With the score of 6.8, this vulnerability is cobbler: Web interface lacks CSRF protection when using Django framework
 Information Exposure vulnerability with ID CVE-2011-4919 was found in cluster marutham on 2019-11-19 16:43:00. With the score of 5.0, this vulnerability is mpack 1.6 has information disclosure via eavesdropping on mails sent by other users
 Untrusted Search Path vulnerability with ID CVE-2019-16861 was found in cluster neythal on 2019-11-19 13:15:00. With the score of 6.9, this vulnerability is Code42 server through 7.0.2 for Windows has an Untrusted Search Path. In certain situations, a non-administrative attacker on the local server could create or modify a dynamic-link library (DLL). The Code42 service could then load it at runtime, and potentially execu