# RAG Demo in Snowflake

1. Extract Text from PDF via PyPDF2 in Python
2. Split texts into chunks using langchain in Python
3. Convert text-chunks into embeddings using Snowflake Cortex
4. Use Vector-Search and LLM for Retrievel Augmented Generation

In [None]:
# Import python packages
import streamlit as st
import pandas as pd
import snowflake.snowpark.functions as F
import snowflake.snowpark.types as T
from snowflake.snowpark.functions import udtf, udf
from snowflake.snowpark.files import SnowflakeFile
from io import BytesIO
from typing import Iterable, Tuple

# Get Snowpark Session
from snowflake.snowpark.context import get_active_session
session = get_active_session()


### 1. Python PDF Extraction

Creating a Python User-Defined-Table-Function (UDTF) using PyPDF2 to extract text from PDF files.

In [None]:
SELECT * FROM DIRECTORY('@DOCUMENTS');

In [None]:
@udtf(session=session,
      name='READ_PDF_UDTF',
      replace=True,
      is_permanent=True,
      stage_location='@GENAI.PUBLIC.FUNCTIONS',
      packages=['snowflake-snowpark-python','pypdf2','pycryptodome'], 
      output_schema=['PAGE_NUM','TEXT'])
class udtf_pdf_reader:
    def process(self, file_path:str) -> Iterable[Tuple[int,str]]:
        from PyPDF2 import PdfFileReader
        with SnowflakeFile.open(file_path, 'rb') as file:
            f = BytesIO(file.readall())
            pdf_reader = PdfFileReader(f)
            for page_ix, page in enumerate(pdf_reader.pages):
                yield (page_ix, page.extract_text())

In [None]:
-- We could use Python for PDF extraction but every SQL analyst can use the function as well
CREATE OR REPLACE TABLE RAW_TEXT AS
SELECT relative_path,
       file_url,
       func.*
    FROM directory(@DOCUMENTS),
        TABLE(READ_PDF_UDTF(build_scoped_file_url(@DOCUMENTS, relative_path))) as func;

In [None]:
from snowflake.snowpark.session import Session
from snowflake.snowpark.functions import col, regexp_replace

# Assuming you've already established a session with Snowflake
# session = Session.builder.configs(configs).create()

# Assuming there's a DataFrame `df` with a column named 'text_column' that you want to process
# For example:
# df = session.create_dataframe(["This is   an example    text with  spaces"])

# Use regexp_replace to find whitespaces that occur more than twice in a row and replace them
# Here, the pattern '\\s{2,}' means "match any whitespace character that occurs at least twice in a row"
# and we replace it with a single space ' '
df_modified = df.with_column("text_column", regexp_replace(col("text_column"), '\\s{2,}', ' '))

# Show the modified DataFrame
df_modified.show()


In [None]:
# And we can easily switch between SQL and Python
pdf_extracts = session.table('RAW_TEXT')
print('Number of pages extracted:', pdf_extracts.count())
pdf_extracts.limit(5).show()

## 2. Text Splitting with Langchain
Some pages contain a lot of text, so we want to split them into chunks that fit into our embedding model.

We use another UDTF based on langchain for this.

In [None]:
@udtf(session=session,
      name='langchain_splitting',
      replace=True,
      is_permanent=True,
      stage_location='@GENAI.PUBLIC.FUNCTIONS',
      packages=['langchain'], 
      output_schema=['CHUNK_INDEX','TEXT_CHUNK'])
class udtf_text_splitter:
    def __init__(self):
        from langchain.text_splitter import RecursiveCharacterTextSplitter
        self.text_splitter = RecursiveCharacterTextSplitter(
            separators = ["\n"],
            chunk_size = 1000,
            chunk_overlap  = 50,
            length_function = len,
            is_separator_regex = False
        )
    def process(self, text:str) -> Iterable[Tuple[int,str]]:
        texts = self.text_splitter.create_documents([text])
        for chunk_ix, text in enumerate(texts):
            yield (chunk_ix, text.page_content.strip())

In [None]:
pdf_extracts_chunked = pdf_extracts.join_table_function('langchain_splitting', 'text')
pdf_extracts_chunked = pdf_extracts_chunked.drop('TEXT')
pdf_extracts_chunked.write.save_as_table('CHUNK_TEXT', mode='overwrite')
session.table('CHUNK_TEXT').order_by('PAGE_NUM','CHUNK_INDEX').limit(10).to_pandas()

### 3. Convert Text-Chunks to Text-Embddings using Snowflake Cortex Functions

In [None]:
--Convert your chunks to embeddings
CREATE OR REPLACE TABLE VECTOR_STORE AS
    SELECT RELATIVE_PATH,
           PAGE_NUM,
           TEXT_CHUNK,
           snowflake.cortex.embed_text('e5-base-v2', TEXT_CHUNK) as chunk_embedding
        FROM CHUNK_TEXT;

