In [1]:
!pip install -q -r requirements.txt

In [2]:
import openai
import re
import os
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity


In [3]:
from lib.openai_client import OpenAIClient

In [4]:
# Define cache directories
xml_cache_dir = './cache/xml'
contract_cache_dir = './cache/contracts_xml'
os.makedirs(xml_cache_dir, exist_ok=True)
os.makedirs(contract_cache_dir, exist_ok=True)



In [5]:
# Load the prompt
with open('./prompts/generate_xml_from_contract.txt', 'r') as file:
    xml_generation_prompt = file.read()

with open('./prompts/generate_contract_from_xml.txt', 'r') as file:
    contract_generation_prompt = file.read()

# Load the XSD schema
with open('./schema/contract_schema.xsd', 'r') as file:
    xsd_schema = file.read()

In [6]:
def load_contracts():
    contract_dir = './contracts'
    contract_files = [f for f in os.listdir(contract_dir) if f.endswith('.txt')]
    contract_files.sort()
    return contract_files  # Return just the filenames, not full paths



In [7]:
def cleanup_contract(contract_text):
    # Remove <scratchpad> section
    contract_text = re.sub(r'<scratchpad>.*?</scratchpad>', '', contract_text, flags=re.DOTALL)
    
    # Remove any other XML-like tags
    contract_text = re.sub(r'<.*?>', '', contract_text)
    
    # Remove any lines that only contain '---'
    contract_text = re.sub(r'^\s*---\s*$', '', contract_text, flags=re.MULTILINE)
    
    # Remove extra blank lines
    contract_text = re.sub(r'\n\s*\n', '\n\n', contract_text)
    
    # Strip leading and trailing whitespace
    contract_text = contract_text.strip()
    
    return contract_text

In [8]:
def cleanup_xml(xml_text):
    # Find the content between <contract> tags
    match = re.search(r'<contract>.*?</contract>', xml_text, re.DOTALL)
    if match:
        # Extract the matched content
        cleaned_xml = match.group(0)
        # Remove any leading/trailing whitespace
        cleaned_xml = cleaned_xml.strip()
        return cleaned_xml
    else:
        # If no <contract> tags are found, return the original text
        print("Warning: No <contract> tags found in the XML. Returning original text.")
        return xml_text

In [9]:
# Load contract files
contract_files = load_contracts()

In [10]:
openaiclient = OpenAIClient()

In [11]:
def get_cache_filename(base_dir, original_file, extension):
    base_name = os.path.basename(original_file)
    return os.path.join(base_dir, base_name.replace('.txt', extension))


In [12]:
def generate_xml_object(contract_text, schema):
    openaiclient.reset_context()
    prompt = xml_generation_prompt.replace("{{XML_SCHEMA}}", schema).replace("{{CONTRACT}}", contract_text)
    openaiclient.add_message("user", prompt)
    response = openaiclient.get_response()
    
    # Print the raw response for debugging
    print("Raw API response:")
    print(response)
    
    # Check if the response contains the expected tags
    if "<xml_output>" in response and "</xml_output>" in response:
        xml_output = response.split("<xml_output>")[1].split("</xml_output>")[0].strip()
        xml_output = cleanup_xml(xml_output)
        return xml_output
    else:
        # If tags are missing, return the entire response
        print("Warning: <xml_output> tags not found in the response. Returning full response.")
        xml_output = cleanup_xml(response.strip())
        return xml_output


In [13]:
def generate_contract_from_xml(xml_document, schema):
    openaiclient.reset_context()
    prompt = contract_generation_prompt.replace("{{XML_SCHEMA}}", schema).replace("{{XML_DOCUMENT}}", xml_document)
    openaiclient.add_message("user", prompt)
    response = openaiclient.get_response()
    return response.strip()

In [14]:
def process_contract_to_xml(contract_file, schema):
    cache_file = get_cache_filename(xml_cache_dir, contract_file, '.xml')
    
    if os.path.exists(cache_file):
        print(f"Using cached XML for {contract_file}")
        with open(cache_file, 'r') as f:
            return f.read()
    
    print(f"Generating new XML for {contract_file}")
    try:
        with open(contract_file, 'r') as f:
            contract_text = f.read()
        
        xml_object = generate_xml_object(contract_text, schema)
        
        with open(cache_file, 'w') as f:
            f.write(xml_object)
        
        return xml_object
    except Exception as e:
        print(f"Error processing {contract_file}: {str(e)}")
        return None


