In [None]:
# vllm serve "Qwen/QwQ-32B-AWQ" \
# vllm serve "Valdemardi/DeepSeek-R1-Distill-Qwen-32B-AWQ" \
# --enforce-eager \
# vllm serve /workspace/model_merged_daretie_quant/ \
# vllm serve "Qwen/QwQ-32B-AWQ" \
# --quantization awq_marlin \
vllm serve "Valdemardi/DeepSeek-R1-Distill-Qwen-32B-AWQ" \
--max_model_len 28000 \
--gpu-memory-utilization 0.8 \
--dtype float16 \
--port 8008 \
--host 0.0.0.0

# vllm serve BAAI/bge-m3 --trust-remote-code --task embed --port 8008 --host 0.0.0.0

!pip3 install --upgrade transformers

# Get Rule

In [None]:
%load_ext autoreload
%autoreload 2

from federal_register.client import FederalRegister
import requests

# CIP 2024-10738
# 2023-16377, 2019-12164, 2020-28868
# Initialize the client.
federal_register_client = FederalRegister()

# Grab a specific document.
# document_id = '2023-16377'
document_id = '2024-10738'
federal_document = federal_register_client.document_by_id(
    document_id=document_id,
    fields='all'
)

print(federal_document['full_text_xml_url'])
response = requests.get(federal_document['full_text_xml_url'])
response.status_code

### Preprocess Rule XML

In [None]:
from preprocess_rule import parse_xml_title

rule_title = parse_xml_title(response.content[:5000].decode("utf-8"))
print(rule_title)

In [None]:
from preprocess_rule import parse_and_clean_xml, clean_xml_text, get_page_numbers
from llama_index.core import Document

titles, sections = parse_and_clean_xml(response.content, first_section_title='Rule Introduction')
assert len(titles) == len(sections)
for title,section in zip(titles,sections):
    first, last = get_page_numbers(section)
    # print(f"{title}: first page = {first}, last page = {last}, {len(section.split())*1.25} tokens")
    print(f"{title}: first page = {first}, last page = {last}, tokens = {len(section.split())*1.25}")
print('All Titles\n',titles)

# Clean the text
cleaned_rule_proposal = {title:clean_xml_text(section, remove_footnotes=True, remove_page_references=True, remove_xml_tags=True, remove_extra_whitespace=True) for title,section in zip(titles,sections)}

# Get request for comments
request_for_comments = None #cleaned_rule_proposal['III. Request for Comments']

# Get only the sections for summarization. 
# omit_sections = ['Rule Introduction', 'Table of Contents', 'Text of Proposed Rules and Form Amendments']
# omit_sections = ['Rule Introduction', 'Table of Contents', 'II. Discussion', 'III. Economic Analysis', 'IV. Paperwork Reduction Act', 'V. Initial Regulatory Flexibility Analysis', 'VI. Consideration of Impact on the Economy', 'Text of Proposed Rules and Form Amendments']
omit_sections = ['III. Request for Comments', "IX. FinCEN's Unfunded Mandates Reform Act Determination", 'Authority and Issuance']
relevant_sections_dict = {title:section for title,section in cleaned_rule_proposal.items() if title not in omit_sections}


# Agentic Summarization

In [None]:
%load_ext autoreload
%autoreload 2

import os, re, copy
from llama_index.core.workflow import (
    Context,
    Event,
    StartEvent,
    StopEvent,
    Workflow,
    step,
)

from llama_index.utils.workflow import draw_all_possible_flows
from llama_index.core import get_response_synthesizer
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI as LIOpenAI
from llama_index.llms.openai_like import OpenAILike

from dotenv import load_dotenv
load_dotenv('/workspace/repos/fsi/.env')
openai_key = os.getenv("OPENAI_API_KEY")

import asyncio
import nest_asyncio
nest_asyncio.apply()

class SummarizationEvent(Event):
    result: dict

class RefineSummariesEvent(Event):
    result: dict

class KeyPointsEvent(Event):
    result: dict

class KeyPointsGPEvent(Event):
    result: dict

class FactCheckingEvent(Event):
    result: dict

class FactUpdateEvent(Event):
    result: dict

class FormatSummaryOutputEvent(Event):
    result: dict

class ObjectiveVoiceEvent(Event):
    result: dict


from llama_index.core.prompts.base import BasePromptTemplate, PromptTemplate
from summary_prompts import (
                             LONG_SUMMARIZE_PROMPT_TMPL, 
                             SHORT_SUMMARIZE_PROMPT_TMPL, 
                             REFINE_SUMMARIES_PROMPT_TMPL,
                             BULLET_POINTS_PROMPT_TMPL, 
                             GUIDING_PRINCIPLES_BULLET_POINTS_PROMPT_TMPL,
                             BULLET_POINT_FACT_CHECKING_PROMPT_TMPL,
                             BULLET_POINT_UPDATE_PROMPT_TMPL,
                             FORMAT_SUMMARY_OUTPUT_PROMPT_TMPL
                             )

#TODO: use prompts not PromptTemplate
long_summarize_prompt_template = PromptTemplate(LONG_SUMMARIZE_PROMPT_TMPL)
short_summarize_prompt_template = PromptTemplate(SHORT_SUMMARIZE_PROMPT_TMPL)
refine_summaries_prompt_template = PromptTemplate(REFINE_SUMMARIES_PROMPT_TMPL)
bullet_points_prompt_template = PromptTemplate(BULLET_POINTS_PROMPT_TMPL)
guiding_principles_bullet_points_prompt_template = PromptTemplate(GUIDING_PRINCIPLES_BULLET_POINTS_PROMPT_TMPL)
bullet_point_fact_checking_prompt_template = PromptTemplate(BULLET_POINT_FACT_CHECKING_PROMPT_TMPL)
bullet_point_update_prompt_template = PromptTemplate(BULLET_POINT_UPDATE_PROMPT_TMPL)
format_summary_output_prompt_template = PromptTemplate(FORMAT_SUMMARY_OUTPUT_PROMPT_TMPL)

def deepseek_llm_postprocesser(llm_response):
    return llm_response.split('</think>')[-1].strip()

def deepseek_summary_postprocesser(llm_response):
    return re.sub(r'<think>.*?</think>', '', llm_response, flags=re.DOTALL).strip()

def llm_postprocesser(llm_response):
    return llm_response.strip()

def summary_postprocesser(llm_response):
    return llm_response.strip()


def deepseek_llm_postprocesser(llm_response):
    return llm_response.split('</think>')[-1].strip()

def deepseek_summary_postprocesser(llm_response):
    return re.sub(r'<think>.*?</think>', '', llm_response, flags=re.DOTALL).strip()

def llm_postprocesser(llm_response):
    return llm_response.strip()

def summary_postprocesser(llm_response):
    return llm_response.strip()


