<a href="https://colab.research.google.com/github/felixiho/LLMs/blob/main/Transaction_Compliance_Monitoring_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Transactions Compliance Monitoring With Document Injestion

In [None]:
%pip install --quiet datasets pymongo langchain-mongodb langgraph-checkpoint-mongodb langchain-core langchain-huggingface langgraph pypdf python-docx unstructured pydantic voyageai transformers torch accelerator

In [39]:
import getpass
import os

def set_env_variables(variable):
  value = getpass.getpass(f"Enter the value for {variable}:")
  if len(value):
    os.environ[variable] = value

# os.environ["MONGO_DB_URI"]
# set_env_variables("TEST_VARIABLE")


set_env_variables("MONGO_DB_URI")

Enter the value for MONGO_DB_URI:··········


Mongo DB Setup

In [40]:
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

db_uri = os.environ.get("MONGO_DB_URI")
if not db_uri:
  raise ValueError("MONGO_DB_URI environment variable is not set")

client = MongoClient(
    db_uri,
    server_api=ServerApi('1'),
    appname="transaction_monitoring_with_document_injestion"
)

DB_NAME = "transaction_compliance"
TRANSACTIONS = "transactions"
REGULATIONS = "regulations"
CHECKPOINTS = "checkpoints"
CHECKPOINTS_WRITES = "checkpoints_writes"

try:
    client.admin.command('ping')
    print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
    print(e)

