##### Install required packages

In [None]:
!pip install boto3

In [None]:
!pip install torch transformers diffgram neo4j anthropic pandas tqdm
!pip install llama_index

##### # Import necessary libraries

In [None]:
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from diffgram import Project
from typing import List, Dict, Optional
import anthropic
import json
from neo4j import GraphDatabase
from tqdm import tqdm
import logging
import os
import sys
import boto3
import requests

In [None]:
# Use os.getcwd() since __file__ is not available in interactive environments
current_dir = os.getcwd()

# If your structure is such that the package is in the parent directory, compute the parent directory:
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))

# Add the parent directory to sys.path if it's not already there
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

In [None]:
from AgenticWorkflow.bedrock_session import get_boto_session

In [None]:
session = get_boto_session()

In [None]:
bedrock_runtime = session.client("bedrock-runtime", region_name="us-east-1")

In [None]:
def get_claudia_kwargs(prompt):
    kwargs = {
      "modelId": "anthropic.claude-3-5-sonnet-20240620-v1:0",
      "contentType": "application/json",
      "accept": "application/json",
      "body": json.dumps({
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 5000,
        "messages": [
          {
            "role": "user",
            "content": [
              {
                "type": "text",
                "text": prompt
              }
            ]
          }
        ]
      })
    }
    return kwargs

In [None]:
prompt = "Does this work?"

In [None]:
kwargs = get_claudia_kwargs(prompt)

In [None]:
def get_response(prompt):
    kwargs = get_claudia_kwargs(prompt)
    response = bedrock_runtime.invoke_model(**kwargs)
    response_body = json.loads(response.get("body").read())
    return response_body['content'][0]['text']

In [None]:
response = get_response(prompt)

In [None]:
response

#### Setting up logging

In [None]:
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
project = Project(host="https://5293-2604-3d08-4f7f-e8c0-a011-c635-235f-690d.ngrok-free.app",
        project_string_id = "translucenttracker",
        client_id = "LIVE__u3v8q0m7tx1p851dp0ap",
        client_secret = "1qgd8as7xfcbuem6mw9j1z0xvjfmmvlagbugqr8z1g1ntypugr2ul24cce5k"
      )

In [None]:
# Configuration
DIFFGRAM_CONFIG = {
    "host": "https://5293-2604-3d08-4f7f-e8c0-a011-c635-235f-690d.ngrok-free.app",
    "project_id": "translucenttracker",
    "client_id": "LIVE__u3v8q0m7tx1p851dp0ap",
    "client_secret": "1qgd8as7xfcbuem6mw9j1z0xvjfmmvlagbugqr8z1g1ntypugr2ul24cce5k"
}

In [None]:
# Define constants
BATCH_SIZE = 32
MAX_LENGTH = 256
NUM_TRAIN_SAMPLES = 10000  # Number of samples to use for training

## Import all the files 
### make sure you have the diffgram_processing_v2 folder which has all the data arranged for NER task

In [None]:
from llama_index.core import SimpleDirectoryReader, StorageContext

In [None]:
file_metadata = lambda x: {"filename": x}
diffgram_documents = SimpleDirectoryReader("diffgram_processing",file_metadata=file_metadata).load_data()

In [None]:
print(len(diffgram_documents))

In [None]:
print(diffgram_documents[500].text)

In [None]:
def get_sample_chunks(diffgram_dcouments, sample_size: int = 10) -> List[Dict]:
    """Get sample chunks from diffgram processing folder"""
    # Load all documents
    documents = diffgram_documents
    logger.info(f"Loaded {len(documents)} documents")
    
    # Sample documents ensuring diversity
    samples = []
    seen_types = set()
    
    # Randomly shuffle documents
    shuffled_docs = list(documents)
    random.shuffle(shuffled_docs)
    
    for doc in shuffled_docs:
        if len(samples) >= sample_size:
            break
            
        # Extract document type from metadata/text
        doc_type = None
        if 'Act' in doc.text:
            doc_type = 'Act'
        elif 'Regulation' in doc.text:
            doc_type = 'Regulation'
            
        if doc_type not in seen_types or len(seen_types) >= sample_size/2:
            samples.append({
                'text': doc.text,
                'type': doc_type,
                'filename': doc.metadata['filename']
            })
            if doc_type:
                seen_types.add(doc_type)
                
    logger.info(f"Selected {len(samples)} sample documents")
    return samples

