In [30]:
import os
import shutil
import tempfile

import requests
from bs4 import BeautifulSoup
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_openai import OpenAI, OpenAIEmbeddings

import mlflow
os.environ['OPENAI_API_KEY'] = "set your api key here"


In [6]:
def fetch_federal_document(url, div_class):  # noqa: D417
    """
    Scrapes the transcript of the Act Establishing Yellowstone National Park from the given URL.

    Args:
    url (str): URL of the webpage to scrape.

    Returns:
    str: The transcript text of the Act.
    """
    # Sending a request to the URL
    response = requests.get(url)
    if response.status_code == 200:
        # Parsing the HTML content of the page
        soup = BeautifulSoup(response.text, "html.parser")

        # Finding the transcript section by its HTML structure
        transcript_section = soup.find("div", class_=div_class)
        if transcript_section:
            transcript_text = transcript_section.get_text(separator="\n", strip=True)
            return transcript_text
        else:
            return "Transcript section not found."
    else:
        return f"Failed to retrieve the webpage. Status code: {response.status_code}"

In [24]:
def fetch_and_save_documents(url_list, doc_path):
    """
    Fetches documents from given URLs and saves them to a specified file path.

    Args:
        url_list (list): List of URLs to fetch documents from.
        doc_path (str): Path to the file where documents will be saved.
    """
    for url in url_list:
        document = fetch_federal_document(url, "col-sm-9")
        
        with open(doc_path, "a") as file:
            file.write(document)


def create_faiss_database(document_path, database_save_directory, chunk_size=500, chunk_overlap=10):
    """
    Creates and saves a FAISS database using documents from the specified file.

    Args:
        document_path (str): Path to the file containing documents.
        database_save_directory (str): Directory where the FAISS database will be saved.
        chunk_size (int, optional): Size of each document chunk. Default is 500.
        chunk_overlap (int, optional): Overlap between consecutive chunks. Default is 10.

    Returns:
        FAISS database instance.
    """
    # Load documents from the specified file
    document_loader = TextLoader(document_path)
    raw_documents = document_loader.load()

    # Split documents into smaller chunks with specified size and overlap
    document_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    document_chunks = document_splitter.split_documents(raw_documents)

    # Generate embeddings for each document chunk
    embedding_generator = OpenAIEmbeddings()
    faiss_database = FAISS.from_documents(document_chunks, embedding_generator)

    # Save the FAISS database to the specified directory
    faiss_database.save_local(database_save_directory)

    return faiss_database

In [18]:
temporary_directory = tempfile.mkdtemp()

doc_path = os.path.join(temporary_directory, "docs.txt")
persist_dir = os.path.join(temporary_directory, "faiss_index")

url_listings = [
    "https://www.archives.gov/milestone-documents/act-establishing-yellowstone-national-park#transcript",
    "https://www.archives.gov/milestone-documents/sherman-anti-trust-act#transcript",
]


fetch_and_save_documents(url_listings, doc_path)

vector_db = create_faiss_database(doc_path, persist_dir)

In [25]:
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("Legal RAG")
mlflow.langchain.autolog(log_models=True, log_input_examples=True)

retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=vector_db.as_retriever())


# Log the retrievalQA chain
def load_retriever(persist_directory):
    embeddings = OpenAIEmbeddings()
    vectorstore = FAISS.load_local(
        persist_directory,
        embeddings,
        allow_dangerous_deserialization=True,  # This is required to load the index from MLflow
    )
    return vectorstore.as_retriever()


with mlflow.start_run() as run:
    model_info = mlflow.langchain.log_model(
        retrievalQA,
        artifact_path="retrieval_qa",
        loader_fn=load_retriever,
        persist_dir=persist_dir,
    )

2024/09/24 14:05:22 INFO mlflow.tracking._tracking_service.client: 🏃 View run fun-lark-775 at: http://127.0.0.1:5000/#/experiments/418308964639969190/runs/8708650cd46548a3ba34fc9a26c29608.
2024/09/24 14:05:22 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/418308964639969190.


In [26]:
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [27]:
def print_formatted_response(response_list, max_line_length=80):
    """
    Formats and prints responses with a maximum line length for better readability.

    Args:
    response_list (list): A list of strings representing responses.
    max_line_length (int): Maximum number of characters in a line. Defaults to 80.
    """
    for response in response_list:
        words = response.split()
        line = ""
        for word in words:
            if len(line) + len(word) + 1 <= max_line_length:
                line += word + " "
            else:
                print(line)
                line = word + " "
        print(line)

In [28]:
answer1 = loaded_model.predict([{"query": "What does the document say about trespassers?"}])

print_formatted_response(answer1)

The document states that all persons who shall locate or settle upon or occupy 
the land designated as a public park, except as provided for in the act, shall 
be considered trespassers and removed from the land. 


In [15]:
# Clean up our temporary directory that we created with our FAISS instance
shutil.rmtree(temporary_directory)

In [17]:
from mlflow import MlflowClient

# Initialize an MLflow Client
client = MlflowClient()


def assign_alias_to_stage(model_name, stage, alias):
    """
    Assign an alias to the latest version of a registered model within a specified stage.

    :param model_name: The name of the registered model.
    :param stage: The stage of the model version for which the alias is to be assigned. Can be
                "Production", "Staging", "Archived", or "None".
    :param alias: The alias to assign to the model version.
    :return: None
    """
    latest_mv = client.get_latest_versions(model_name, stages=[stage])[0]
    client.set_registered_model_alias(model_name, alias, latest_mv.version)

In [31]:
import requests
import json
payload = json.dumps(
    {
        'inputs': {'query': ["What does the document say about trespassers?"]},
        'params': {
            'max_tokens': 20
        }
    }
)
res = requests.post(
    url="http://127.0.0.1:5002/invocations",
    data=payload,
    headers={"Content-Type": "application/json"},
)

In [32]:
res.json()

{'predictions': [' The document states that any persons who settle, occupy, or trespass upon the designated tract of land near the headwaters of the Yellowstone River, except as provided for in the act, shall be considered trespassers and will be removed from the land.']}