class RuleSummarizationFlow(Workflow):

    async def _get_nodes_for_section(self, ctx: Context, section_title: str):
        return [x for x in await ctx.get("nodes") if x.metadata['section_title'] == section_title]

    async def _generate_llm_response(self, llm, template, **kwargs):
        formatted_prompt = template.format(**kwargs)
        response = await llm.acomplete(formatted_prompt)
        return response.text

    async def _process_and_postprocess(self, ctx: Context, llm, template, postprocessor, **kwargs):
        response_text = await self._generate_llm_response(llm, template, **kwargs)
        return postprocessor(response_text)

    async def _summarize_section(self, ctx: Context, section):
        print(f"  Summarizing section: {section.metadata['section_title']}...")
        nodes = await self._get_nodes_for_section(ctx, section.metadata['section_title'])
        summarizer = await ctx.get("tree_summarizer")
        summary_prompt = await ctx.get('summary_prompt')
        compliance_guidance = await ctx.get("compliance_guidance")
        additional_guidance = await ctx.get("additional_guidance")
        postprocessor = await ctx.get("local_summary_postprocesser")

        template = summary_prompt.get_template()

        response = await summarizer.aget_response(
            template.format(
                compliance_guidance=compliance_guidance,
                additional_guidance=additional_guidance
            ),
            [doc.text for doc in nodes]
        )
        clean_response = postprocessor(response)
        return section.metadata['section_title'], clean_response
    
    async def _refine_summary(self, ctx: Context, title: str, summary: str):
        print(f"  Refining summary section: {title}...")
        local_llm = await ctx.get("local_llm")
        postprocesser = await ctx.get("local_llm_postprocesser")
        template = refine_summaries_prompt_template.get_template()
        refined_summary = await self._process_and_postprocess(ctx, local_llm, template, postprocesser, summaries=summary)
        print(f"  {title} Number of words: {len(refined_summary.split())}")
        return title, refined_summary

    async def _extract_bullet_points(self, ctx: Context, title: str, summary: str, prompt_template, **kwargs):
        print(f"  Extracting key bullet points for section: {title}...")
        bullet_points_llm = await ctx.get("section_bullet_points_llm")
        compliance_guidance = await ctx.get("compliance_guidance")
        additional_guidance = await ctx.get("additional_guidance")
        postprocessor = await ctx.get("local_llm_postprocesser")
        
        bullet_points = await self._process_and_postprocess(
            ctx, 
            bullet_points_llm, 
            prompt_template.get_template(), 
            postprocessor,
            compliance_guidance=compliance_guidance, 
            additional_guidance=additional_guidance, 
            section_str=summary,
            **kwargs
        )
        return title, bullet_points

    async def _check_facts(self, ctx: Context, title: str, summary: str, bullet_points: str):
        print(f"  Fact checking bullets: {title}...")
        fact_checker_llm = await ctx.get("fact_checker_llm")
        postprocessor = await ctx.get("expert_llm_postprocesser")
        template = bullet_point_fact_checking_prompt_template.get_template()
        result = await self._process_and_postprocess(ctx, fact_checker_llm, template, postprocessor, section_str=summary, statement=bullet_points)
        return title, result

    async def _update_facts(self, ctx: Context, title: str, incorrect_bullet_points: str, corrected_bullet_points: str):
         print(f"  Fact update section: {title}...")
         fact_update_llm = await ctx.get("fact_update_llm")
         postprocessor = await ctx.get("expert_llm_postprocesser")
         template = bullet_point_update_prompt_template.get_template()
         updated_bullets = await self._process_and_postprocess(ctx, fact_update_llm, template, postprocessor, incorrect_bullet_points=incorrect_bullet_points, corrected_bullet_points=corrected_bullet_points)
         return title, updated_bullets

    async def _format_summary(self, ctx: Context, title: str, summary: str):
        print(f"  Formatting summary output: {title}...")
        local_llm = await ctx.get("local_llm")
        postprocessor = await ctx.get("local_llm_postprocesser")
        template = format_summary_output_prompt_template.get_template()
        formatted_summary = await self._process_and_postprocess(ctx, local_llm, template, postprocessor, summary_output=summary)
        formatted_summary = formatted_summary.replace('**', '')
        print(f"  {title} Number of words: {len(formatted_summary.split())}")
        return title, formatted_summary

    async def _configure_long_summary(self, local_model_name, ctx):
        summary_style = 'compact_accumulate'
        summary_prompt = long_summarize_prompt_template
        summary_ctx_len = 7000
        buffer_size = 5
        breakpoint_percentile_threshold = 75
        if local_model_name in ['deepseek-r1:32b', 'qwq']:
            await ctx.set("local_summary_postprocesser", deepseek_summary_postprocesser)
        else:
            await ctx.set("local_summary_postprocesser", summary_postprocesser)
        return summary_style, summary_prompt, summary_ctx_len, buffer_size, breakpoint_percentile_threshold

    async def _configure_short_summary(self, local_model_name, ctx):
        summary_style = 'tree_summarize'
        summary_prompt = short_summarize_prompt_template
        summary_ctx_len = 10000
        buffer_size = 1
        breakpoint_percentile_threshold = 95
        await ctx.set('summary_response_split', None)
        if local_model_name in ['deepseek-r1:32b', 'qwq']:
            await ctx.set("local_summary_postprocesser", deepseek_llm_postprocesser)
        else:
            await ctx.set("local_summary_postprocesser", llm_postprocesser)
        return summary_style, summary_prompt, summary_ctx_len, buffer_size, breakpoint_percentile_threshold


    @step
    async def initialize(self, ctx: Context, ev: StartEvent) -> SummarizationEvent:
        # see https://github.com/run-llama/llama_index/blob/f7c5ee5efbb6172e819f26d1705fcdf6114b11a3/llama-index-core/llama_index/core/response_synthesizers/type.py#L4
        # Summarization methods: "accumulate", "compact_accumulate", "compact", "simple_summarize", "tree_summarize", "refine"
        # ollama start (defaults to "http://127.0.0.1:11434")
        # OLLAMA_HOST="http://127.0.0.1:11435" ollama start
        embed_model_name = "nomic-embed-text"
        # embed_model_name = "bge-m3"
        local_model_name, local_ctx_len = "deepseek-r1:32b", 20000
        # local_model_name, local_ctx_len = "qwq", 32000
        # local_model_name, local_ctx_len = "gpt-4o", 16384
        # expert_model_name, expert_ctx_len = "gpt-4o", 16384
        
        assert ev.summary_length in ['long', 'short'], "Invalid summary length. Must be 'long' or 'short'."
        # Models
        
        if ev.summary_length == 'long':
            summary_style, summary_prompt, summary_ctx_len, buffer_size, breakpoint_percentile_threshold = await self._configure_long_summary(local_model_name, ctx)
            if local_model_name in ['qwq', 'deepseek-r1:32b']: await ctx.set("local_summary_postprocesser", deepseek_summary_postprocesser)
            else: await ctx.set("local_summary_postprocesser", summary_postprocesser)
        else:
            summary_style, summary_prompt, summary_ctx_len, buffer_size, breakpoint_percentile_threshold = await self._configure_short_summary(local_model_name, ctx)
            if local_model_name in ['qwq', 'deepseek-r1:32b']: await ctx.set("local_summary_postprocesser", deepseek_llm_postprocesser)
            else: await ctx.set("local_summary_postprocesser", llm_postprocesser)
            
        await ctx.set('summary_length', ev.summary_length)
        await ctx.set('summary_prompt', summary_prompt)
        
        if local_model_name in ['qwq', 'deepseek-r1:32b']: 
            system_prompt = "You are an AI assistant. Be helpful and informative. Provide accurate information. Be respectful and professional. Only answer in English."
            await ctx.set("local_llm_postprocesser", deepseek_llm_postprocesser)
        else: 
            system_prompt = None
            await ctx.set("local_llm_postprocesser", llm_postprocesser)
            await ctx.set("local_summary_postprocesser", summary_postprocesser)
        

        # Embedding model
        # OLLAMA_HOST="http://127.0.0.1:11435" ollama start
        await ctx.set('embed_model', OllamaEmbedding(embed_model_name, base_url="http://localhost:11435"))
        # TODO: add tokenizer to splitter
        splitter = SemanticSplitterNodeParser(buffer_size=buffer_size, 
                                            embed_model=await ctx.get('embed_model'), 
                                            include_metadata=True, 
                                            breakpoint_percentile_threshold=breakpoint_percentile_threshold)
        
        # Build the documents (they will be split into chunks later)
        documents_rule_proposal = [Document(text=section_text, metadata={'section_title':title}) for title, section_text in ev.rule_proposal_sections.items()]
        nodes = splitter.get_nodes_from_documents(documents_rule_proposal, show_progress=True)
        summary_ctx_len = max([len(node.text.split()) for node in nodes])*2.0
        print(summary_ctx_len)
        await ctx.set("documents", documents_rule_proposal)
        await ctx.set("nodes", nodes)
        
        # Local LLM
        # await ctx.set("local_llm", LIOpenAI(model=local_model_name, max_tokens=local_ctx_len, api_key=openai_key, temperature=0.5))
        # await ctx.set("summary_llm", LIOpenAI(model=expert_model_name, max_tokens=summary_ctx_len, api_key=openai_key, temperature=0.5))

        # additional_kwargs = {"num_predict": 10000,
        #                      "mirostat":0}
        # await ctx.set("local_llm", Ollama(model=local_model_name, url="http://127.0.0.1:11434", context_window=local_ctx_len, model_type="chat", is_function_calling_model=True, 
        #                                   request_timeout=4000.0, additional_kwargs=additional_kwargs, keep_alive=0, system_prompt=system_prompt))
        # await ctx.set("summary_llm", Ollama(model=local_model_name, url="http://127.0.0.1:11434", context_window=summary_ctx_len, model_type="chat", is_function_calling_model=True, 
        #                                     request_timeout=4000.0, additional_kwargs=additional_kwargs, keep_alive=0, system_prompt=system_prompt))

        vllm_model = "/workspace/model_merged_daretie_quant/"
        # vllm_model = "Valdemardi/DeepSeek-R1-Distill-Qwen-32B-AWQ"
        await ctx.set("local_llm", OpenAILike(
                                            model=vllm_model,
                                            temperature=0.6,
                                            system_prompt=system_prompt,
                                            api_base="http://0.0.0.0:8008/v1",
                                            api_key="fake",
                                            is_chat_model=True,
                                            is_function_calling_model=True,
                                            context_window=local_ctx_len,
                                            # max_tokens=10000,
                                            timeout=4000.0
                                        )
                    )
        await ctx.set("summary_llm", OpenAILike(
                                            model=vllm_model,
                                            temperature=0.6,
                                            system_prompt=system_prompt,
                                            api_base="http://0.0.0.0:8008/v1",
                                            api_key="fake",
                                            is_chat_model=True,
                                            is_function_calling_model=True,
                                            context_window=summary_ctx_len,
                                            # max_tokens=10000,
                                            timeout=4000.0
                                        )
                    )
        
        # Expert LLM
        # await ctx.set("expert_llm", LIOpenAI(model=expert_model_name, max_tokens=expert_ctx_len, api_key=openai_key))
        await ctx.set("expert_llm", await ctx.get("local_llm") )
        await ctx.set("expert_llm_postprocesser", await ctx.get("local_llm_postprocesser"))
        
        await ctx.set("tree_summarizer", get_response_synthesizer(llm=await ctx.get("summary_llm"), 
                                                                  response_mode=summary_style))
        await ctx.set("section_bullet_points_llm", await ctx.get("local_llm"))
        await ctx.set("fact_checker_llm", await ctx.get("expert_llm"))
        await ctx.set("fact_update_llm", await ctx.get("expert_llm"))
        await ctx.set("section_guiding_principles_bullet_points_llm", await ctx.get("local_llm"))
        await ctx.set("request_for_comments", ev.request_for_comments)
        await ctx.set("guiding_principles", ev.guiding_principles)
        await ctx.set("compliance_guidance", ev.compliance_guidance)
        await ctx.set("additional_guidance", ev.additional_guidance)
        #TODO: implement fact checking manager instead of status and num iterations
        # await ctx.set("fact_checking_manager", FactCheckingManager(ctx))
        await ctx.set("fact_checking_status", 1)
        await ctx.set("fact_checking_num_iterations", ev.fact_checking_num_iterations)

        return SummarizationEvent(result={})
    
    @step
    async def summarize(self, ctx: Context, ev: SummarizationEvent) -> KeyPointsEvent | RefineSummariesEvent:
        documents = await ctx.get("documents")
        summary_length = await ctx.get('summary_length')
        tasks = [self._summarize_section(ctx, section) for section in documents]
        results = await asyncio.gather(*tasks)
        summaries = dict(results)
        await ctx.set("summaries", summaries)
        # ev.result['summaries'] = summaries
        if summary_length == 'short':
            return KeyPointsEvent(result=ev.result)
        else:
            return RefineSummariesEvent(result=ev.result)
        
    @step
    async def refine_summaries(self, ctx: Context, ev: RefineSummariesEvent) -> KeyPointsEvent:
        # Store raw summaries
        summaries = await ctx.get("summaries")
        raw_summaries = copy.deepcopy(summaries)
        await ctx.set("raw_summaries", raw_summaries)
        # ev.result['raw_summaries'] = copy.deepcopy(ev.result['summaries'])
        tasks = [
            self._refine_summary(ctx, title, summary)
            for title, summary in raw_summaries.items()
            # for title, summary in ev.result['raw_summaries'].items()
        ]
        results = await asyncio.gather(*tasks)
        refined_summaries = dict(results)
        await ctx.set("summaries", refined_summaries)
        # ev.result['summaries'] = refined_summaries
        return KeyPointsEvent(result=ev.result)

    @step
    async def summary_bullet_points(self, ctx: Context, ev: KeyPointsEvent) -> FactCheckingEvent:
        summaries = await ctx.get("summaries")
        tasks = [
            self._extract_bullet_points(ctx, title, summary, bullet_points_prompt_template)
            for title, summary in summaries.items()
            # for title, summary in ev.result['summaries'].items()
        ]

        results = await asyncio.gather(*tasks)
        bullet_points = dict(results)        
        await ctx.set("bullet_points", bullet_points)
        # ev.result['bullet_points'] = bullet_points
        return FactCheckingEvent(result=ev.result)

    @step
    async def fact_checking(self, ctx: Context, ev: FactCheckingEvent) -> KeyPointsGPEvent | FactUpdateEvent:
        # facts_updated = "fact_updates" in ev.result
        facts_updated = await ctx.get("fact_updates", False)
        update_bullets = False
        fact_checks = {}
        fact_checking_status = await ctx.get("fact_checking_status")
        fact_checking_num_iterations = await ctx.get("fact_checking_num_iterations")
        summaries = await ctx.get("summaries")

        if fact_checking_status <= fact_checking_num_iterations:
            tasks = []
            fact_updates = await ctx.get("fact_updates", {})
            bullet_points = await ctx.get("bullet_points")
            for title, summary in summaries.items():
            # for title, summary in ev.result['summaries'].items():
                # Only check if not previously updated or if the previous update needed changes
                # if not facts_updated or (title in ev.result.get('fact_updates', {}) and ev.result['fact_updates'][title] != 'None'):
                if not facts_updated or (title in fact_updates and fact_updates[title] != 'None'):
                    tasks.append(self._check_facts(ctx, title, summary, bullet_points[title]))
                
            if tasks:  # Only process if there are tasks
                results = await asyncio.gather(*tasks)
                fact_checks = dict(results)
                
                # Check if any updates are needed
                update_bullets = any(
                    check != 'None' or len(check) > 25 
                    for check in fact_checks.values()
                )
                
        if update_bullets:
            fact_checking_status += 1
            await ctx.set("fact_checking_status", fact_checking_status)
            
        await ctx.set('fact_checks', fact_checks)
        # ev.result['fact_checks'] = fact_checks
        return FactUpdateEvent(result=ev.result) if update_bullets else KeyPointsGPEvent(result=ev.result)

    @step
    async def fact_update(self, ctx: Context, ev: FactUpdateEvent) -> FactCheckingEvent:
        bullet_points = await ctx.get("bullet_points")
        fact_checks = await ctx.get("fact_checks")
        tasks = [
            self._update_facts(ctx, title, bullet_updates, bullet_points[title])
            # self._update_facts(ctx, title, bullet_updates, ev.result['bullet_points'][title])
            # for title, bullet_updates in ev.result['fact_checks'].items()
            for title, bullet_updates in fact_checks.items()
        ]
        results = await asyncio.gather(*tasks)
        fact_updates = dict(results)
        for title, updated_bullet_points in fact_updates.items():
            bullet_points[title] = updated_bullet_points
        await ctx.set('bullet_points', bullet_points)
        await ctx.set('fact_updates', fact_updates)
        # ev.result['fact_updates'] = fact_updates
        return FactCheckingEvent(result=ev.result)

    @step
    async def guiding_principles_bullet_points(self, ctx: Context, ev: KeyPointsGPEvent) -> FormatSummaryOutputEvent:
        guiding_principles = await ctx.get("guiding_principles")
        summaries = await ctx.get("summaries")
        tasks = [
            self._extract_bullet_points(ctx, title, summary, guiding_principles_bullet_points_prompt_template, guiding_principles=guiding_principles)
            for title, summary in summaries.items()
            # for title, summary in ev.result['summaries'].items()
        ]
        results = await asyncio.gather(*tasks)
        gp_bullet_points = dict(results)
        # ev.result['guiding_principles_bullet_points'] = gp_bullet_points
        await ctx.set('guiding_principles_bullet_points', gp_bullet_points)
        return FormatSummaryOutputEvent(result=ev.result)
    

    @step
    async def format_summary_output(self, ctx: Context, ev: FormatSummaryOutputEvent) -> StopEvent:
        summaries = await ctx.get("summaries")
        tasks = [
            self._format_summary(ctx, title, summary)
            for title, summary in summaries.items()
            # for title, summary in ev.result['summaries'].items()
        ]

        results = await asyncio.gather(*tasks)
        formatted_summaries = dict(results)        
        await ctx.set('formatted_summaries', formatted_summaries)
        # ev.result['formatted_summaries'] = formatted_summaries
        return StopEvent(result=ctx)