## Schema Generation

In [552]:
import random
import re
from typing import List, Dict

def get_sample_chunks(diffgram_documents, sample_size: int = 10) -> List[Dict]:
    """Get sample chunks from diffgram processing folder using regex to detect Act and Regulation IDs."""
    logger.info(f"Loaded {len(diffgram_documents)} documents")
    
    samples = []
    seen_types = set()
    
    # Regex patterns to extract IDs
    act_pattern = re.compile(r"ActId:\s*([^\n]+)", re.IGNORECASE)
    reg_pattern = re.compile(r"RegId:\s*([^\n]+)", re.IGNORECASE)
    
    # Shuffle the documents to ensure randomness
    shuffled_docs = list(diffgram_documents)
    random.shuffle(shuffled_docs)
    
    for doc in shuffled_docs:
        if len(samples) >= sample_size:
            break

        doc_text = doc.text
        
        # Try to detect Regulation ID first; if found, mark as Regulation.
        reg_match = reg_pattern.search(doc_text)
        act_match = act_pattern.search(doc_text)
        
        if reg_match:
            doc_type = 'Regulation'
        elif act_match:
            doc_type = 'Act'
        else:
            # If no ID is found, you may choose to label as None or skip.
            doc_type = None

        # Decide if we should add this document.
        # Here we add the document if its type hasn't been added yet,
        # or if we've already collected at least half of the desired samples (to ensure diversity).
        if doc_type not in seen_types or len(seen_types) >= sample_size / 2:
            samples.append({
                'text': doc_text,
                'type': doc_type,
                'filename': doc.metadata.get('filename', 'unknown')
            })
            if doc_type:
                seen_types.add(doc_type)
                
    logger.info(f"Selected {len(samples)} sample documents")
    return samples


In [553]:
import random
import re
from typing import List, Dict

# Cell 4: Schema Generation

def generate_schema_prompt(samples: List[Dict]) -> str:
    """
    Create a prompt for generating a dynamic NER tagging schema for legal texts.
    The output must be a JSON object with:
      - a top-level "schema" key containing a list of BIO entity definitions,
      - and an optional "triplet_schema" key containing relationship definitions.
    
    The tagging should follow the BIO convention:
      - 'B-' indicates the beginning of an entity,
      - 'I-' indicates a continuation of an entity,
      - 'O' indicates tokens outside any entity.
      
    The entity definitions and relationship definitions should be dynamically derived from the provided sample texts.
    The output JSON must be directly compatible with Diffgram.
    """
    prompt = (
        "Analyze the following legal text samples and generate a comprehensive NER tagging schema tailored for legal documents. "
        "Your output must be a JSON object with two keys: \n"
        "  1. \"schema\": an array of entity definitions using BIO tagging (include 'O' for outside tokens), and \n"
        "  2. \"triplet_schema\": an array of relationship definitions capturing subject–predicate–object triplets (if applicable).\n\n"
        "For the BIO tagging schema, ensure you include definitions such as:\n"
        "  - O: tokens outside any named entity,\n"
        "  - B-<ENTITY_TYPE> and I-<ENTITY_TYPE> for the beginning and continuation of each entity type (e.g., internal section references, external act/regulation references, etc.).\n\n"
        "For the triplet schema, define how relationships between entities are represented, including subject, relation, and object, "
        "with examples drawn from the text.\n\n"
        "Output the schema in the following format (do not include any extra text):\n\n"
        "{\n"
        '  "schema": [\n'
        "    {\n"
        '      "name": "O",\n'
        '      "description": "Outside of any named entity",\n'
        '      "example": "The"\n'
        "    },\n"
        "    {\n"
        '      "name": "B-<ENTITY_TYPE>",\n'
        '      "description": "Description for the beginning of <ENTITY_TYPE>",\n'
        '      "example": "Dynamic example based on sample"\n'
        "    },\n"
        "    ...\n"
        "  ],\n"
        '  "triplet_schema": [\n'
        "    {\n"
        '      "subject": "<subject tag>",\n'
        '      "relation": "<relation type>",\n'
        '      "object": "<object tag>",\n'
        '      "description": "Description of the relationship",\n'
        '      "example": "Example triplet extracted from text"\n'
        "    }\n"
        "  ]\n"
        "}\n\n"
        "Now, analyze the following samples:\n"
    )
    
    for i, sample in enumerate(samples, 1):
        prompt += f"\nSample {i}:\n{sample['text'][:500]}...\n"
    
    return prompt

