In [1]:
import os

os.environ["DOC_AI_LOCATION"] = "us"
os.environ["DOC_AI_PROCESSOR_ID"] = "e977fdd46ee23308"
os.environ["PROJECT_ID"] = "602280418311"
os.environ["GOOGLE_API_KEY"] = ""
os.environ["LOCATION"] = "us-west1"
os.environ['AWS_ACCESS_KEY_ID']=""
os.environ['AWS_SECRET_ACCESS_KEY']=""
os.environ['MLFLOW_TRACKING_URI']="https://mlflow-gzpg2zq5pa-uc.a.run.app/"
os.environ['MLFLOW_TRACKING_USERNAME']=""
os.environ['MLFLOW_TRACKING_PASSWORD']=""
os.environ['LANGCHAIN_TRACING_V2']="true"
os.environ['LANGCHAIN_ENDPOINT']="https://api.smith.langchain.com"
os.environ['LANGCHAIN_API_KEY']=""
os.environ['LANGCHAIN_PROJECT']="iliOS-key-value-extraction"
os.environ['GOOGLE_APPLICATION_CREDENTIALS']=''

In [2]:
import pathlib
os.environ['PYTHONPATH'] = str(pathlib.Path().absolute().parent)

In [3]:
# from src.doc_ai.processor import DocAIProcessor
# from src.doc_ai.processors import DOC_AI_PROCESSOR
# from src.pipelines.term_extraction.pipeline_config import Phase1ESAConfig
# 
# config = Phase1ESAConfig()
# processor_location: str = os.environ["DOC_AI_LOCATION"]
# processor_project_id: str = os.environ["PROJECT_ID"]
# processor_id: str = DOC_AI_PROCESSOR["PROCESSOR"]
# 
# processor = DocAIProcessor(
#     location=processor_location,
#     project_id=processor_project_id,
#     processor_id=processor_id,
# )
# print(config.get_file_names())
# file_sequences = []
# for file_paths in config.get_file_names():
#     file_sequences.append(processor.process_documents(
#         [config.get_documents_path() + file_paths]
#     ))

In [4]:
import pickle
# with open('file_sequences.pkl', 'wb') as f:
#     pickle.dump(file_sequences, f)


In [5]:
path = pathlib.Path().absolute().parent.parent / 'Precomputes/file_sequences_phase_1_esa.pkl'

In [6]:
with open(path, 'rb') as f:
    file_sequences = pickle.load(f)

In [7]:
file_sequences

In [8]:
from src.pipelines.term_extraction.pipeline_runner import mlflow_metrics
from src.validation.validation import calculate_metrics
from src.pipelines.term_extraction.utils import get_project_preview, save_results
import mlflow
import logging.config
from pathlib import Path
from typing import Any, List, Optional

import pandas as pd
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages.ai import AIMessage
from langchain_core.retrievers import BaseRetriever

from src.doc_ai.file_sequence import FileSequence
from src.doc_ai.processor import DocAIProcessor
from src.gen_ai.gen_ai import get_llm
from src.pipelines.constants import NOT_PROVIDED_STR, NOT_SUPPORTED_KEY_ITEM
from src.pipelines.term_extraction.pipeline_config import (
    Phase1ESAConfig,
    PipelineConfig,
    PVSystPipelineConfig,
    SubscriberMgmtPipelineConfig,
)
from src.prompts.prompts import (
    prompt_template,
    prompt_template_instructions,
    rag_prompt_template,
    rag_prompt_template_pvsyst,
)
from src.validation.reviewer.reviewer import Reviewer
from src.vectordb.vectordb import VectorDB

logger = logging.getLogger(__name__)