draw_all_possible_flows(RuleSummarizationFlow, filename="/workspace/data/rule_flows.html")

In [None]:
fake_relevant_section_dict = {k:v for k,v in relevant_sections_dict.items() if k in ["IV. Analysis of the Costs and Benefits Associated With the Proposed Rule"]}

In [None]:
## Runnining the workflow
# Load the guiding principles
with open('/workspace/repos/fsi/fsi_guiding_principles.txt', 'r') as f:
    guiding_principles = f.read()

c = RuleSummarizationFlow(timeout=12000, verbose=True)
# Pass in the data
result = await c.run(rule_proposal_sections=fake_relevant_section_dict, 
# result = await c.run(rule_proposal_sections=relevant_sections_dict, 
                     request_for_comments=request_for_comments, 
                     guiding_principles=guiding_principles,
                     compliance_guidance="Do not include any information about requests for comments or instructions to commenters.",
                     additional_guidance='',
                     fact_checking_num_iterations=3,
                     summary_length='long')

In [None]:
import json
formatted_summaries = await result.get('formatted_summaries')
with open(f'/workspace/{rule_title}.json'.replace(' ','_'), 'w') as f:
    json.dump(formatted_summaries, f, indent=4)

In [None]:
raise

In [None]:
bullets_dict={}
for bullets in ['bullet_points', 'guiding_principles_bullet_points']:
    result_bullets = await result.get(bullets)
    bullets_dict[bullets] = {}
    for section in result_bullets.keys():
        bullet_collect=[]
        for bullet in result_bullets[section].replace('\n\n','\n').split('\n'):
            bullet_collect.append(bullet.replace('**', '').replace('-','').strip())
        bullets_dict[bullets][section] = bullet_collect