vd:27017: [Errno -2] Name or service not known (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms), Timeout: 30s, Topology Description: <TopologyDescription id: 682353038fc97133b9f004ed, topology_type: Unknown, servers: [<ServerDescription ('vd', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('vd:27017: [Errno -2] Name or service not known (configured timeouts: socketTimeoutMS: 20000.0ms, connectTimeoutMS: 20000.0ms)')>]>


Create MongoDB Collections for:


1.   Transactions
2.   Regulations
3.   Checkpoints
4.   Checkpoints Writes



In [None]:
db = client[DB_NAME]
transaction_collection = db[TRANSACTIONS]
regulations_collection = db[REGULATIONS]
checkpoints_collection = db[CHECKPOINTS]
checkpoints_writes_collection = db[CHECKPOINTS_WRITES]

def create_mongodb_collections():
  existing_collections = db.list_collection_names()

  # Transactions collections
  if TRANSACTIONS not in existing_collections:
    db.create_collection(
        TRANSACTIONS,
        validator={
            "$jsonSchema": {
                "bsonType": "object",
                "required": [
                    'transaction_id',
                    'amount',
                    'currency',
                    'sender',
                    'receiver',
                    'transaction_date'
                ],
                "properties": {
                    "transaction_id": {"bsonType": "string"},
                    "amount": {"bsonType": "double", "minimum": 0},
                    "currency": {"bsonType": "string"},
                    "sender": {"bsonType": "string"},
                    "receiver": {"bsonType": "string"},
                    "compliance_status": {"bsonType": "string"}
                }
            }
        },
        validationLevel="moderate",
    )
    print(f"Collection {TRANSACTIONS} created successfully")
  else:
    print(f"Collection {TRANSACTIONS} already exists")


  # Regulations collections
  if REGULATIONS not in existing_collections:
    db.create_collection(REGULATIONS)
    print(f"Collection {REGULATIONS} created successfully")
  else:
    print(f"Collection {REGULATIONS} already exists")


  # Checkpoints collections
  if CHECKPOINTS not in existing_collections:
    db.create_collection(CHECKPOINTS)
    print(f"Collection {CHECKPOINTS} created successfully")
  else:
    print(f"Collection {CHECKPOINTS} already exists")

  # Checkpoint Writes collection
  if CHECKPOINTS_WRITES not in existing_collections:
    db.create_collection(CHECKPOINTS_WRITES)
    print(f"Collection {CHECKPOINTS_WRITES} created successfully")



create_mongodb_collections()

Vector Search Index Creation

In other to enable semantic search with embedded documents vectors, we need to
enable a vector search index on the regulations collection.

We're using cosine similarity because we're doing semantic search and document classification problems.

see this for https://www.pinecone.io/learn/vector-similarity/


We also poll for readyness after creating the search index as attempting search on unready index causes error

In [None]:
import time
from pymongo.operations import SearchIndexModel

VECTOR_INDEX_NAME = "regulation_vector_index"

def create_vector_search_index():
  try:
    existing_regulation_indexes = regulations_collection.list_search_indexes()
    for index in existing_regulation_indexes:
      if index["name"] == VECTOR_INDEX_NAME:
        print(f"Vector search index {VECTOR_INDEX_NAME} already exists")
        return
  except Exception as e:
    print(f"Error getting indexes: {e}")
    return

  #create new vector search index
  search_index_model = SearchIndexModel(
      definition={
          "fields": [
              {
                  "type": "vector",
                  "path": "embedding",
                  "similarity": "cosine",
                  "numDimensions": 1024
              }
          ]
      },
      name=VECTOR_INDEX_NAME,
      type="vectorSearch"
  )

  try:
    new_search_index = regulations_collection.create_search_index(search_index_model)
    print(f"Vector search index {VECTOR_INDEX_NAME} created successfully and is building")
  except Exception as e:
    print(f"Error creating vector search index: {e}")

  #wait for sync
  print("Polling to check if vector search index is ready")
  predicate = lambda i: i.get("queryable") is True

  while True:
    try:
      indexes = list(regulations_collection.list_search_indexes(new_search_index))
      if indexes:
        if predicate(indexes[0]):
          break
      time.sleep(5)
    except Exception as e:
      print(f"Error polling for vector search index: {e}")
      break

  print(f"{new_search_index} is ready for querying.")

create_vector_search_index()

Document Processing and Schema Definition



In [None]:
import io
import re
from datetime import datetime

from docx import Document
from pydantic import BaseModel, Field
from pypdf import PdfReader
from typing import Any, Dict, List, Optional, Union

class RegulationDocument(BaseModel):
  # Schema for regulatory Documents

  id: Optional[str] = None
  title: str
  content: str
  source: str
  document_type: str
  jurisdiction: str
  publication_date: str
  tags: List[str] = Field(default_factory=list)
  embedding: Optional[List[float]] = None
  chunks: Optional[List[Dict[str, Any]]] = None

  def to_dict(self):
    return self.model_dump(exclude_none=True)



class DocumentProcessor:
  """Processes different document formats and extracts text"""

  @staticmethod
  def extract_text_from_pdf(file_path_or_bytes):
    if isinstance(file_path_or_bytes, str):
      reader = PdfReader(file_path_or_bytes)
    else:
      reader = PdfReader(io.BytesIO(file_path_or_bytes))

    text = ""
    for page in reader.pages:
      text += page.extract_text() + "\n"
    return text

  @staticmethod
  def extract_text_from_docx(file_path_or_bytes):
    if isinstance(file_path_or_bytes, str):
      document = Document(file_path_or_bytes)
    else:
      document = Document(io.BytesIO(file_path_or_bytes))

    text = ""
    for paragraph in document.paragraphs:
      text += paragraph + "\n"
    return text

  @staticmethod
  def extract_text_from_txt(file_path_or_bytes):
    if isinstance(file_path_or_bytes, str):
      with open(file_path_or_bytes, encoding="utf-8") as f:
        return f.read()
    else:
      return file_path_or_bytes.decode("utf-8")


  @staticmethod
  def process_document(file_path, metadata=None):
    if metadata is None:
      metadata = {}

    file_extension = file_path.split('.')[-1].lower()

    if file_extension == "pdf":
      text = DocumentProcessor.extract_text_from_pdf(file_path)
      doc_type = "pdf"
    elif file_extension == "docx":
      text = DocumentProcessor.extract_text_from_docx(file_path)
      doc_type  = "docx"
    elif file_extension == "txt":
      text = DocumentProcessor.extract_text_from_txt(file_path)
      doc_type = "txt"
    else:
      raise ValueError(f"Unsupported file format: {file_extension}")

    if "title" not in metadata:
      title = os.path.basename(file_path).rsplit('.', 1)[0]
      metadata["title"] = title

    if "document_type" not in metadata:
      metadata["document_type"] = doc_type

    regulation = RegulationDocument(
        title=metadata.get("title", ""),
        content=text,
        source=metadata.get("source", file_path),
        document_type=metadata.get("document_type", doc_type),
        jurisdiction=metadata.get("jurisdiction", "Unknown"),
        publication_date=metadata.get(
            "publication_date", datetime.now().strftime("%Y-%m-%d")
        ),
        tags=metadata.get("tags", []),
    )

    return regulation

  @staticmethod
  def extract_metadata_from_content(content):
    metadata = {}

    # Extract jurisdiction
    jurisdiction_pattern = r"(?i)jurisdiction[:\s]+(\w+(?:\s+\w+)*)"
    jurisdiction_match = re.search(jurisdiction_pattern, content)
    if jurisdiction_match:
        metadata["jurisdiction"] = jurisdiction_match.group(1).strip()

    # Extract date
    date_pattern = r"(?i)(?:date|published)[:\s]+(\d{1,2}[/-]\d{1,2}[/-]\d{2,4}|\d{4}[/-]\d{1,2}[/-]\d{1,2})"
    date_match = re.search(date_pattern, content)
    if date_match:
        metadata["publication_date"] = date_match.group(1).strip()

    # Extract tags
    tags_pattern = r"(?i)(?:keywords|tags)[:\s]+([\w\s,]+)"
    tags_match = re.search(tags_pattern, content)
    if tags_match:
        tags = [tag.strip() for tag in tags_match.group(1).split(",")]
        metadata["tags"] = tags

    return metadata


Text Processing, Embedding Generation and Storage

A few notes:

1. We implement both chunk level and document level embeddings.
This is because it helps with hierarchical retreivel. This means we have a 2 staged retrivel process where we first find the relevant document before finding the relevant chunks within those documents

2. We use langchains RecursiveCharacterTextSplitter to preserve semantics

3. Proper rate limiting by tracking time between subsequent calls

In [None]:
set_env_variables("VOYAGE_API_KEY")

In [None]:
import voyageai
from langchain_text_splitters import RecursiveCharacterTextSplitter


class TextProcessor:

  last_voyage_call = 0
  _instance = None

  # use singleton pattern to only have one instance throughout the app's lifetime
  def __new__(cls, *args, **kwargs):
    if cls._instance is None:
      cls._instance = super().__new__(cls)
      cls._instance._initialized = False
    return cls._instance

  def __init__(self, chunk_size=1000, chunk_overlap=200):
    if not hasattr(self, '_initialized') or not self._initialized:
      self.chunk_size = chunk_size
      self.chunk_overlap = chunk_overlap
      self.text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                length_function=len,
                separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
            )
      self.voyage_client = voyageai.Client(api_key=os.environ["VOYAGE_API_KEY"])
      self.model_name = "voyage-3"
      self._initialized = True

  def chunk_text(self, text):
    print(text)
    return self.text_splitter.split_text(text)

  def generate_embeddings(self, texts: List[str]):
    if not texts:
      return []

    current_time = time.time()
    time_since_last_call = current_time - self.last_voyage_call

    if time_since_last_call < 20:
      wait_time = 20 - time_since_last_call
      print(
          f"Due to rate limiting, we're waiting {wait_time} seconds"
      )
      time.sleep(wait_time)
    embeddings = self.voyage_client.embed(texts, model=self.model_name).embeddings

    self.last_voyage_call = time.time()
    return embeddings

  def process_document(self, regulation_doc):
    chunks = self.chunk_text(regulation_doc.content)
    chunk_embeddings = self.generate_embeddings(chunks)
    processed_chunks = []

    for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
      processed_chunks.append({
          "chunk_id": i,
          "content": chunk,
          "embedding": embedding
      })

    doc_text = f"{regulation_doc.title}\n{chunks[0] if chunks else ''}"
    doc_embedding = self.generate_embeddings(doc_text)[0]

    regulation_doc.embedding = doc_embedding
    regulation_doc.chunks = processed_chunks

    return regulation_doc

  def store_regulation(self, regulation_doc):
    regulation_dict = regulation_doc.to_dict()

    result = regulations_collection.insert_one(regulation_dict)
    print(f"Stored regulation document with ID: {result.inserted_id}")

    return result.inserted_id