class Pipeline:
    """Pipeline for building project previews."""

    def __init__(
            self,
            processor_location: str,
            processor_project_id: str,
            processor_id: str,
            terms_and_definitions: pd.DataFrame,
            config: PipelineConfig,
            k: int = 10,
            chunk_size: Optional[int] = None,
            add_tables: bool = True,
            few_shot: bool = False,
            few_shot_examples: pd.DataFrame = None,
            model_type: str = "CLAUDE",
            file_sequence: FileSequence = None,
    ):
        """Initialize the term_extraction."""

        self.processor = DocAIProcessor(
            location=processor_location,
            project_id=processor_project_id,
            processor_id=processor_id,
        )
        self.model_type = model_type
        self.llm = get_llm(model_type=self.model_type)
        chunk_size = chunk_size if chunk_size is not None else 600
        self.vectordb = VectorDB(
            k=k,
            chunk_size=chunk_size,
            add_tables=add_tables,
        )
        self.retriever: BaseRetriever
        self.terms_and_definitions_full = terms_and_definitions.copy()
        self.terms_and_definitions = self.terms_and_definitions_preprocess(
            terms_and_definitions
        )
        if few_shot and few_shot_examples is None:
            raise ValueError("few_shot_examples must be provided if few_shot is True")
        self.few_shot = few_shot
        self.few_shot_examples = few_shot_examples
        self.reviewer = Reviewer(model_type=self.model_type)
        self.config = config
        self.file_sequence = file_sequence
        self.postprocess_excluded_keys: List[str] | None = None

        logger.info("Pipeline initialized with the following parameters:")
        logger.info(f"k: {k}")
        logger.info(f"few_shot: {few_shot}")

    @staticmethod
    def terms_and_definitions_preprocess(
            terms_and_definitions: pd.DataFrame,
    ) -> pd.DataFrame:
        col = (
            "Instructions"
            if "Instructions" in terms_and_definitions.columns
            else "Definitions"
        )
        terms_and_instructions = terms_and_definitions[
            ~terms_and_definitions[col].isna()
            & ~terms_and_definitions["Key Items"].isna()
            ]
        return terms_and_instructions

    def get_retriever(self) -> Any:
        """Get the retriever."""
        if self.retriever is None:
            raise ValueError("Retriever has not been built yet.")
        return self.retriever

    def _build_prompt(self, term_definition_row: pd.Series) -> str:
        """
        Builds a prompt for a given term and definition.
        We can change the prompt template here.
        """
        if "Instructions" in term_definition_row:
            term = term_definition_row["Key Items"]
            instructions = term_definition_row["Instructions"]
            prompt = prompt_template_instructions(term, instructions)
            return prompt

        term = term_definition_row["Key Items"]
        definition = term_definition_row["Definitions"]
        if self.few_shot and term in self.few_shot_examples.index:
            examples = [
                example
                for example in self.few_shot_examples.loc[term]
                if not pd.isna(example) and example != NOT_PROVIDED_STR
            ]
        else:
            examples = []

        prompt = prompt_template(term, definition, examples=examples)
        return prompt

    def _build_a_chain(
            self, file_paths: List[str | Path], pages: List[List[int]] | None = None, file_sequence: FileSequence = None
    ) -> Any:
        """
        Builds a langchain chain for a given file.
        Builds a basic RAG chain for a given file.
        """
        retrieval_qa_chat_prompt = rag_prompt_template()
        logger.info(f"Preparing file for processing: {file_paths}")
        logger.info(f"Processing pages: {pages}")
        self.retriever = self.vectordb.retriever_from_file_sequence(file_sequence)
        combine_docs_chain = create_stuff_documents_chain(
            self.llm, retrieval_qa_chat_prompt
        )

        retrieval_chain = create_retrieval_chain(self.retriever, combine_docs_chain)
        return retrieval_chain

    def build_project_preview(self, file_paths: List[str | Path], file_sequence: FileSequence) -> pd.DataFrame:
        """
        Main method of the term_extraction. This method builds a project preview for a
        given pdf file.
        :param file_paths: list of files to bundle together like Site Lease plus
            Amendments
        :return:
        """
        logger.info("Build a default chain")
        chain = self._build_a_chain(file_paths, file_sequence=file_sequence, pages=[[i for i in range(50)]])
        logger.info(f"Building project preview for {file_paths}")

        responses = self.build_responses(chain, file_paths)

        logger.info(f"Finished batch processing for {file_paths}")

        project_preview = pd.DataFrame(self.terms_and_definitions["Key Items"])
        project_preview["Predicted Legal Terms"] = responses

        project_preview = self.false_positives_postprocessing(project_preview, self.postprocess_excluded_keys)

        project_preview["Predicted Legal Terms"] = project_preview[
            "Predicted Legal Terms"
        ].apply(self.remove_triple_backticks)

        project_preview = self.terms_and_definitions_full[["Key Items"]].merge(
            project_preview, on="Key Items", how="left"
        )
        project_preview["Predicted Legal Terms"] = project_preview[
            "Predicted Legal Terms"
        ].fillna(NOT_SUPPORTED_KEY_ITEM)

        return project_preview

    def build_responses(self, chain: Any, file_paths: List[str | Path]) -> List[str]:
        prompts = [
            self._build_prompt(term_definition_row)
            for _, term_definition_row in self.terms_and_definitions.iterrows()
        ]
        logger.info(f"Starting batch processing for {file_paths}")
        responses = chain.batch([{"input": prompt} for prompt in prompts])
        responses = [response["answer"].strip() for response in responses]
        return responses  # type: ignore

    @staticmethod
    def false_positives_postprocessing(project_preview: pd.DataFrame, excluded_keys: List[str] = None) -> pd.DataFrame:
        if excluded_keys is None:
            excluded_keys = []
        project_preview["Predicted Legal Terms"] = project_preview.apply(
            lambda row: (
                NOT_PROVIDED_STR
                if (
                           ("not provided" in row["Predicted Legal Terms"].lower())
                           or (NOT_PROVIDED_STR.lower() in row["Predicted Legal Terms"].lower())
                   ) and row["Key Items"] not in excluded_keys
                else row["Predicted Legal Terms"]
            ),
            axis=1
        )
        return project_preview

    @staticmethod
    def remove_triple_backticks(text: str) -> str:
        text = text.strip()
        if text.startswith("```"):
            text = text[3:]
        if text.endswith("```"):
            text = text[:-3]

        return text

    @classmethod
    def from_config(cls, config: PipelineConfig) -> "Pipeline":
        """Create a Pipeline instance from a PipelineConfig instance."""
        return cls(
            k=config.k,
            chunk_size=config.chunk_size,
            few_shot=config.few_shot,
            terms_and_definitions=config.get_terms_and_definitions(),
            few_shot_examples=config.get_few_shot_examples(),
            add_tables=config.add_tables,
            processor_location=config.processor_location,
            processor_project_id=config.processor_project_id,
            processor_id=config.processor_id,
            model_type=config.model_type,
            config=config,
        )