In [None]:
# Create main PDF with all sections summaries
from utils import create_pdf
formatted_summaries = await result.get('formatted_summaries')
create_pdf(formatted_summaries,  bullets_dict, f'SUMMARY {rule_title}.pdf'.replace(' ','_'), None)

In [None]:
%load_ext autoreload
%autoreload 2

from llama_index.llms.ollama import Ollama
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.core import get_response_synthesizer
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.core.prompts.base import PromptTemplate
from summary_prompts import (
                             LONG_SUMMARIZE_PROMPT_TMPL, 
                             SHORT_SUMMARIZE_PROMPT_TMPL, 
                             REFINE_SUMMARIES_PROMPT_TMPL,
                             BULLET_POINTS_PROMPT_TMPL, 
                             GUIDING_PRINCIPLES_BULLET_POINTS_PROMPT_TMPL,
                             BULLET_POINT_FACT_CHECKING_PROMPT_TMPL,
                             BULLET_POINT_UPDATE_PROMPT_TMPL
                             )


long_summarize_prompt_template = PromptTemplate(LONG_SUMMARIZE_PROMPT_TMPL)
short_summarize_prompt_template = PromptTemplate(SHORT_SUMMARIZE_PROMPT_TMPL)
refine_summaries_prompt_template = PromptTemplate(REFINE_SUMMARIES_PROMPT_TMPL)
bullet_points_prompt_template = PromptTemplate(BULLET_POINTS_PROMPT_TMPL)
guiding_principles_bullet_points_prompt_template = PromptTemplate(GUIDING_PRINCIPLES_BULLET_POINTS_PROMPT_TMPL)
bullet_point_fact_checking_prompt_template = PromptTemplate(BULLET_POINT_FACT_CHECKING_PROMPT_TMPL)
bullet_point_update_prompt_template = PromptTemplate(BULLET_POINT_UPDATE_PROMPT_TMPL)
embed_model_name="nomic-embed-text"
embed_model = OllamaEmbedding(embed_model_name, base_url="http://localhost:11435")
splitter = SemanticSplitterNodeParser(buffer_size=5, embed_model=embed_model, include_metadata=True, breakpoint_percentile_threshold=75)
documents_rule_proposal = [Document(text=section_text, metadata={'section_title':title}) for title, section_text in relevant_sections_dict.items()]
nodes = splitter.get_nodes_from_documents(documents_rule_proposal, show_progress=True)