def validate_schema(schema: Dict) -> Dict:
    """Validate and enhance the generated schema by ensuring all required entities are present."""
    required_entities = {
        'references': [
            'INTERNAL_SECTION_REF',
            'EXTERNAL_ACT_REF',
            'EXTERNAL_REGULATION_REF'
        ],
        'structure': [
            'SECTION',
            'SUBSECTION',
            'PARAGRAPH',
            'SUBPARAGRAPH'
        ],
        'concepts': [
            'DEFINITION',
            'LEGAL_TERM',
            'TIME_PERIOD'
        ],
        'relationships': [
            'CROSS_REF',
            'AMENDMENT_REF',
            'PARENT_CHILD_REF'
        ]
    }
    
    # Safely retrieve the existing list of entity types; initialize if missing.
    existing_entities = {entity['name'] for entity in schema.get('entity_types', [])}
    
    # Iterate over each category and its required entity names.
    for category, entities in required_entities.items():
        for entity in entities:
            # For each entity, ensure both BIO variants are present.
            for prefix in ['B-', 'I-']:
                full_entity_name = f"{prefix}{entity}"
                if full_entity_name not in existing_entities:
                    schema.setdefault('entity_types', []).append({
                        'name': full_entity_name,
                        'description': f"{'Beginning' if prefix == 'B-' else 'Inside'} token for {entity}",
                        'example': f"Example usage of {entity} from text",
                        'category': category
                    })
    
    return schema


In [554]:
# Get sample documents
samples = get_sample_chunks(diffgram_documents, sample_size=10)

INFO:__main__:Loaded 81611 documents
INFO:__main__:Selected 10 sample documents


In [555]:
schema_prompt = generate_schema_prompt(samples)

In [556]:
# Generate schema
schema_response = get_response(schema_prompt)

In [557]:
schema_response = json.loads(schema_response)

In [558]:
print(schema_response)