In [None]:
# Sample regulatory texts
sample_regulations = [
    {
        "title": "Anti-Money Laundering Directive",
        "content": """ANTI-MONEY LAUNDERING DIRECTIVE
Jurisdiction: European Union
Date: 2021-06-15
Keywords: AML, KYC, financial crime, cross-border

Section 1: Scope and Definitions
1.1 This directive applies to all financial institutions operating within the European Union that process cross-border transactions.
1.2 'Cross-border transaction' refers to any financial transfer that originates in one country and terminates in another.
1.3 'High-risk jurisdiction' refers to countries identified by the Financial Action Task Force (FATF) as having strategic deficiencies in their AML/CFT regimes.

Section 2: Due Diligence Requirements
2.1 Enhanced due diligence must be performed for all transactions exceeding €10,000 that involve high-risk jurisdictions.
2.2 Financial institutions must verify the identity of both the sender and recipient for all cross-border transactions exceeding €3,000.
2.3 For transactions with sanctioned countries, prior approval must be obtained from the compliance department.

Section 3: Reporting Requirements
3.1 All suspicious transactions must be reported to the national Financial Intelligence Unit within 24 hours of detection.
3.2 Monthly reports must be submitted detailing all cross-border transactions exceeding €50,000.
3.3 Failure to report suspicious activities may result in fines of up to €5 million or 10% of annual turnover.
""",
        "source": "EU Financial Regulatory Authority",
        "document_type": "directive",
        "jurisdiction": "European Union",
        "publication_date": "2021-06-15",
        "tags": ["AML", "KYC", "financial crime", "cross-border"],
    },
    {
        "title": "Sanctions Compliance Framework",
        "content": """SANCTIONS COMPLIANCE FRAMEWORK
Jurisdiction: United States
Date: 2022-03-10
Keywords: sanctions, OFAC, restricted parties, compliance

Section 1: Overview
1.1 This framework outlines compliance requirements for financial institutions regarding transactions subject to sanctions administered by the Office of Foreign Assets Control (OFAC).
1.2 All US financial institutions and their foreign branches must comply with these requirements.

Section 2: Prohibited Transactions
2.1 No financial institution shall process transactions involving entities listed on the Specially Designated Nationals (SDN) list.
2.2 Transactions with entities in comprehensively sanctioned countries including Iran, North Korea, Syria, Cuba, and the Crimea region are prohibited without specific OFAC authorization.
2.3 Transactions that attempt to circumvent sanctions through third-party intermediaries are strictly prohibited and subject to severe penalties.

Section 3: Screening Requirements
3.1 All parties to a transaction must be screened against the most current OFAC sanctions lists prior to processing.
3.2 Screening must include beneficial owners with 25% or greater ownership interest.
3.3 Institutions must implement real-time screening for all international wire transfers regardless of amount.

Section 4: Penalties for Non-Compliance
4.1 Civil penalties may reach the greater of $1,000,000 per violation or twice the value of the transaction.
4.2 Criminal penalties for willful violations may include fines up to $20 million and imprisonment up to 30 years.
4.3 Financial institutions may be subject to regulatory actions including restrictions on activities or loss of licenses.
""",
        "source": "US Department of Treasury",
        "document_type": "framework",
        "jurisdiction": "United States",
        "publication_date": "2022-03-10",
        "tags": ["sanctions", "OFAC", "restricted parties", "compliance"],
    },
]