import os
from llama_index.llms.openai import OpenAI as LIOpenAI
from dotenv import load_dotenv
load_dotenv('/workspace/repos/fsi/.env')
openai_key = os.getenv("OPENAI_API_KEY")

expert_model_name, ctx_len = "gpt-4o", 3000

llm = LIOpenAI(model=expert_model_name, max_tokens=30000, api_key=openai_key)
        
additional_kwargs = {"num_predict": 20000,
                     "mirostat":0}
local_model_name, ctx_len, tokenizer_name = "deepseek-r1:32b", 2048, ""
llm = Ollama(model=local_model_name, url="http://127.0.0.1:11434", context_window=ctx_len, model_type="chat", is_function_calling_model=True, 
                            request_timeout=4000.0, additional_kwargs=additional_kwargs, keep_alive=0, system_prompt=None)

tree_summarizer = get_response_synthesizer(llm=llm, 
                                            response_mode="compact_accumulate", 
                                            ) 
test_nodes = [x for x in nodes if x.metadata['section_title'] == 'II. Discussion of Regulation Best Interest']
len(test_nodes), len(" ".join([x.text for x in test_nodes]).split())*1.25

print(sorted([len(x.text.split())*1.25 for x in test_nodes]))
# tree_response = tree_summarizer.get_response(long_summarize_prompt_template.get_template().format(compliance_guidance="None", 
tree_response = tree_summarizer.get_response(long_summarize_prompt_template.get_template().format(compliance_guidance="Do not include any information about requests for comments or instructions to commenters.", 
                                                                                                  additional_guidance=""), 
                                             text_chunks=[doc.text for doc in test_nodes])
# raw_summaries = "\n\n".join([x.split('</think>')[-1].strip() for x in tree_response.split("---------------------\nResponse")])
# raw_summaries_list = [x.split('</think>')[-1].strip() for x in tree_response.split("---------------------\nResponse")]
import re
text = re.sub(r'<think>.*?</think>', '', tree_response, flags=re.DOTALL)
print(text)
refine_summaries_prompt = ("You will be given text containing several independent reports. "
                           "Each report begins with 'Response' and ends with '---------------------'. "
                           "Your task is to rewrite all the responses (i.e. do not summarize) to form a coherent and accurate report. "
                           "Remove any redundant information and redundant phrasing (e.g. 'The regulation specifies', 'In summary', 'The text outlines'). "
                           "Preserve all of the details in the original report, while ensuring that the final rewrite has a clear narrative flow and logical structure. "
                           "Write the report in paragraph form, using full sentences. "
                           "Do not use any Markdown formatting. "
                           "The final report should be at least 5000 words long.\n"
                           "\n"
                           "ORIGINAL REPORT:\n"
                           "{summaries}\n"
                           "\n"
                           "REFINED REPORT:\n"
                        )
additional_kwargs = {"num_predict": 10000,
                     "mirostat":0}
local_model_name, ctx_len, tokenizer_name = "deepseek-r1:32b", 20000, ""
llm = Ollama(model=local_model_name, url="http://127.0.0.1:11434", context_window=ctx_len, model_type="chat", is_function_calling_model=True, 
                            request_timeout=4000.0, additional_kwargs=additional_kwargs, keep_alive=0, system_prompt=None)

response = llm.acomplete(refine_summaries_prompt.format(summaries=text))

In [None]:
for b in bullets_dict[bullets][section]:
    print(b.replace('**', '').replace('-','').strip())

# GraphRAG for summarization

In [None]:
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.llms.ollama import Ollama


model_name, ctx_len, tokenizer_name = "deepseek-r1:32b", 40000, ""
additional_kwargs = {"num_predict": 8000}
system_prompt = None
llm = Ollama(model=model_name, url="http://127.0.0.1:11434", context_window=ctx_len, model_type="chat", is_function_calling_model=True, 
             request_timeout=4000.0, additional_kwargs=additional_kwargs, keep_alive=0, system_prompt=system_prompt)
# OLLAMA_HOST="http://127.0.0.1:11435" ollama start
embed_model = OllamaEmbedding('bge-m3', base_url="http://localhost:11435", keep_alive=0)
splitter = SemanticSplitterNodeParser(buffer_size=1, embed_model=embed_model, include_metadata=True)
documents_rule_proposal = [Document(text=section_text, metadata={'section_title':title}) for title, section_text in relevant_sections_dict.items()]

# nodes = splitter.get_nodes_from_documents([documents_rule_proposal[6]], show_progress=True)
nodes = splitter.get_nodes_from_documents(documents_rule_proposal, show_progress=True)

In [None]:
from llama_index.core.prompts.base import BasePromptTemplate, PromptTemplate
from summary_prompts import TREE_PRAXIS_SUMMARIZE_PROMPT_TMPL, GLOBAL_PRAXIS_SUMMARIZE_PROMPT_TMPL, GLOBAL_PRAXIS_QA_PROMPT_TMPL, REQUEST_FOR_COMMENT_PRAXIS_SUMMARIZE_PROMPT_TMPL, DEF_EXTRACTION_PROMPT_TMPL
from llama_index.core import get_response_synthesizer