In [15]:
def generate_new_xml_from_reconstructed(reconstructed_contract, schema, original_contract_file):
    cache_file = get_cache_filename(xml_cache_dir, original_contract_file, '_new.xml')
    
    if os.path.exists(cache_file):
        print(f"Using cached new XML for {original_contract_file}")
        with open(cache_file, 'r') as f:
            return f.read()
    
    print(f"Generating new XML from reconstructed contract for {original_contract_file}")
    new_xml_object = generate_xml_object(reconstructed_contract, schema)
    
    with open(cache_file, 'w') as f:
        f.write(new_xml_object)
    
    return new_xml_object

In [16]:
def process_xml_to_contract(xml_object, original_contract_file, schema):
    cache_file = get_cache_filename(contract_cache_dir, original_contract_file, '_reconstructed.txt')
    
    if os.path.exists(cache_file):
        print(f"Using cached reconstructed contract for {original_contract_file}")
        with open(cache_file, 'r') as f:
            return f.read()
    
    print(f"Generating new contract from XML for {original_contract_file}")
    reconstructed_contract = generate_contract_from_xml(xml_object, schema)

    # Clean up the reconstructed contract
    reconstructed_contract = cleanup_contract(reconstructed_contract)

    with open(cache_file, 'w') as f:
        f.write(reconstructed_contract)

    return reconstructed_contract

In [17]:
def compare_embeddings(original_embeddings, new_embeddings):
    similarities = []
    for orig_emb, new_emb in zip(original_embeddings, new_embeddings):
        similarity = cosine_similarity([orig_emb], [new_emb])[0][0]
        similarities.append(similarity)
    return similarities

In [18]:
def create_embeddings(xmls):
    embeddings = []
    for xml in xmls:
        embedding = openaiclient.get_embedding(xml)
        embeddings.append(embedding)
    return embeddings

In [19]:
def generate_data():
    contract_dir = './contracts'
    original_xmls = []
    new_xmls = []
    
    for contract_file in contract_files:
        full_path = os.path.join(contract_dir, contract_file)
        
        # Generate original XML
        original_xml = process_contract_to_xml(full_path, xsd_schema)
        if original_xml is not None:
            # Clean up the original XML
            original_xml = cleanup_xml(original_xml)
            print(f"Original XML for {contract_file}:")
            print(original_xml[:50] + "...")
            print("\n")
            
            # Store original XML
            original_xmls.append(original_xml)
            
            # Reconstruct contract from XML
            reconstructed_contract = process_xml_to_contract(original_xml, contract_file, xsd_schema)
            
            print(f"Reconstructed contract for {contract_file}:")
            print(reconstructed_contract[:50] + "...")
            print("\n")
            
            # Generate new XML from reconstructed contract
            new_xml = generate_new_xml_from_reconstructed(reconstructed_contract, xsd_schema, contract_file)
            
            # Clean up the new XML
            new_xml = cleanup_xml(new_xml)
            print(f"New XML generated from reconstructed contract for {contract_file}:")
            print(new_xml[:50] + "...")
            print("\n")
            
            # Store new XML
            new_xmls.append(new_xml)
            
        else:
            print(f"Skipping processing for {contract_file} due to XML generation error.")
    
    return original_xmls, new_xmls

In [20]:
# Generate data
original_xmls, new_xmls = generate_data()

# Create embeddings
print("Creating embeddings for original XMLs...")
original_embeddings = create_embeddings(original_xmls)
print("Creating embeddings for new XMLs...")
new_embeddings = create_embeddings(new_xmls)

# Compare embeddings
print("Comparing embeddings...")
similarities = compare_embeddings(original_embeddings, new_embeddings)

# Print results
for i, similarity in enumerate(similarities):
    print(f"Similarity for contract {i+1}: {similarity:.4f}")

print(f"\nAverage similarity: {np.mean(similarities):.4f}")
print(f"Minimum similarity: {np.min(similarities):.4f}")
print(f"Maximum similarity: {np.max(similarities):.4f}")

Using cached XML for ./contracts/agreement_01.txt
Original XML for agreement_01.txt:
<contract>
    <parties>
        <party>
         ...


Using cached reconstructed contract for agreement_01.txt
Reconstructed contract for agreement_01.txt:
**AMENDED AND RESTATED AGREEMENT**

This AMENDED A...


Using cached new XML for agreement_01.txt
New XML generated from reconstructed contract for agreement_01.txt:
<contract>
  <parties>
    <party>
      <name>Equ...


Using cached XML for ./contracts/employment_01.txt
Original XML for employment_01.txt:
<contract>
  <parties>
    <party>
      <name>Par...


Using cached reconstructed contract for employment_01.txt
Reconstructed contract for employment_01.txt:
**Labor Contract**

This Labor Contract (the "Cont...


Using cached new XML for employment_01.txt
New XML generated from reconstructed contract for employment_01.txt:
<contract>
    <parties>
      <party>
        <na...


Using cached XML for ./contracts/lease_01.txt
Original XML for l