In [None]:
text_processor = TextProcessor()

for reg_data in sample_regulations:
  regulation = RegulationDocument(**reg_data)

  processed_regulation = text_processor.process_document(regulation)
  regulation_id = text_processor.store_regulation(processed_regulation)

  print(f"Processed and stored regulation: {regulation.title}")

Transaction Data Models and Compliance Status



In [None]:
from enum import Enum
from typing import Any, Dict, Optional

from pydantic import BaseModel,  field_validator

class ComplianceStaus(str, Enum):
  COMPLIANT = "Compliant"
  REPORTING_REQUIRED = "Reporting Required"
  VIOLATION = "Violation"
  PENDING = "Pending Assessment"


class TransactionParty(BaseModel):
  name: str
  country: str
  account_number: str
  institution: str
  is_sanctioned: bool = False
  risk_score: Optional[float] = None

class Transaction(BaseModel):
  id: Optional[str] = None
  transaction_id: str
  amount: float
  currency: str
  sender: TransactionParty
  receiver: TransactionParty
  transaction_date: str
  transaction_type: str
  description: str
  compliance_status: ComplianceStaus = ComplianceStaus.PENDING
  compliance_details: Optional[Dict[str, Any]] = None

  @field_validator("amount")
  def amount_must_be_positive(cls, v):
    if v <= 0:
      raise ValueError("Amount must be positive")
    return v

  def to_dict(self):
    return self.model_dump(exclude_none=True)

  def to_prompt(self):
    return f"""Transaction Details:
      - Transaction ID: {self.transaction_id}
      - Amount: {self.amount} {self.currency}
      - Date: {self.transaction_date}
      - Type: {self.transaction_type}
      - Description: {self.description}

      Sender Information:
      - Name: {self.sender.name}
      - Country: {self.sender.country}
      - Institution: {self.sender.institution}
      - Sanctioned: {self.sender.is_sanctioned}

      Receiver Information:
      - Name: {self.receiver.name}
      - Country: {self.receiver.country}
      - Institution: {self.receiver.institution}
      - Sanctioned: {self.receiver.is_sanctioned}
    """