compliance_guidance = "None"
additional_guidance = "None"
tree_summarize_prompt_template = PromptTemplate(TREE_PRAXIS_SUMMARIZE_PROMPT_TMPL)
global_summarize_prompt_template = PromptTemplate(GLOBAL_PRAXIS_SUMMARIZE_PROMPT_TMPL)

In [None]:
summary_llm = get_response_synthesizer(llm=llm,
                                        response_mode="tree_summarize", 
                                        summary_template=tree_summarize_prompt_template) 

In [None]:
compliance_guidance = "None"
additional_guidance = "None"
summary2 = summary_llm.get_response("", 
                                   [doc.text for doc in nodes],
                                   compliance_guidance=compliance_guidance, 
                                   additional_guidance=additional_guidance)

In [None]:
# LLM as a judge
judge_prompt = """You will be given two summaries. You are to decide which summary is more professional and more informative.
Here are the criteria for a professional and informative summary:
- The summary should be verbose, making sure to addresses all key points.
- The summary should be free of grammatical errors.
- The summary should be free of spelling errors.
- The summary should be free of unnecessary information.
- The summary should be free of personal opinions.
- The summary should be free of conversational elements.
- The summary should be free of pretext and posttext.
- The summary should be free of any additional information not present in the original text.
- The summary should not contradict the original text.
Be sure to read both summaries carefully before making your decision.
Please provide the reasoning for your choice.
Format your response as follows:
**Best Summary:** [Your choice between Summary 1 and Summary 2]
**Reasoning:** [Your reasoning]
\n\nOriginal text:\n\n{original_text}\n\nSummary 1:\n\n{summary}\n\nSummary 2:\n\n{summary2}"""

In [None]:
response = llm.complete(judge_prompt.format(original_text=" ".join([doc.text for doc in nodes]), summary=summary.split('</think>')[-1].strip(), summary2=summary2.split('</think>')[-1].strip()))

In [None]:
print(response.text.split('</think>')[-1].strip())

In [None]:
summary1="""The section discusses considerations related to the impact of a proposed regulation on the economy, particularly in the context of the Small Business Regulatory Enforcement Fairness Act of 1996 (SBREFA). It explains that under SBREFA, a rule is considered "major" if it meets one or more of the following criteria: 

1. The regulation results in or is likely to result in an annual effect on the economy of $100 million or more.
2. The regulation causes a major increase in costs or prices for consumers or individual industries.
3. The regulation has significant adverse effects on competition, investment, or innovation.

The text emphasizes that the regulatory body is required to advise the Office of Management and Budget (OMB) whether the proposed regulation qualifies as a "major rule" under these criteria. Additionally, it invites public comment on the potential economic impact of the proposed rule on an annual basis, any increases in costs or prices for consumers or industries, and any effects on competition, investment, or innovation. Commenters are encouraged to provide empirical data and factual support to substantiate their views. The section underscores the importance of gathering comprehensive input to assess whether the regulation would meet the thresholds for being classified as a "major rule" under SBREFA."""

In [None]:
import os
from graph_rag_utils import set_neo4j_password, add_lines_to_conf
from dotenv import load_dotenv
load_dotenv('/workspace/repos/fsi/.env')

NEO4J_PWD = os.getenv('NEO4J_PWD')
set_neo4j_password(password=NEO4J_PWD)
add_lines_to_conf()

* FIRST set password for database `neo4j-admin dbms set-initial-password <PASSWORD>`

* THEN make sure the two lines have been added to neo4j.conf

* FINALLY start neo4j in terminal or tmux screen `neo4j start`


In [None]:

# apoc plugin needs to go in /var/lib/neo4j/plugins
#### NOTE #### neo4j and apoc versions must match!!!!!!
# wget https://github.com/neo4j/apoc/releases/download/5.22.0/apoc-5.22.0-core.jar

# in vim /etc/neo4j/neo4j.conf add the following lines
# dbms.security.procedures.allowlist=apoc.*
# dbms.security.procedures.unrestricted=apoc.*

## Load Neo4j Database

In [None]:
# Load a database dump file (the last argument is the database name)
# neo4j-admin database load --from-path=/workspace/data/compliance/ --overwrite-destination neo4j

# Start the neo4j service
# neo4j console (or neo4j start)

import os
from typing import Literal
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.core import PropertyGraphIndex, StorageContext
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor, SimpleLLMPathExtractor
from llama_index.core.postprocessor import LLMRerank
from dotenv import load_dotenv
load_dotenv('/workspace/repos/fsi/.env')

import nest_asyncio
nest_asyncio.apply()

neo_url="bolt://localhost:7687"
NEO4J_PWD = os.getenv('NEO4J_PWD')
graph_store = Neo4jPropertyGraphStore(
        username="neo4j",
        password=NEO4J_PWD,
        url=neo_url,
        database="neo4j",
    )

entities = Literal[
    "REGULATION",          
    "AGENCY",              
    "DEFINITION",          
    "TERM",                
    "SECTION",             
    "STAKEHOLDER",         
    "REQUIREMENT",         
    "EXCEPTION",           
    "TIMELINE",            
    "PENALTY",             
    "ECONOMIC_IMPACT",     
    "JUSTIFICATION",       
    "REQUEST_FOR_COMMENT", 
]
relations = Literal[
    "REPLACES",            
    "AMENDS",              
    "REFERS_TO",           
    "DEFINES",             
    "APPLIES_TO",          
    "HAS_SECTION",         
    "HAS_REQUIREMENT",     
    "HAS_EXCEPTION",       
    "HAS_TIMELINE",        
    "HAS_PENALTY",         
    "HAS_JUSTIFICATION",   
    "HAS_REQUEST_FOR_COMMENT",  
    "IMPACTS",             
    "SUPPORTS",            
    "OPPOSES",             
    "HAS_ECONOMIC_IMPACT",  
    "AFFECTS",              
    "ESTIMATED_BY",         
    "RELATES_TO",
]
validation_schema = {
    "REGULATION": ["AMENDS", "REPLACES", "REFERS_TO", "HAS_SECTION", "HAS_REQUIREMENT", 
        "HAS_EXCEPTION", "HAS_TIMELINE", "HAS_PENALTY", "HAS_JUSTIFICATION", 
        "HAS_PUBLIC_COMMENT", "IMPACTS", "HAS_ECONOMIC_IMPACT", "AFFECTS"],
    "AGENCY": ["REFERS_TO"],
    "DEFINITION": ["DEFINES", "REFERS_TO"],
    "TERM": ["DEFINES", "REFERS_TO"],
    "SECTION": ["HAS_SECTION", "REFERS_TO"],
    "REQUIREMENT": ["APPLIES_TO", "HAS_EXCEPTION", "HAS_TIMELINE", "HAS_PENALTY", "HAS_ECONOMIC_IMPACT"],
    "ECONOMIC_IMPACT": ["AFFECTS", "ESTIMATED_BY", "REFERS_TO"],
    "STAKEHOLDER": ["IMPACTS", "OPPOSES", "SUPPORTS", "AFFECTS"],
    "EXCEPTION": ["APPLIES_TO", "RELATES_TO"],
    "TIMELINE": ["RELATES_TO"],
    "PENALTY": ["RELATES_TO"],
    "JUSTIFICATION": ["SUPPORTS", "REFERS_TO"],
    "REQUEST_FOR_COMMENT": ["REFERS_TO", "SUPPORTS", "OPPOSES", 
                            "AFFECTS", "HAS_REQUEST_FOR_COMMENT", "RELATES_TO"],
}

