# Building an ITSG-33 RAG chatbot with LangChain, Hugging Face, FAISS, Amazon SageMaker and Amazon Textract

In [1]:
%%sh
pip install sagemaker langchain amazon-textract-caller amazon-textract-textractor sentence-transformers pypdf pip install faiss-cpu -qU

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jupyter-ai-magics 2.15.0 requires langchain<0.2.0,>=0.1.0, but you have langchain 0.2.1 which is incompatible.
langchain-community 0.0.38 requires langchain-core<0.2.0,>=0.1.52, but you have langchain-core 0.2.3 which is incompatible.[0m[31m
[0m

In [2]:
import boto3, json, sagemaker
from typing import Dict
from langchain import LLMChain
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from sagemaker.jumpstart.model import JumpStartModel

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


## Deploy LLM on SageMaker

In [3]:
model_id = "huggingface-llm-mistral-7b-instruct"
accept_eula = False

model = JumpStartModel(model_id=model_id)

Using model 'huggingface-llm-mistral-7b-instruct' with wildcard version identifier '*'. You can pin to version '3.1.0' for more stable results. Note that models may have different input/output signatures after a major version upgrade.


In [5]:
predictor = model.deploy(accept_eula=accept_eula)

payload = {
    "inputs": "Why is the sky blue?",
}
response = predictor.predict(payload)
print(response)

Using already existing model: hf-llm-mistral-7b-instruct-2024-06-04-12-05-01-370


------------![{'generated_text': "Why is the sky blue? Why does the moon change shape? What makes a rainbow? Join scientists and engineers at NASA and The Museum of Science and Industry Chicago for a series of free, live interactive webcasts, exploring some of the most fundamental questions in science.\n\nIn this live webcast series, we'll explore the scientific principles behind some of nature's most beautiful and intriguing phenomena. Each webcast will feature a guest speaker from NASA and Q&A sessions with students and the public"}]


## Configure LLM in LangChain

In [6]:
#endpoint_kwargs = {"InferenceComponentName": inference_component_name}
model_kwargs = {"max_new_tokens": 512, "top_p": 0.2, "temperature": 0.2}

In [7]:
class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps(
            # Mistral prompt, see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
            {"inputs": f"<s>[INST] {prompt} [/INST]", "parameters": {**model_kwargs}}
        )
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        splits = response_json[0]["generated_text"].split("[/INST] ")
        return splits[1]


content_handler = ContentHandler()

In [8]:
sm_client = boto3.client("sagemaker-runtime") # needed for AWS credentials

llm = SagemakerEndpoint(
    #      endpoint_kwargs=endpoint_kwargs,
    endpoint_name=predictor.endpoint_name,
    model_kwargs=model_kwargs,
    content_handler=content_handler,
    client=sm_client,
)

## Zero-shot example

In [9]:
system_prompt = """
As a helpful cybersecurity expert, please answer the question.
Don't invent facts. If you can't provide a factual answer, say you don't know what the answer is.
"""

prompt = PromptTemplate.from_template(system_prompt + "{content}")

In [10]:
llm_chain = LLMChain(llm=llm, prompt=prompt)

  warn_deprecated(


In [11]:
question = "What is CA-9 control?"

query = f"question: {question}"

In [12]:
answer = llm_chain.run({query})
print(answer)

  warn_deprecated(


I'm sorry for any confusion, but the term "CA-9 control" is not a widely recognized term in cybersecurity or any other field that I'm aware of. It's possible that it could refer to a specific control or security measure used in a particular organization or industry, but without more context, it's impossible for me to provide a factual answer. If you could please provide more information or context about what "CA-9 control" is supposed to be, I'd be happy to help if I can.


## RAG example with PDF files

In [13]:
from langchain.document_loaders import AmazonTextractPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA

### Upload local PDF files to S3

Sources:
* https://www.cyber.gc.ca/sites/default/files/itsg-33_-_overview.pdf
* https://www.cyber.gc.ca/sites/default/files/cyber/publications/itsg33-ann1-eng.pdf
* https://www.cyber.gc.ca/sites/default/files/cyber/publications/itsg33-ann2-eng.pdf
* https://www.cyber.gc.ca/sites/default/files/cyber/publications/itsg33-ann3a-eng.pdf
* https://www.cyber.gc.ca/sites/default/files/cyber/publications/itsg33-ann4a-1-eng.pdf


In [14]:
# Define S3 bucket and prefix for PDF storage

bucket = sagemaker.Session().default_bucket()
prefix = "itsg-33-rag-demo"

In [14]:
%%sh -s $bucket $prefix
aws s3 cp --recursive itsg-33 s3://$1/$2/

upload: itsg-33/itsg-33_-_overview.pdf to s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg-33_-_overview.pdf
upload: itsg-33/itsg33-ann1-eng.pdf to s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann1-eng.pdf
upload: itsg-33/itsg33-ann3a-eng.pdf to s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann3a-eng.pdf
upload: itsg-33/itsg33-ann5-eng.pdf to s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann5-eng.pdf
upload: itsg-33/itsg33-ann4a-1-eng.pdf to s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann4a-1-eng.pdf
upload: itsg-33/itsg33-ann2-eng.pdf to s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann2-eng.pdf


In [15]:
# Build list of S3 URIs

s3 = boto3.client("s3")
objs = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)
objs = objs['Contents']
uris = [f's3://{bucket}/{obj["Key"]}' for obj in objs]
uris    

['s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg-33_-_overview.pdf',
 's3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann1-eng.pdf',
 's3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann2-eng.pdf',
 's3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann3a-eng.pdf',
 's3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann4a-1-eng.pdf',
 's3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann5-eng.pdf']

### Analyze documents with Amazon Textract and split them in chunks

In [16]:
%%time

textract_client = boto3.client('textract')
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)

all_chunks = []

for uri in uris:
    loader = AmazonTextractPDFLoader(uri, client=textract_client)
    document = loader.load()
    chunks = splitter.split_documents(document)
    all_chunks += chunks
    print(f"Loaded {uri}, {len(document)} pages, {len(chunks)} chunks")

Loaded s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg-33_-_overview.pdf, 16 pages, 48 chunks
Loaded s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann1-eng.pdf, 56 pages, 177 chunks
Loaded s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann2-eng.pdf, 113 pages, 364 chunks
Loaded s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann3a-eng.pdf, 270 pages, 1224 chunks
Loaded s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann4a-1-eng.pdf, 114 pages, 288 chunks
Loaded s3://sagemaker-ca-central-1-654654213972/itsg-33-rag-demo/itsg33-ann5-eng.pdf, 24 pages, 71 chunks
CPU times: user 1min 3s, sys: 1.04 s, total: 1min 4s
Wall time: 8min 30s


### Embed document chunks and store them in FAISS
https://github.com/facebookresearch/faiss 

In [17]:
%%time
# Define embedding model
# See https://huggingface.co/spaces/mteb/leaderboard

embedding_model_id = "BAAI/bge-small-en-v1.5"

embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model_id,
)

  from tqdm.autonotebook import tqdm, trange
2024-06-04 12:28:31.856069: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


CPU times: user 5.31 s, sys: 409 ms, total: 5.72 s
Wall time: 7.29 s


In [18]:
%%time
# Embed chunks
embeddings_db = FAISS.from_documents(all_chunks, embeddings)

CPU times: user 6min 39s, sys: 2min 33s, total: 9min 12s
Wall time: 5min 56s


In [19]:
%%time
# Save database
embeddings_db.save_local("faiss_index")

CPU times: user 19.3 ms, sys: 916 µs, total: 20.2 ms
Wall time: 28 ms


### Shortcut : load existing embedding database

In [20]:
%%time
embeddings_db = FAISS.load_local("faiss_index", embeddings,  allow_dangerous_deserialization=True)

CPU times: user 13.9 ms, sys: 2.48 ms, total: 16.4 ms
Wall time: 46.4 ms


********

### Configure RAG chain

In [21]:
%%time
retriever = embeddings_db.as_retriever(search_kwargs={"k": 10})

CPU times: user 106 µs, sys: 0 ns, total: 106 µs
Wall time: 110 µs


In [22]:
%%time
# Define prompt template
prompt_template = """
As a helpful Government of Canada cybersecurity specialist, please answer the question below, focusing on text data and using only the context below.
Don't invent facts. If you can't provide a factual answer, say you don't know what the answer is.

question: {question}

context: {context}
"""

prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

CPU times: user 79 µs, sys: 26 µs, total: 105 µs
Wall time: 110 µs


In [23]:
%%time
chain = RetrievalQA.from_chain_type(
    llm=llm, 
    chain_type="stuff",
    retriever=retriever, 
    chain_type_kwargs = {"prompt": prompt})

CPU times: user 190 ms, sys: 98.9 ms, total: 288 ms
Wall time: 311 ms


### Ask our question again

In [24]:
%%time
question = "What is CA-9 control?"
answer = chain.run({"query": question})
print(answer)

CA-9 control refers to the organization's authorization of internal connections of defined information system components or classes of components to an information system. The organization documents the interface characteristics, security requirements, and nature of the information communicated for each internal connection. (Refer to CA-9 in the provided context for more details.)
CPU times: user 263 ms, sys: 11.8 ms, total: 275 ms
Wall time: 2.94 s


In [25]:
%%time
question = "What is AC-6 control?"
answer = chain.run({"query": question})
print(answer)

AC-6 is a cybersecurity control referred to as "Least Privilege." It is a principle that allows only authorized accesses for users or processes acting on behalf of users, which are necessary to accomplish assigned tasks in accordance with organizational missions and business functions. This control is related to other controls such as AC-3, CM-2, CM-3, CM-5, CM-6, CM-7, PL-4, and is considered a best practice. The control enhances security by reducing the risk of malevolent activity without collusion and helps to ensure that users only have the access they need and no more.
CPU times: user 254 ms, sys: 19.5 ms, total: 274 ms
Wall time: 5.07 s


In [26]:
%%time
question = "What is the Supplemental Guidance for AC-6 control?"
answer = chain.run({"query": question})
print(answer)

The Supplemental Guidance for AC-6 control (Remote Access I Automated Monitoring/Control) provides additional information on how organizations can implement the control effectively. It states that automated monitoring and control of remote access sessions allows organizations to detect cyber-attacks and ensure ongoing compliance with remote access policies by auditing connection activities of remote users on various information system components. Related controls for this enhancement include AU-2 and AU-12.
CPU times: user 213 ms, sys: 10.4 ms, total: 223 ms
Wall time: 3.83 s


In [27]:
%%time
question = "What are the Control Enhancements for AC-6 control Least Privilege?"
answer = chain.run({"query": question})
print(answer)

The context provided does not mention any specific control enhancements for AC-6 Least Privilege. According to the text, the principle of least privilege is a best practice, and it is encouraged to include it in departmental profiles for most cases. However, there are some exceptions, such as for specialized or advanced capabilities that are not required for all systems, or for outside personnel who need privileged access for maintenance. In these cases, inclusion in a departmental profile is made on a case-by-case basis. The text also suggests that organizations give due consideration to the least privilege control enhancement, even if it is not included in a departmental profile. Therefore, there are no specific control enhancements mentioned for AC-6 Least Privilege in the context provided.
CPU times: user 284 ms, sys: 15.8 ms, total: 300 ms
Wall time: 6.03 s


In [28]:
%%time
question = "Reflect on your last answer. Did you miss anything?"
answer = chain.run({"query": question})
print(answer)

Based on the context provided, the term "statement of assessment" refers to a recognition or acknowledgement that the assessment process has been completed with acceptable results. It can be formal or informal, such as a signed certificate from a security assessor or a record of decision appearing in the minutes of a meeting.

Therefore, the answer to the question "Reflect on your last answer. Did you miss anything?" is that I did not miss anything related to the definition of a statement of assessment in the context provided. However, I cannot answer any questions that go beyond the context given.
CPU times: user 277 ms, sys: 12.3 ms, total: 289 ms
Wall time: 4.94 s


In [29]:
%%time
question = "What is AUTHORIZE ACCESS TO SECURITY FUNCTIONS?"
answer = chain.run({"query": question})
print(answer)

The term "AUTHORIZE ACCESS TO SECURITY FUNCTIONS" refers to the process of explicitly authorizing access to organization-defined security functions and security-relevant information in an information system. This is typically done by the organization to ensure that only authorized individuals have access to these functions and information to maintain the security of the system. The context provided in the Control Enhancements section of ITSG-33 mentions this concept in the LEAST PRIVILEGE I AUTHORIZE ACCESS TO SECURITY FUNCTIONS enhancement. The authorization package, which includes statements of assessment, residual risk assessment results or TRA report, and the operations plan, is used to make the decision to authorize or deny access to these functions. (3.9.6 Authorize Information System Operations)
CPU times: user 190 ms, sys: 4.55 ms, total: 195 ms
Wall time: 6.34 s


## Delete endpoint and model

In [31]:
predictor.delete_predictor()