Compliance monitoring engine and workflow orchestrator

We use the following:

1. HF ShiledGemma: For classification
2. MongoDb Atlas for vector search, checkpoints and transaction storage
3. Voyage AI for generating text embeddings
4. LangGraph for stateful  workflow orchestration


Compliance Engine:
1. Retrieves relevant regulation associated with transaction via vector ssearch
2. checks compliance assesment using LLM and returns JSON
3. Normalize confidence score using an activation fn -> softmax
4. updates transaction records with compliance status and details


Compliance Workflow:
1. Defines a LangGraph based workflow for processing transactions
2. Includes checkpointing, error handling and conditional retries


In [None]:
set_env_variables("HUGGINGFACE_API_KEY")

In [None]:
import numpy as np
import torch
from huggingface_hub import login
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface import HuggingFacePipeline
from torch.nn.functional import softmax
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

class ComplianceEngine:

  MODEL = "google/shieldgemma-2b"

  def __init__(self):
    login(token=os.environ.get("HUGGINGFACE_API_KEY"))
    self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL)
    model = AutoModelForCausalLM.from_pretrained(
        self.MODEL,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )

    text_generation_pipeline = pipeline(
        "text-generation",
        model=model,
        tokenizer = self.tokenizer,
        max_new_tokens=1024,
        do_sample=False,
        pad_token_id=self.tokenizer.eos_token_id #https://github.com/huggingface/transformers/issues/34869
    )

    self.llm = HuggingFacePipeline(pipeline=text_generation_pipeline)

    self.text_processor = TextProcessor()

    self.assessment_prompt = ChatPromptTemplate.from_template(
        """You are a financial compliance expert with extensive knowledge of regulatory frameworks. Your task is to evaluate whether the following transaction complies with the specified regulations.

    Transaction Details:
    {transaction}

    Relevant Regulations:
    {regulations}

    Compliance Assessment Framework:
    - Compliant: Transaction fully adheres to all applicable regulations with no reporting requirements
    - Reporting Required: Transaction is legal but requires mandatory reporting to regulatory authorities
    - Violation: Transaction directly contravenes one or more regulatory requirements

    Step-by-step Analysis Process:
    1. Identify the transaction type and key participants
    2. Determine which specific regulations apply to this transaction
    3. Assess compliance with each applicable regulation
    4. Evaluate if reporting requirements exist
    5. Determine final compliance status

    Provide your assessment in the following JSON format:
    {{
        "status": "Compliant" | "Reporting Required" | "Violation",
        "confidence": <float between 0 and 1>,
        "reasoning": "<concise explanation with specific regulatory references>",
        "applicable_regulations": ["<specific regulation sections that apply>"],
        "recommended_actions": ["<actionable steps for compliance>"],
        "risk_factors": ["<key risk elements identified>"]
    }}

    Return ONLY the JSON object. No additional text, explanations, or formatting. YOU WILL BE PENALIZED IF YOU RETURN ANYTHING OTHER THAN THE JSON.
    """
    )

    self.parser = JsonOutputParser()

    # creates chain
    self.chain = self.assessment_prompt | self.llm | self.parser


  def retrieve_relevant_regulations(self, transaction: Transaction):
    transaction_text = transaction.to_prompt()
    transaction_embedding = self.text_processor.generate_embeddings(
        [transaction_text]
    )[0]

    vector_search_stage = {
        "$vectorSearch": {
            "index": VECTOR_INDEX_NAME,
            "queryVector": transaction_embedding,
            "path": "embedding",
            "numCandidates": 150,
            "limit": 5
        }
    }

    # Remove embedding, _id and chunks from result. As they're not neccessary at this point
    project_stage = {
        "$project":{
            "embedding": 0,
            "chunks": 0,
            "_id": 0
        }
    }

    pipeline = [vector_search_stage, project_stage]
    results = list(regulations_collection.aggregate(pipeline))

    regulations_text = ""
    for i, reg in enumerate(results, 1):
      regulations_text += f"Regulation {i}: {reg['title']} ({reg['jurisdiction']}, {reg['publication_date']})\n"
      regulations_text += f"{reg['content']}\n\n"

    return regulations_text


  def apply_softmax_normalization(self, assessment):
    status_scores = {"Compliant": 0.0, "Reporting Required": 0.0, "Violation": 0.0}

    status_scores[assessment["status"]] = assessment["confidence"]

    scores_array = np.array(list(status_scores.values()))
    normalized_scores = softmax(torch.tensor(scores_array), dim=0).numpy()

    assessment["confidence"] = float(
            normalized_scores[list(status_scores.keys()).index(assessment["status"])]
    )

    assessment["confidence_details"] = {
        status: float(score)
        for status, score in zip(status_scores.keys(), normalized_scores)
    }

    return assessment

  def assess_transaction(self, transaction: Transaction):
    try:
      regulations = self.retrieve_relevant_regulations(transaction)
      print("-" * 80)
      print(f"Retrieved {len(regulations)} relevant regulations.")
      print("Here are the first 100 characters of the first regulation:")
      print(regulations.split("\n")[0][:100])

      inputs = {
          "transaction": transaction.to_prompt(),
          "regulations": regulations
      }
      print("before invoking chain")
      assessment = self.chain.invoke(inputs)
      print("after invoking chain")

      assessment = self.apply_softmax_normalization(assessment)
      print("after invoking apply_softmax_normalization")

      transaction.compliance_status = ComplianceStaus(assessment["status"])
      transaction.compliance_details = assessment

      if transaction.id:
        transaction_collection.update_one(
            {"_id": transaction.id}, {"$set": transaction.to_dict()}
        )
        print(f"Update transaction with ID: {transaction.id}")
      else:
        result = transaction_collection.insert_one(transaction.to_dict())
        transaction.id = str(result.inserted_id)
        print(f"Stored transaction with ID: {transaction.id}")

      return assessment

    except Exception as e:
      print(f"Error during Langchain processing: {e}")
      assessment = {
          "status": "Reporting Required",
          "confidence": 0.5,
          "reasoning": f"Error during assessment: {e!s}. Please review the transaction manually.",
          "applicable_regulations": [],
          "recommended_actions": ["Review transaction manually"],
          "risk_factors": ["Assessment processing failure"],
      }
      return assessment