In [None]:
-- Embedding Search
SELECT PAGE_NUM, 
       TEXT_CHUNK,
       VECTOR_L2_DISTANCE(
            snowflake.cortex.embed_text('e5-base-v2', 
                'What is Siemens Smart Infrastructure?'), 
            CHUNK_EMBEDDING) AS VECTOR_DISTANCE
FROM VECTOR_STORE 
    ORDER BY VECTOR_DISTANCE LIMIT 10;

### 4. Retrieval Augmented Generation (RAG)

In [None]:
-- RAG
WITH RAG_DATA AS (
    -- Embedding Search
    SELECT PAGE_NUM, 
           TEXT_CHUNK,
           VECTOR_L2_DISTANCE(
                snowflake.cortex.embed_text('e5-base-v2', 
                    'What is Siemens Smart Infrastructure?'), 
                CHUNK_EMBEDDING) AS VECTOR_DISTANCE
    FROM VECTOR_STORE 
        ORDER BY VECTOR_DISTANCE LIMIT 1
)
SELECT PAGE_NUM, 
       TEXT_CHUNK,
       snowflake.cortex.complete(
            'mixtral-8x7b', 
            CONCAT( 
                'Answer the question based on the context. Be concise.','Context: ',
                TEXT_CHUNK,
                'Question: ', 
                'What is Siemens Smart Infrastructure?',
                'Answer: '
            )
        ) as RESPONSE
    FROM RAG_DATA;

In [None]:
-- RAG
WITH RAG_DATA AS (
    -- Embedding Search
    SELECT PAGE_NUM, 
           TEXT_CHUNK,
           VECTOR_L2_DISTANCE(
                snowflake.cortex.embed_text('e5-base-v2', 
                    'What is Siemens Smart Infrastructure?'), 
                CHUNK_EMBEDDING) AS VECTOR_DISTANCE
    FROM VECTOR_STORE 
        ORDER BY VECTOR_DISTANCE LIMIT 1
)
SELECT PAGE_NUM, 
       TEXT_CHUNK,
       snowflake.cortex.complete(
            'mixtral-8x7b', 
            [
                {'role': 'system', 'content': 'You are a helpful AI assistant that answers questions based on provided context. Be concise.'},
                {'role': 'user', 'content': CONCAT('What is Siemens Smart Infrastructure? The context is:', TEXT_CHUNK)}
            ], {}
        ) as RESPONSE
    FROM RAG_DATA;

In [None]:
# Import python packages
import streamlit as st
from snowflake.snowpark.context import get_active_session

# Write directly to the app
st.title("Snowflake Cortex")
st.subheader("Retrieval Augmented Generation for Siemens Financial Report")

# Get the current credentials
session = get_active_session()

llm_model = st.selectbox('LLM Model', ['mistral-large', 'mixtral-8x7b', 'llama2-70b-chat', 'mistral-7b', 'gemma-7b'])

question = st.text_area('Your question:', value='What is Siemens Smart Infrastructure?')
if st.button('Ask question'):
    sql = f"""WITH RAG_DATA AS (
                    -- Embedding Search
                    SELECT PAGE_NUM, 
                           TEXT_CHUNK,
                           VECTOR_L2_DISTANCE(
                                snowflake.cortex.embed_text('e5-base-v2', 
                                    '{question}'), 
                                CHUNK_EMBEDDING) AS VECTOR_DISTANCE
                    FROM VECTOR_STORE 
                        ORDER BY VECTOR_DISTANCE LIMIT 1
                )
                SELECT PAGE_NUM, 
                       TEXT_CHUNK,
                       snowflake.cortex.complete(
                            '{llm_model}', 
                            CONCAT( 
                                'Answer the question based on the context. Be concise.','Context: ',
                                TEXT_CHUNK,
                                'Question: ', 
                                '{question}',
                                'Answer: '
                            )
                        ) as RESPONSE
                    FROM RAG_DATA"""
    results = session.sql(sql).collect()[0]
    st.subheader('Response:')
    st.info(results['RESPONSE'])
    st.subheader('Context:')
    st.info(f"Page Number: {results['PAGE_NUM']}")
    st.info(f"Page Context: {results['TEXT_CHUNK']}")

In [None]:
# clean up
#session.sql('DROP TABLE IF EXISTS CHUNK_TEXT').collect()
#session.sql('DROP TABLE IF EXISTS RAW_TEXT').collect()
#session.sql('DROP TABLE IF EXISTS VECTOR_STORE').collect()
#session.sql('DROP FUNCTION IF EXISTS READ_PDF_UDTF(VARCHAR)').collect()
#session.sql('DROP FUNCTION IF EXISTS LANGCHAIN_SPLITTING(VARCHAR)').collect()