# Create storage context
storage_context = StorageContext.from_defaults(persist_dir='/workspace/data/compliance/graph_idx_schema_only',
                                               property_graph_store=graph_store)

# create graph index

kg_extractor = SchemaLLMPathExtractor(
    llm=llm,
    possible_entities=entities,
    possible_relations=relations,
    kg_validation_schema=validation_schema,
    strict=True,  # if false, will allow triples outside of the schema``
    num_workers=7,
    max_triplets_per_chunk=10,
)

graph_index = PropertyGraphIndex.from_existing(
                                llm=llm,
                                property_graph_store=graph_store, 
                                embed_model=embed_model, 
                                storage_context=storage_context,
                                kg_extractors=[kg_extractor],
                                show_progress=True,
                                )

# Create a graph engine
query_engine = graph_index.as_query_engine(
    # llm=llm,
    similarity_top_k=20,
    node_postprocessors=[
        LLMRerank(
            choice_batch_size=5,
            top_n=10,
        )
    ],
    # see https://github.com/run-llama/llama_index/blob/f7c5ee5efbb6172e819f26d1705fcdf6114b11a3/llama-index-core/llama_index/core/response_synthesizers/type.py#L4
    response_mode="tree_summarize", # "accumulate", "compact_accumulate", "compact", "simple_summarize", "tree_summarize"
)

## Build Neo4j Database from Scratch

In [None]:
# Build graph from scratch
# GraphRAG Database
import os
from typing import Literal

from llama_index.core.indices.property_graph import SchemaLLMPathExtractor
from llama_index.core.postprocessor import LLMRerank
from rag_utils import create_neo4j_graph_store, dump_neo4j_database

from dotenv import load_dotenv
load_dotenv('/workspace/repos/fsi/.env')

import nest_asyncio
nest_asyncio.apply()

graph_idx_persist_dir = "/workspace/data/compliance/graph_idx_schema_only"

entities = Literal[
    "REGULATION",          
    "AGENCY",              
    "DEFINITION",          
    "TERM",                
    "SECTION",             
    "STAKEHOLDER",         
    "REQUIREMENT",         
    "EXCEPTION",           
    "TIMELINE",            
    "PENALTY",             
    "ECONOMIC_IMPACT",     
    "JUSTIFICATION",       
    "REQUEST_FOR_COMMENT", 
]
relations = Literal[
    "REPLACES",            
    "AMENDS",              
    "REFERS_TO",           
    "DEFINES",             
    "APPLIES_TO",          
    "HAS_SECTION",         
    "HAS_REQUIREMENT",     
    "HAS_EXCEPTION",       
    "HAS_TIMELINE",        
    "HAS_PENALTY",         
    "HAS_JUSTIFICATION",   
    "HAS_REQUEST_FOR_COMMENT",  
    "IMPACTS",             
    "SUPPORTS",            
    "OPPOSES",             
    "HAS_ECONOMIC_IMPACT",  
    "AFFECTS",              
    "ESTIMATED_BY",         
    "RELATES_TO",
]
validation_schema = {
    "REGULATION": ["AMENDS", "REPLACES", "REFERS_TO", "HAS_SECTION", "HAS_REQUIREMENT", 
        "HAS_EXCEPTION", "HAS_TIMELINE", "HAS_PENALTY", "HAS_JUSTIFICATION", 
        "HAS_PUBLIC_COMMENT", "IMPACTS", "HAS_ECONOMIC_IMPACT", "AFFECTS"],
    "AGENCY": ["REFERS_TO"],
    "DEFINITION": ["DEFINES", "REFERS_TO"],
    "TERM": ["DEFINES", "REFERS_TO"],
    "SECTION": ["HAS_SECTION", "REFERS_TO"],
    "REQUIREMENT": ["APPLIES_TO", "HAS_EXCEPTION", "HAS_TIMELINE", "HAS_PENALTY", "HAS_ECONOMIC_IMPACT"],
    "ECONOMIC_IMPACT": ["AFFECTS", "ESTIMATED_BY", "REFERS_TO"],
    "STAKEHOLDER": ["IMPACTS", "OPPOSES", "SUPPORTS", "AFFECTS"],
    "EXCEPTION": ["APPLIES_TO", "RELATES_TO"],
    "TIMELINE": ["RELATES_TO"],
    "PENALTY": ["RELATES_TO"],
    "JUSTIFICATION": ["SUPPORTS", "REFERS_TO"],
    "REQUEST_FOR_COMMENT": ["REFERS_TO", "SUPPORTS", "OPPOSES", 
                            "AFFECTS", "HAS_REQUEST_FOR_COMMENT", "RELATES_TO"],
}

kg_extractor = SchemaLLMPathExtractor(
    llm=llm,
    possible_entities=entities,
    possible_relations=relations,
    kg_validation_schema=validation_schema,
    strict=True,  # if false, will allow triples outside of the schema
    num_workers=7,
    max_triplets_per_chunk=10,
)

# llm.is_function_calling_model = False
# extract_prompt = None
# kg_extractor = SimpleLLMPathExtractor(
#         extract_prompt=extract_prompt,
#         llm=llm,
#         max_paths_per_chunk=10,
#         num_workers=6,
#     )

print("Creating graph store...")
graph_store = create_neo4j_graph_store(neo_url="bolt://localhost:7687", 
                                       password=os.getenv("NEO4J_PWD"), 
                                       config={"connection_timeout": 1000, "connection_acquisition_timeout": 1000, "max_connection_pool_size": 1000})

# if not os.path.exists(graph_idx_persist_dir):
#     print("Deleting all nodes and relationships...")
#     neo4j_query(graph_store, query="""MATCH n=() DETACH DELETE n""")

print("Creating graphrag index...")
# graph_index = create_neo4j_graphrag(documents_rule_proposal, llm, embed_model, kg_extractor, graph_store, graph_idx_persist_dir=graph_idx_persist_dir, graph_store_persist_dir=graph_store_persist_dir)

from llama_index.core import PropertyGraphIndex
# nodes are created from semantic splitter
graph_index = PropertyGraphIndex(nodes,
                                 llm=llm,
                                 property_graph_store=graph_store,
                                 kg_extractors=[kg_extractor],
                                 embed_model=embed_model,
                                 show_progress=True,
                                 )

graph_index.storage_context.persist(persist_dir=graph_idx_persist_dir)


# dump_neo4j_database('neo4j', '/workspace/data/') # database needs to be stopped before running this command
# http://localhost:7474/browser/
# 22 mins