IMplement Agent Orchestration with Langraph to coordinate the compliance assessment workflow

In [None]:
from typing import Any, Dict, List, Optional, TypedDict

import langgraph.graph as lg
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.checkpoint.mongodb import MongoDBSaver


# Graphs state
class ComplianceState(TypedDict):
  transaction: Dict[str, Any]
  regulations: Optional[List[Dict[str, Any]]]
  assessment: Optional[Dict[str, Any]]
  messages: List[Union[HumanMessage, AIMessage]]
  errors: Optional[List[str]]

class ComplianceWorkflow:

  def __init__(self):
    self.compliance_engine = ComplianceEngine()
    self.text_processor = TextProcessor()

    self.checkpoint_store = MongoDBSaver(client, DB_NAME, CHECKPOINTS)

    self.workflow = self._build_graph()

  def _parse_transaction(self, state:ComplianceState) -> ComplianceState:
    try:
      transaction_data = state["transaction"]
      transaction = Transaction(**transaction_data)

      state["transaction"] = transaction.to_dict()
      state["messages"].append(
          AIMessage(content=f"Transaction {transaction.transaction_id} parsed successfully.")
      )
    except Exception as e:
      error_msg = f"Error parsing transaction: {e!s}"
      state["errors"] = state.get("errors", []) + [error_msg]
      state["messages"].append(AIMessage(content=error_msg))

    return state

  def _retrieve_regulations(self, state: ComplianceState) -> ComplianceState:
    try:
      transaction_data = state["transaction"]
      transaction = Transaction(**transaction_data)

      regulations_text = self.compliance_engine.retrieve_relevant_regulations(
          transaction
      )

      state["regulations"] = regulations_text
      state["messages"].append(
          AIMessage(content="Retrieved relevant regulations for compliance assessment")
      )

    except Exception as e:
      error_msg = f"Error retrieving regulations: {e!s}"
      state["errors"] = state.get("errors", []) + [error_msg]
      state["messages"].append(AIMessage(content=error_msg))

    return state

  def _assess_compliance(self, state: ComplianceState) -> ComplianceState:
    try:
      transaction_data = state["transaction"]
      transaction = Transaction(**transaction_data)

      assessment = self.compliance_engine.assess_transaction(transaction)

      state["assessment"] = assessment
      state["transaction"] = (
          transaction.to_dict()
      )
      summary = f"Compliance assessment complete. Status: {assessment['status']} (Confidence: {assessment['confidence']:.2f})\n"
      summary += f"Reasoning: {assessment['reasoning']} \n"
      if assessment.get("recommended_actions"):
        summary += f"Recommended actions: {', '.join(assessment['recommended_actions'])}\n"

      state["messages"].append(AIMessage(content=summary))
    except Exception as e:
        error_msg = f"Error assessing compliance: {e!s}"
        state["errors"] = state.get("errors", []) + [error_msg]
        state["messages"].append(AIMessage(content=error_msg))

    return state

  def _should_retry(self, state: ComplianceState) -> str:
    if state.get("errors") and len(state["errors"]) < 3:
        return "retry"
    return "end"

  def _build_graph(self):

    builder = lg.StateGraph(ComplianceState)

    # nodes
    builder.add_node("parse_transaction", self._parse_transaction)
    builder.add_node("retrieve_regulations", self._retrieve_regulations)
    builder.add_node("assess_compliance", self._assess_compliance)

    # edges
    builder.add_edge("parse_transaction", "retrieve_regulations")
    builder.add_edge("retrieve_regulations", "assess_compliance")

    # conditional edge for error handling
    builder.add_conditional_edges(
        "assess_compliance",
        self._should_retry,
        {"retry": "parse_transaction", "end": lg.END}
    )

    builder.set_entry_point("parse_transaction")

    return builder.compile(checkpointer=self.checkpoint_store)

  def process_transaction(self, transaction_data: Dict[str, Any]) -> Dict[str, Any]:
    initial_state = ComplianceState(
        transaction=transaction_data,
        regulations=None,
        assessment=None,
        messages=[
            HumanMessage(
                content=f"Process transaction {transaction_data.get('transaction_id', 'unknown')}"
            )
        ],
        errors=None,
    )

    config = {"configurable": {"thread_id": "2"}}
    final_state = self.workflow.invoke(initial_state, config)

    return final_state