class PipelinePhaseIESA(Pipeline):
    process_first_n_pages = 5
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.postprocess_excluded_keys = ["RECs (Y/N and $ amount)"]

    def _build_a_chain_phase_1_esa(
            self, file_paths: List[str | Path], pages: List[List[int]] | None = None
    ) -> Any:
        """
        Builds a langchain chain for a given file.
        Builds a basic RAG chain for a given file.
        """
        retrieval_qa_chat_prompt = rag_prompt_template()
        logger.info(f"Preparing file for processing: {file_paths}")
        logger.info(f"Processing pages: {pages}")
        try:
            file_sequence: FileSequence = self.processor.process_documents(
                file_paths, pages=pages
            )
        except Exception as e:
            raise ValueError(f"Failed to create FileSequence. Error: {e}")
        self.retriever = self.vectordb.retriever_from_file_sequence(file_sequence)
        combine_docs_chain = create_stuff_documents_chain(
            self.llm, retrieval_qa_chat_prompt
        )

        retrieval_chain = create_retrieval_chain(self.retriever, combine_docs_chain)
        return retrieval_chain

    def build_responses(self, chain: Any, file_paths: List[str | Path]) -> List[str]:
        """
        Main method of the term_extraction. This method builds a project preview for a
        :param file_paths: list of file paths to process
        :param chain: default chain
        :return:
        """
        logger.info("Build a PHASE 1 CHAIN")
        chain_phase_1_esa = self._build_a_chain_phase_1_esa(
            file_paths, pages=[[i for i in range(self.process_first_n_pages)]]
        )
        logger.info("Build a SIMPLE CHAIN")

        logger.info(f"Building project preview for {file_paths}")

        # include only first page of the agreement into the processing
        terms_and_definitions_first_page = self.terms_and_definitions[
            self.terms_and_definitions["Use_first_page"].values.astype(bool)
        ]
        # include all the pages of the agreement into the processing
        terms_and_definitions_all_pages = self.terms_and_definitions[
            ~self.terms_and_definitions["Use_first_page"].values.astype(bool)
        ]
        prompts_phase_1 = [
            self._build_prompt(row)
            for _, row in terms_and_definitions_first_page.iterrows()
        ]
        prompts_other = [
            self._build_prompt(row)
            for _, row in terms_and_definitions_all_pages.iterrows()
        ]
        logger.info(f"Starting batch processing for {file_paths}")

        responses_phase_1_esa = chain_phase_1_esa.batch(
            [{"input": prompt} for prompt in prompts_phase_1]
        )
        responses_other = chain.batch([{"input": prompt} for prompt in prompts_other])
        responses = []
        for use_first_page in self.terms_and_definitions["Use_first_page"].tolist():
            if use_first_page:
                responses.append(responses_phase_1_esa.pop(0))
            else:
                responses.append(responses_other.pop(0))

        responses = [response["answer"].strip() for response in responses]
        return responses