In [None]:


query_engine = graph_index.as_query_engine(
    # llm=llm,
    similarity_top_k=20,
    node_postprocessors=[
        LLMRerank(
            choice_batch_size=5,
            top_n=10,
        )
    ],
    # see https://github.com/run-llama/llama_index/blob/f7c5ee5efbb6172e819f26d1705fcdf6114b11a3/llama-index-core/llama_index/core/response_synthesizers/type.py#L4
    response_mode="tree_summarize", # "accumulate", "compact_accumulate", "compact", "simple_summarize", "tree_summarize"
)


In [None]:
dump_neo4j_database('neo4j', '/workspace/data/compliance')

In [None]:
summary_prompt = """Below is a section of a newly proposed federal regulation.
Your task is to produce a precise summary of the given section, highlighting all major key points.
Be very verbose in your summary, ensuring that you capture all the essential details and requirements.
Use the context above to incorporate any relevant global document information.
Use exact numeric values, specific costs, or other details where necessary.
You must strictly follow the guidance provided by the expert compliance officer:

Compliance Officer Guidelines:
{compliance_guidance}
{additional_guidance}

SECTION TO SUMMARIZE:
{section_str}


SUMMARY:"""

In [None]:
section_idx = 7
compliance_guidance = ""
additional_guidance = ""
metadata_str = "metadata: {metadata}\n".format(metadata=str(documents_rule_proposal[section_idx].metadata))
# query = tree_summarize_prompt_template.template.format(compliance_guidance=compliance_guidance, additional_guidance=additional_guidance, section_str=documents_rule_proposal[section_idx].text)
query = summary_prompt.format(compliance_guidance=compliance_guidance, additional_guidance=additional_guidance, section_str=metadata_str+documents_rule_proposal[section_idx].text)
response = query_engine.query(query)

In [None]:
print(documents_rule_proposal[section_idx].text)

In [None]:
print(response.response)

In [None]:
for source in response.source_nodes:
    print(source.metadata)
    print(source.text)
    print('--'*20)

# RAG with smaller text segments

In [None]:
import os
# from rag_utils import create_llama_vector_index_rag
from llama_index.core import VectorStoreIndex
from llama_index.core.node_parser import SemanticSplitterNodeParser, SentenceSplitter

# splitter = SemanticSplitterNodeParser(buffer_size=1, embed_model=embed_model, include_metadata=True)
# TODO: add tokenizer
persist_dir = '/workspace/data/compliance/vector_index'
splitter = SentenceSplitter(chunk_size=256, chunk_overlap=32, include_metadata=True, tokenizer=None)
documents_rule_proposal = [Document(text=section_text, metadata={'section_title':title}) for title, section_text in relevant_sections_dict.items()]
nodes = splitter.get_nodes_from_documents(documents_rule_proposal, show_progress=True)
vector_index = VectorStoreIndex(nodes, 
                                llm=llm,
                                embed_model=embed_model, 
                                show_progress=True,
                                )
# vector_index = create_llama_vector_index_rag(llm, embed_model=embed_model, persist_dir='/workspace/data/compliance/vector_index', documents=documents_rule_proposal, vector_store_kwargs={'chunk_size':256, 'chunk_overlap':32})

if not os.path.exists(persist_dir):
    print(f"Persisting vector index to {persist_dir}")
    vector_index.storage_context.persist(persist_dir=persist_dir)
else:
    print(f"Vector index already exists at {persist_dir}")

In [None]:
from llama_index.core.vector_stores.types import MetadataFilters, ExactMatchFilter
from llama_index.core.postprocessor import LLMRerank

# Define the metadata filter
metadata_filters = MetadataFilters(
    filters=[
        ExactMatchFilter(key="section_title", value="VIII. FinCEN's Regulatory Impact Analysis")
    ]
)

rag_query_engine = vector_index.as_query_engine(
    llm=llm,
    similarity_top_k=10,
    # node_postprocessors=[
    #     LLMRerank(
    #         llm=llm,
    #         choice_batch_size=5,
    #         top_n=10,
    #     )
    # ],
    # response_mode="tree_summarize", 
    filters=metadata_filters
)


In [None]:
summary_for_checking = """The proposed federal regulation under FinCEN's Regulatory Impact Analysis aims to assess the costs and benefits of the regulation. The regulation emphasizes the importance of quantifying costs and benefits, reducing costs, harmonizing rules, and promoting flexibility. It has been designated as a significant regulatory action and reviewed by the Office of Management and Budget. The primary costs of compliance with the proposed rule are detailed in the Analysis of the Costs and Benefits Associated with the Proposed Rule, with estimated annual internal time costs of $404,045,339.05 and external cost burden of $48,446,969.76. The benefits of the rule are expected to include reducing money laundering and terrorist financing in the U.S. financial system, aiding law enforcement in investigating and disrupting financial crimes. The rule would help in identifying high-risk customers and preventing criminal activities. While the economic losses prevented by reducing financial crimes are difficult to estimate, the rule would reduce both monetary and nonmonetary harms caused by such activities. The rule requires investment advisers to establish and implement a Customer Identification Program (CIP) based on their specific circumstances, with requirements similar to those for other financial institutions. Some advisers may have reduced costs if they already perform certain Anti-Money Laundering/Counter Financing of Terrorism (AML/CFT) functions or are affiliated with banks or broker-dealers. The costs incurred by the rule include establishing a CIP, verifying identifying information, checking customers against government lists, recordkeeping, and reliance on other financial institutions. FinCEN estimates that the average compliance costs for an ERA with three customers would be $1,675 internally and $654 externally, while for an RIA with 100 customers, it would be $26,468 internally and $4,088 externally. Overall, FinCEN believes that the benefits of the rule would outweigh the costs, contributing to a more secure financial system."""

In [None]:
summary_fact_checking_prompt = """Below is a summary that need to be fact-checked.
ONLY use the information retrieved in the context above to determine if the facts presented in the summary are true.
Verify all numeric values, numeric values in word form, and other factual information to ensure accuracy.
Do not use any pre-existing knowledge.
If there are any factual mistakes, rewrite the incorrect portion of the summary and include the correct information, and output in this format:
- Original summary statement: [Incorrect information]
- Corrected summary statement: [Corrected information] 
- Supporting Context: [Section context only]
If the entire summary is factually correct, output "None".
Only output the portions of the summary that are incorrect. Do not include any pre-text or post-text.

SUMMARY TO FACT-CHECK:
{statement}

FACT-CHECKING RESPONSE:"""


response = rag_query_engine.query(summary_fact_checking_prompt.format(statement=summary_for_checking))

In [None]:
print(response.response.split('</think>')[-1].strip())

In [None]:
for source in response.source_nodes:
    print(source.metadata)
    print(source.text)
    print('--'*20)

In [None]:
print(summary_for_checking)

In [None]:
section_idx=7
print(documents_rule_proposal[section_idx].text)