In [None]:
# Sample transactions for demonstration
sample_transactions = [
    {
        "transaction_id": "TX123456789",
        "amount": 150000.00,
        "currency": "EUR",
        "sender": {
            "name": "European Trading Ltd",
            "country": "Germany",
            "account_number": "DE89370400440532013000",
            "institution": "Deutsche Bank",
            "is_sanctioned": False,
        },
        "receiver": {
            "name": "Global Imports Inc",
            "country": "United States",
            "account_number": "US12345678901234567890",
            "institution": "Bank of America",
            "is_sanctioned": False,
        },
        "transaction_date": "2023-11-15",
        "transaction_type": "International Wire Transfer",
        "description": "Payment for machinery parts",
    },
    {
        "transaction_id": "TX987654321",
        "amount": 75000.00,
        "currency": "USD",
        "sender": {
            "name": "American Exports LLC",
            "country": "United States",
            "account_number": "US98765432109876543210",
            "institution": "JP Morgan Chase",
            "is_sanctioned": False,
        },
        "receiver": {
            "name": "Tehran Trading Co",
            "country": "Iran",
            "account_number": "IR123456789012345678901234",
            "institution": "Bank Melli Iran",
            "is_sanctioned": True,
        },
        "transaction_date": "2023-12-01",
        "transaction_type": "International Wire Transfer",
        "description": "Consulting services",
    },
]