{'schema': [{'name': 'O', 'description': 'Outside of any named entity', 'example': 'The'}, {'name': 'B-ACT', 'description': 'Beginning of an Act reference', 'example': 'Pension Benefits Standards Act'}, {'name': 'I-ACT', 'description': 'Continuation of an Act reference', 'example': 'Benefits Standards Act'}, {'name': 'B-REGULATION', 'description': 'Beginning of a Regulation reference', 'example': 'Pipeline Regulation'}, {'name': 'I-REGULATION', 'description': 'Continuation of a Regulation reference', 'example': 'Consumer Credit Regulation'}, {'name': 'B-SECTION', 'description': 'Beginning of a Section reference', 'example': 'section 41'}, {'name': 'I-SECTION', 'description': 'Continuation of a Section reference', 'example': 'subsection ( 2 )'}, {'name': 'B-LEGAL_TERM', 'description': 'Beginning of a legal term', 'example': 'collective agreement'}, {'name': 'I-LEGAL_TERM', 'description': 'Continuation of a legal term', 'example': 'solemn declaration'}, {'name': 'B-AUTHORITY', 'descripti

In [559]:
# ----------------------------------------------------------------------------
# Helper function to deduplicate an existing schema dictionary.
# It ensures that within both the NER ("schema") and triplet ("triplet_schema")
# sections, no duplicate entries exist.
# ----------------------------------------------------------------------------
def dedupe_schema(schema):
    # Deduplicate the NER labels using the "name" field (case-insensitive).
    if "schema" in schema:
        unique_labels = {}
        for label in schema["schema"]:
            # Use the uppercase, stripped version of the label name as the key.
            key = label["name"].strip().upper()
            unique_labels[key] = label
        schema["schema"] = list(unique_labels.values())
    
    # Deduplicate the triplet schema using the tuple (subject, relation, object) (case-insensitive).
    if "triplet_schema" in schema:
        unique_triplets = {}
        for triplet in schema["triplet_schema"]:
            # Use the uppercase, stripped versions for subject, relation, and object.
            subject_key = triplet["subject"].strip().upper()
            relation_key = triplet["relation"].strip().upper()
            object_key = triplet["object"].strip().upper()
            key = (subject_key, relation_key, object_key)
            unique_triplets[key] = triplet
        schema["triplet_schema"] = list(unique_triplets.values())
    
    return schema

In [560]:
# File where the full AI-generated schema (NER + triplet) will be stored.
schema_file = "ner_schema_response.json"
# ----------------------------------------------------------------------------
# Function to merge the new AI-generated schema with the saved schema.
# It appends only new labels and new triplet definitions (based on uniqueness).
# ----------------------------------------------------------------------------
def merge_schemas(new_schema, saved_schema):
    # First, deduplicate the saved schema in case there are any duplicates.
    saved_schema = dedupe_schema(saved_schema)
    
    # Create a copy of the saved schema to merge into.
    merged_schema = saved_schema.copy()

    # -----------------------------
    # Merge the NER labels ("schema" key)
    # -----------------------------
    if "schema" not in merged_schema:
        merged_schema["schema"] = []
    # Build a set of existing label names (normalized to uppercase).
    existing_labels = { label["name"].strip().upper() for label in merged_schema.get("schema", []) }
    # Add any new label that is not already present.
    for new_label in new_schema.get("schema", []):
        new_label_name = new_label["name"].strip().upper()
        if new_label_name not in existing_labels:
            merged_schema["schema"].append(new_label)
            existing_labels.add(new_label_name)

    # -----------------------------
    # Merge the Triplet schema ("triplet_schema" key)
    # -----------------------------
    if "triplet_schema" in new_schema:
        if "triplet_schema" not in merged_schema:
            merged_schema["triplet_schema"] = []
        # Build a set of existing triplets using normalized (subject, relation, object) as a key.
        existing_triplets = {
            (triplet["subject"].strip().upper(), triplet["relation"].strip().upper(), triplet["object"].strip().upper())
            for triplet in merged_schema.get("triplet_schema", [])
        }
        # Add any new triplet that is not already present.
        for new_triplet in new_schema.get("triplet_schema", []):
            triplet_key = (
                new_triplet["subject"].strip().upper(),
                new_triplet["relation"].strip().upper(),
                new_triplet["object"].strip().upper()
            )
            if triplet_key not in existing_triplets:
                merged_schema["triplet_schema"].append(new_triplet)
                existing_triplets.add(triplet_key)

    # Optionally dedupe again after merging.
    merged_schema = dedupe_schema(merged_schema)
    return merged_schema

# File where the full AI-generated schema (NER + triplet) is stored.
schema_file = "ner_schema_response.json"

# Assume schema_response contains the new AI-generated schema.
# (schema_response must be defined prior to this code.)

if os.path.exists(schema_file):
    with open(schema_file, "r") as infile:
        saved_schema = json.load(infile)
    print("Loaded saved schema from file.")

    # Merge the new schema with the saved schema.
    merged_schema = merge_schemas(schema_response, saved_schema)

    # Save the merged (and deduplicated) schema back to the file.
    with open(schema_file, "w") as outfile:
        json.dump(merged_schema, outfile, indent=2)
    print("Merged and updated schema saved to file.")

    # Update schema_response to reflect the merged result.
    schema_response = merged_schema
else:
    # If the file doesn't exist, save the new schema response directly.
    with open(schema_file, "w") as outfile:
        json.dump(schema_response, outfile, indent=2)
    print("Saved new schema response to file.")

# (Optional) Print the current schema_response to verify.
#print("Current schema_response:", json.dumps(schema_response, indent=2))

Loaded saved schema from file.
Merged and updated schema saved to file.


In [561]:
def get_schema_list(id):
    auth = project.session.auth
    url = f"{DIFFGRAM_CONFIG['host']}/api/project/{DIFFGRAM_CONFIG['project_id']}/labels?schema_id={id}"
    # Step 4: Make the POST request using the SDK's session auth
    response = requests.get(url, auth=auth)
    # Step 5: Handle the response
    if response.status_code == 200:
        #print("Annotation update successful!")
        #pprint.pprint(response.json())  # View the updated data
        return response.json()
    else:
        print(f"Error: {response.status_code}")
        print(response.text)  # Print error details for debugging

In [562]:
def add_schema_label(label_name, existing_label_names):
        if label_name.upper() not in existing_label_names:
            print("Adding label:", label_name.upper())
            label = { 'name': label_name.upper()}
            project.label_new(label, schema_id=schema_id)
            existing_label_names.add(label_name.upper())
        else:
            print(f"Label '{label_name.upper()}' already exists in the schema, skipping.")
        return existing_label_names

In [None]:
# -------------------------------
# STEP 1: Create the label mappings.
# -------------------------------

label2id = {}  # mapping from label string to numeric id
id2label = {}  # mapping from numeric id to label string

# Iterate through the schema list and create the mappings.
for idx, label_info in enumerate(schema_response['schema']):
    label = label_info['name']
    label2id[label] = idx
    id2label[idx] = label

# Print the mappings to verify
print("Label to ID Mapping:", label2id)
print("ID to Label Mapping:", id2label)

# -------------------------------
# STEP 2: Save the schema to Diffgram.
# -------------------------------

# Define the name of your NER schema in Diffgram.
NER_schema_name = 'ENTITY_TRAINING_SCHEMA'
schema_id = None

# List the existing schemas in your Diffgram project.
schemas = project.schema.list()
print("Existing Schemas in Diffgram:")
print(json.dumps(schemas, indent=2))

# Check if a schema with the name NER_schema_name already exists.
for schema in schemas:
    if schema.get('name') == NER_schema_name:
        schema_id = schema.get('id')
        break

# If the schema does not exist, create a new one.
if schema_id is None:
    print(f"Schema '{NER_schema_name}' not found. Creating a new one...")
    json_response = project.new_schema(name=NER_schema_name)
    schema_id = json_response.get("id")
    print(f"Created new schema with id: {schema_id}")
else:
    print(f"Schema '{NER_schema_name}' already exists with id: {schema_id}")

schema_labels = get_schema_list(schema_id)

# Retrieve existing labels for the schema to avoid duplicates.
schema_label_id_value = []
if schema_labels is not None:
    labels = schema_labels['labels_out']
    for label in labels:
        value = {}
        value['id'] = label['id']
        value['name'] = label['label']['name']
        schema_label_id_value.append(value)

existing_label_names = set()
try:
    schema_label_id_value[0]['name']
    for label in schema_label_id_value:
            label_name = label.get("name")
            if label_name:
                existing_label_names.add(label_name)
    print(existing_label_names)      
except:
     print("There are no schema labels")   

# Now add each label from your NER_schema to the schema if it doesn't already exist.
for label_def in schema_response['schema']:
    label_name = label_def['name'].upper()
    if label_name not in existing_label_names:
        print("Adding label:", label_name)
        project.label_new(label_def, schema_id=schema_id)
        existing_label_names.add(label_name)
    else:
        print(f"Label '{label_name}' already exists in the schema, skipping.")   

# Adding triplet data
try:
    for triplet_def in schema_response['triplet_schema']:
        triplet_subject = triplet_def['subject']
        triplet_relation = triplet_def['relation']
        triplet_object = triplet_def['object']
        existing_label_names = add_schema_label(triplet_subject, existing_label_names)
        existing_label_names = add_schema_label(triplet_relation, existing_label_names)
        existing_label_names = add_schema_label(triplet_subject, existing_label_names)
except:
     print("There are no schema labels")   