class PipelineFactory:
    """Factory for creating pipelines."""

    @staticmethod
    def create_pipeline(config: PipelineConfig) -> Pipeline:
        """Create a pipeline based on the pipeline_name in the config."""
        if config.pipeline_name == Phase1ESAConfig.pipeline_name:
            logger.info("Creating Phase I ESA pipeline")
            return PipelinePhaseIESA.from_config(config)
        else:
            logger.info("Creating default pipeline")
            return Pipeline.from_config(config)


class PipelineRunner:
    """Pipeline runner for building project previews."""

    def __init__(
            self, pipeline_config: PipelineConfig, experiment: Optional[str] = None
    ):
        """Initialize the term_extraction runner."""
        self.pipeline = PipelineFactory.create_pipeline(pipeline_config)
        self.config = pipeline_config
        self.override_mlflow_experiment(experiment)
        self.result_metrics = None

    def run(self, file_sequences: List[FileSequence]) -> None:
        """Run the term_extraction.

        Pipeline steps:
        1. Load the correct project preview.
        2. Build the project preview using our model.
            1.1. Load the document.
            1.2. Process the document with DocAI service (extract text).
            1.3. Build the RAG system on the provided document text.
            1.4. Execute LLM within the RAG system to generate legal terms.
            1.5. Build the project preview.
        3. Compare the predicted and actual project previews.
        4. Calculate the metrics.
        5. Save the results.

        Note: The term_extraction is run for each file in the term_extraction config.
        """

        logger.info("STARTING PIPELINE")
        results = []

        with mlflow.start_run():
            for file_name, file_sequence in zip(self.config.get_file_names(), file_sequences):
                logger.info("PROCESSING FILE: %s", file_name)
                logger.info("PULLING CORRECT PROJECT PREVIEW: %s", file_name)
                correct_project_preview = get_project_preview(
                    self.config.get_project_previews_path(), file_name
                )
                if self.config.pipeline_name == "pv-syst":
                    correct_project_preview["Legal Terms"] = correct_project_preview[
                        "Value"
                    ].copy()

                logger.info("BUILDING PROJECT PREVIEW: %s", file_name)
                if type(file_name) is list:
                    predicted_project_preview = self.pipeline.build_project_preview(
                        [
                            self.config.get_documents_path() + attachment
                            for attachment in file_name
                        ], file_sequence
                    )
                else:
                    predicted_project_preview = self.pipeline.build_project_preview(
                        [self.config.get_documents_path() + file_name],  # type: ignore
                        file_sequence
                    )
                logger.info("COMPARE: %s", file_name)
                predicted_actual = predicted_project_preview.merge(
                    correct_project_preview, on="Key Items", how="outer"
                )
                logger.info("CALCULATE METRICS: %s", file_name)
                metrics = calculate_metrics(predicted_actual, self.config.metrics)
                predicted_actual_metrics = predicted_actual.merge(
                    metrics, on="Key Items", how="outer"
                )
                if type(file_name) is list:
                    results.append(
                        predicted_actual_metrics.assign(file_name=file_name[0])
                    )
                else:
                    results.append(predicted_actual_metrics.assign(file_name=file_name))
                logger.info("FINISHED PROCESSING: %s", file_name)
            logger.info(f"SAVING RESULTS TO: {self.config.get_output_results_path()}")

            metrics_total, results_df_name = save_results(
                results, self.config.get_output_results_path(), self.config.metrics
            )
            self.result_metrics = (
                metrics_total.iloc[0].drop(["file_name", "key_item"]).to_dict()
            )
            mlflow.set_tag("Type", "Legal Terms Extraction")
            mlflow_metrics(self.config, metrics_total, results_df_name)
            logger.info("PIPELINE SUCCEEDED")

    def override_mlflow_experiment(self, experiment: str | None) -> None:
        """Override the mlflow experiment if experiment name is provided"""
        if experiment:
            mlflow.set_experiment(experiment)
        else:
            mlflow.set_experiment(self.config.pipeline_name)


In [9]:
pipeline_config = Phase1ESAConfig()
# Run the pipeline
pipeline_runner = PipelineRunner(pipeline_config)
pipeline_runner.run(file_sequences)