In [36]:
workflow = ComplianceWorkflow()

results = []
for tx_data in sample_transactions:
    print(f"\nProcessing transaction {tx_data['transaction_id']}...")
    result = workflow.process_transaction(tx_data)
    results.append(result)

    for message in result["messages"]:
        if isinstance(message, AIMessage):
            print(f"System: {message.content}")

    if result.get("assessment"):
        assessment = result["assessment"]
        print(f"\nFinal Assessment for {tx_data['transaction_id']}:")
        print(f"Status: {assessment['status']}")
        print(f"Confidence: {assessment['confidence']:.2f}")
        print(f"Reasoning: {assessment['reasoning']}")
        if assessment.get("risk_factors"):
            print(f"Risk Factors: {', '.join(assessment['risk_factors'])}")
        if assessment.get("applicable_regulations"):
            print(
                f"Applicable Regulations: {', '.join(assessment['applicable_regulations'])}"
            )
        if assessment.get("recommended_actions"):
            print(
                f"Recommended Actions: {', '.join(assessment['recommended_actions'])}"
            )
        print("-" * 80)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Device set to use cpu



Processing transaction TX123456789...
Due to rate limiting, we're waiting 19.31210160255432 seconds
--------------------------------------------------------------------------------
Retrieved 7789 relevant regulations.
Here are the first 100 characters of the first regulation:
Regulation 1: Anti-Money Laundering Directive (European Union, 2021-06-15)
before invoking chain
after invoking chain
after invoking apply_softmax_normalization
Error during Langchain processing: Document failed validation, full error: {'index': 0, 'code': 121, 'errmsg': 'Document failed validation', 'errInfo': {'failingDocumentId': ObjectId('682350ae8fc97133b9f004ea'), 'details': {'operatorName': '$jsonSchema', 'schemaRulesNotSatisfied': [{'operatorName': 'properties', 'propertiesNotSatisfied': [{'propertyName': 'sender', 'details': [{'operatorName': 'bsonType', 'specifiedAs': {'bsonType': 'string'}, 'reason': 'type did not match', 'consideredValue': {'name': 'European Trading Ltd', 'country': 'Germany', 'accoun

KeyboardInterrupt: 