# Improving entity resolution

In the previous notebook, we attempted to perform entity resolution with little data.

While Louvain + Jaro-Winkler did produce near-perfect results on a relatively large set of entities, there are things we can do to improve this.

## Features

In an ideal world, your data would include stable unique identifiers like:

- Social Security Numbers
- Phone numbers
- Addresses

And definable attributes, like:
- Age
- Gender
- Job role

The more of these identifiers and attributes we have, the easier it is to resolve an entity.

Let's say we had a PhoneNumber node and an SSN node. In our dataset, some Person nodes are connected to both, and some only to one. To resolve these, we can simply run WCC on a projection of these nodes and rels, do some light name-matching and push them to a Parent node.

However, real-life is rarely so kind -- especially when dealing with unstructured data.

## Inferring features

In the previous notebook, we did not consider only one aspect of each User. Instead, we considered several features of each user.

Think of each feature as a voter. The more voters we have, with different perspectives, the more accurate our resolutions.

In this notebook, we'll try to infer some new feature nodes to use as voting blocks for resolution, inclulding:

- Stylometry
- Etc.
- Etc.

First, as always, we'll connect to the database. 

Then, we'll clean up the resolution artifacts from the previous notebook -- we can do better.

In [10]:
%pip install transformers torch --quiet

Note: you may need to restart the kernel to use updated packages.


In [51]:
%pip install tabulate

Collecting tabulate
  Downloading tabulate-0.9.0-py3-none-any.whl.metadata (34 kB)
Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)
Installing collected packages: tabulate
Successfully installed tabulate-0.9.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Connect to the database
import os
import pandas as pd
from graphdatascience import GraphDataScience
from dotenv import load_dotenv
import time

load_dotenv()

NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USERNAME", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
DATABASE = "enrondemo"

if not NEO4J_PASSWORD:
    raise ValueError("NEO4J_PASSWORD not found in .env file!")

try:
    gds = GraphDataScience(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD), database=DATABASE)
    gds_version = gds.version()
    print(f"Connected to Neo4j GDS {gds_version}")
    
    # Verify data exists
    result = gds.run_cypher("MATCH (e:Email) RETURN count(e) as count")
    email_count = result['count'].iloc[0]
    
    if email_count == 0:
        raise ValueError("No emails found! Please run Notebook 1 first to import data.")
    
    print(f"Found {email_count:,} emails in database")
    
except Exception as e:
    print(f"Connection or data validation failed: {e}")
    print("\nTroubleshooting:")
    print("  1. Have you run Notebook 1 to import data?")
    print("  2. Is Neo4j running with the 'neo4j database?")
    print("  3. Is the GDS plugin installed and activated?")
    raise

Connected to Neo4j GDS 2.23.0
Found 517,401 emails in database


  from .autonotebook import tqdm as notebook_tqdm


In [60]:
# Clean up the resolution artifacts from the previous notebook
delete_email_properties = """
MATCH (e:Email)
CALL (e) {
    REMOVE e.dateFloat,
           e.degree_centrality_undirected,
           e.fastrp_embedding_full,
           e.fastrp_embedding_scaled,
           e.louvain_community,
           e.scaledFeatures,
           e.fastrp_embedding_features,
           e.wcc_id
} IN TRANSACTIONS OF 10000 ROWS
"""

delete_user_properties = """
MATCH (u:User)
CALL (u) {
    REMOVE u.degree_centrality_undirected,
           u.email_degree,
           u.fastrp_embedding_full,
           u.fastrp_embedding_scaled,
           u.fastrp_embedding_features,
           u.louvain_community,
           u.scaledFeatures,
           u.wcc_id
} IN TRANSACTIONS OF 10000 ROWS
"""

delete_mailbox_properties = """
MATCH (m:Mailbox)
CALL (m) {
    REMOVE m.betweenness,
           m.leiden_community,
           m.pagerank,
           m.degree_centrality_undirected,
           m.email_degree,
           m.fastrp_embedding_full,
           m.fastrp_embedding_scaled,
           m.fastrp_embedding_features,
           m.louvain_community,
           m.scaledFeatures,
           m.wcc_id
} IN TRANSACTIONS OF 10000 ROWS
"""

gds.run_cypher(delete_email_properties)
gds.run_cypher(delete_user_properties)
gds.run_cypher(delete_mailbox_properties)


# 1. Stylometry

Although a relatively weak signal at the smaller scale, the attributes of a person's writing style can become useful artifacts in and of themselves.

From email content alone we can find the following fingerprints:
- Lexical
- Syntactic
- Character-level (n-grams, capitalisation, whitespace)
- Structural (paragraph length, greeting/signoff patterns, line breaks)
- Function words

Take the following email for example:

First, we'll create an index on our name properties. This will improve the speed of our text distance checks.

In [3]:
index_query = """
// Create fulltext index on User name properties
CREATE FULLTEXT INDEX user_names_fulltext IF NOT EXISTS
FOR (u:User)
ON EACH [u.nameNormStrip, u.nameRaw]
"""

index_results = gds.run_cypher(index_query)
print(index_results)

Empty DataFrame
Columns: []
Index: []


In [None]:
# Load STAR model
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd

# Load STAR model (if not already loaded)
print("Loading STAR model...")
tokenizer = AutoTokenizer.from_pretrained('AIDA-UPM/star')
model = AutoModel.from_pretrained('AIDA-UPM/star')
model.eval()
print("Model loaded.")

Loading STAR model...
Model loaded.


In [12]:
# =============================================================================
# STAR define functions
# =============================================================================

def get_style_embedding(text: str, max_length: int = 512) -> list:
    """Generate a 768-dimensional stylometric embedding using STAR."""
    if not text or len(text.strip()) < 50:
        return None
    inputs = tokenizer(
        text, 
        truncation=True, 
        max_length=max_length,
        padding=True, 
        return_tensors='pt'
    )
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.pooler_output.squeeze().tolist()

def cosine_similarity(v1: list, v2: list) -> float:
    """Compute cosine similarity between two vectors."""
    if v1 is None or v2 is None:
        return None
    import math
    dot = sum(a * b for a, b in zip(v1, v2))
    norm1 = math.sqrt(sum(a * a for a in v1))
    norm2 = math.sqrt(sum(b * b for b in v2))
    return dot / (norm1 * norm2) if norm1 and norm2 else 0.0

def extract_first_segment(thread_text: str) -> str:
    """
    Extract only the first segment of an email thread.
    Removes forwarded/replied content to get clean authored text.
    """
    if not thread_text:
        return ""
    
    markers = [
        '---------------------- Forwarded by',
        '----- Forwarded by',
        '-----Original Message-----',
        '---------------------- Original',
        'From:',  # Often marks the start of forwarded headers
    ]
    
    text = thread_text
    for marker in markers:
        if marker in text:
            parts = text.split(marker)
            # Take everything before the first marker
            text = parts[0]
    
    return text.strip()

In [None]:
# =============================================================================
# STAR Fingerprinting: Compare Rick Buy vs Kenneth Lay mailboxes
# =============================================================================

query = """
// Target names to search
WITH ['Rick Buy', 'Kenneth Lay'] AS target_names
UNWIND target_names AS target_name

// Split name into words and add fuzzy threshold
WITH target_name,
     [word IN split(target_name, ' ') | word + '~0.7'] AS fuzzy_words

// Query fulltext index with fuzzy matching
CALL db.index.fulltext.queryNodes('user_names_fulltext', apoc.text.join(fuzzy_words, ' AND '))
YIELD node AS u, score

// Filter by score threshold
WHERE score > 1.0

// Get their mailboxes
MATCH (u)-[:USED]->(m:Mailbox)

// Get emails SENT from this mailbox
MATCH (m)-[:SENT]->(e:Email)

// Aggregate
WITH target_name, u, m, score, collect(e.thread)[0..30] AS threads

RETURN 
    target_name AS search_term,
    u.nameRaw AS user_name,
    u.nameNormStrip AS user_normalized,
    m.address AS mailbox,
    size(threads) AS email_count,
    score AS match_score,
    threads,
    u.louvain_community AS community
ORDER BY target_name, score DESC, size(threads) DESC
"""

print("Querying Neo4j using fulltext fuzzy search...")
results = gds.run_cypher(query)
print(f"Found {len(results)} mailbox records")
results.head(50)

Querying Neo4j using fulltext fuzzy search...
Found 13 mailbox records


Unnamed: 0,search_term,user_name,user_normalized,mailbox,email_count,match_score,threads,community
0,Kenneth Lay,Kenneth Lay,Kenneth Lay,kenneth.lay@enron.com,30,12.301496,[Message-ID: <4205652.1075859390811.JavaMail.e...,566080
1,Kenneth Lay,"Kenneth "" Lay","Kenneth "" Lay",kenneth.lay@enron.com,30,12.301496,[Message-ID: <4205652.1075859390811.JavaMail.e...,132149
2,Kenneth Lay,Kenneth Lay (E-mail),Kenneth Lay,kenneth.lay@enron.com,30,10.685865,[Message-ID: <4205652.1075859390811.JavaMail.e...,132149
3,Kenneth Lay,Kenneth L. Lay,Kenneth L. Lay,kenneth.lay@enron.com,30,10.423277,[Message-ID: <4205652.1075859390811.JavaMail.e...,495222
4,Kenneth Lay,Kenneth L. Lay (E-mail),Kenneth L. Lay,kenneth.lay@enron.com,30,9.20072,[Message-ID: <4205652.1075859390811.JavaMail.e...,132149
5,Kenneth Lay,Kenneth L. Lay - Enron,Kenneth L. Lay - Enron,kenneth.lay@enron.com,30,9.043266,[Message-ID: <4205652.1075859390811.JavaMail.e...,132149
6,Rick Buy,Rick Buy,Rick Buy,rick.buy@enron.com,30,13.503496,[Message-ID: <10973933.1075859677452.JavaMail....,566080
7,Rick Buy,Rick Buy,Rick Buy,buy@enron.com,1,13.503496,[Message-ID: <316659.1075857666899.JavaMail.ev...,566080
8,Rick Buy,Rick Buy- Enron Corp. Chief Risk Officer,Rick Buy- Enron Corp. Chief Risk Officer,no.address@enron.com,30,9.978207,[Message-ID: <15394323.1075840430245.JavaMail....,487657
9,Rick Buy,Rick Buy- Enron Corp. Chief Risk Officer,Rick Buy- Enron Corp. Chief Risk Officer,40enron@enron.com,30,9.978207,[Message-ID: <7671344.1075840430472.JavaMail.e...,487657


In [None]:
# =============================================================================
# Generate STAR embeddings with fixed extraction
# =============================================================================

def extract_email_body(thread_text: str) -> str:
    """
    Extract the actual email body content, stripping headers and forwarded content.
    """
    if not thread_text:
        return ""
    
    text = thread_text
    
    # Parse past headers to find body
    lines = text.split('\n')
    in_headers = True
    body_lines = []
    
    for i, line in enumerate(lines):
        if in_headers:
            # Header lines: "Field: Value" or whitespace continuation
            if ':' in line and line.split(':')[0].replace('-', '').replace('_', '').replace(' ', '').isalnum():
                continue
            elif line.startswith(' ') or line.startswith('\t'):
                continue
            elif line.strip() == '':
                in_headers = False
                continue
            else:
                in_headers = False
                body_lines.append(line)
        else:
            body_lines.append(line)
    
    body_text = '\n'.join(body_lines)
    
    # Remove forwarded/replied content
    forward_markers = [
        '---------------------- Forwarded by',
        '----- Forwarded by',
        '-----Original Message-----',
        '---------------------- Original',
        '----- Original Message -----',
        '________________________________________',
    ]
    
    for marker in forward_markers:
        if marker in body_text:
            body_text = body_text.split(marker)[0]
    
    return body_text.strip()


def concatenate_emails(threads: list, max_chars: int = 15000) -> str:
    """
    Concatenate email bodies for fingerprinting.
    """
    if not threads:
        return ""
    
    segments = []
    total_chars = 0
    
    for thread in threads:
        if thread is None:
            continue
        segment = extract_email_body(thread)
        if len(segment) > 30:  # Lowered threshold - short emails still have style
            segments.append(segment)
            total_chars += len(segment)
            if total_chars > max_chars:
                break
    
    return "\n\n---\n\n".join(segments)


# Process each row (no deduplication - each User-Mailbox pair is unique)
fingerprints = []

print("Generating STAR embeddings with FIXED extraction...")
print("-" * 70)

for idx, row in results.iterrows():
    combined_text = concatenate_emails(row['threads'])
    embedding = get_style_embedding(combined_text) if combined_text and len(combined_text) > 100 else None
    
    fingerprints.append({
        'search_term': row['search_term'],
        'user_name': row['user_name'],
        'user_normalized': row['user_normalized'],
        'mailbox': row['mailbox'],
        'email_count': row['email_count'],
        'text_length': len(combined_text) if combined_text else 0,
        'embedding': embedding,
        'community': row['community']
    })
    
    status = "✓" if embedding else "✗ (insufficient text)"
    print(f"  {row['user_normalized'][:25]:<25} | {row['mailbox']:<35} | {row['email_count']:>3} emails | {len(combined_text) if combined_text else 0:>5} chars {status}")

valid_embeddings = len([f for f in fingerprints if f['embedding']])
print("-" * 70)
print(f"Generated {valid_embeddings} embeddings from {len(fingerprints)} User-Mailbox pairs")

Generating STAR embeddings with FIXED extraction...
----------------------------------------------------------------------
  Kenneth Lay               | kenneth.lay@enron.com               |  30 emails |  8719 chars ✓
  Kenneth " Lay             | kenneth.lay@enron.com               |  30 emails |  8719 chars ✓
  Kenneth Lay               | kenneth.lay@enron.com               |  30 emails |  8719 chars ✓
  Kenneth L. Lay            | kenneth.lay@enron.com               |  30 emails |  8719 chars ✓
  Kenneth L. Lay            | kenneth.lay@enron.com               |  30 emails |  8719 chars ✓
  Kenneth L. Lay - Enron    | kenneth.lay@enron.com               |  30 emails |  8719 chars ✓
  Rick Buy                  | rick.buy@enron.com                  |  30 emails |  9635 chars ✓
  Rick Buy                  | buy@enron.com                       |   1 emails |   315 chars ✓
  Rick Buy- Enron Corp. Chi | no.address@enron.com                |  30 emails | 15836 chars ✓
  Rick Buy- Enron Corp

In [53]:
# =============================================================================
# Compare all to rick.buy@enron.com (with copy-pasteable output)
# =============================================================================

# Find Rick Buy reference
reference_mailbox = 'rick.buy@enron.com'
reference_embedding = None
reference_user = None

for fp in fingerprints:
    if fp['mailbox'] == reference_mailbox and fp['embedding'] is not None:
        if fp['user_normalized'] == 'Rick Buy':
            reference_embedding = fp['embedding']
            reference_user = fp['user_name']
            break

if reference_embedding is None:
    print("ERROR: No embedding for rick.buy@enron.com")
else:
    comparison_results = []
    
    for fp in fingerprints:
        similarity = cosine_similarity(reference_embedding, fp['embedding'])
        
        comparison_results.append({
            'Search Term': fp['search_term'],
            'User': fp['user_name'],
            'Mailbox': fp['mailbox'],
            'Community': fp['community'],
            'Emails': fp['email_count'],
            'Text Chars': fp['text_length'],
            'Style Similarity': round(similarity, 4) if similarity else None
        })
    
    df_comparison = pd.DataFrame(comparison_results)
    df_comparison = df_comparison.sort_values('Style Similarity', ascending=False, na_position='last')
    
    # Print as markdown table
    print("=" * 95)
    print(f"STYLOMETRIC SIMILARITY TO {reference_user} ({reference_mailbox})")
    print("=" * 95)
    print()
    print(df_comparison.to_markdown(index=False))
    print()
    
    # Summary with sufficient text only
    df_valid = df_comparison[df_comparison['Text Chars'] >= 500]
    rick_sims = df_valid[df_valid['Search Term'] == 'Rick Buy']['Style Similarity'].dropna()
    ken_sims = df_valid[df_valid['Search Term'] == 'Kenneth Lay']['Style Similarity'].dropna()
    
    print(f"--- Summary (mailboxes with 500+ chars) ---")
    print(f"Rick Buy avg similarity:    {rick_sims.mean():.4f} (n={len(rick_sims)})")
    print(f"Kenneth Lay avg similarity: {ken_sims.mean():.4f} (n={len(ken_sims)})")
    print(f"Separation:                 {rick_sims.mean() - ken_sims.mean():+.4f}")
    
    if rick_sims.mean() > ken_sims.mean():
        print("\nSTAR successfully distinguishes Rick Buy from Kenneth Lay!")
    else:
        print("\nStyles not clearly separated - may need more data or different approach")

STYLOMETRIC SIMILARITY TO Rick Buy (rick.buy@enron.com)

| Search Term   | User                                     | Mailbox                 |   Community |   Emails |   Text Chars |   Style Similarity |
|:--------------|:-----------------------------------------|:------------------------|------------:|---------:|-------------:|-------------------:|
| Rick Buy      | Rick Buy                                 | rick.buy@enron.com      |      566080 |       30 |         9635 |             1      |
| Rick Buy      | Rick Buy                                 | buy@enron.com           |      566080 |        1 |          315 |             0.7392 |
| Kenneth Lay   | Kenneth Lay                              | kenneth.lay@enron.com   |      566080 |       30 |         8719 |             0.6983 |
| Kenneth Lay   | Kenneth " Lay                            | kenneth.lay@enron.com   |      132149 |       30 |         8719 |             0.6983 |
| Kenneth Lay   | Kenneth Lay (E-mail)                 

## Stylometric workflow

First let's get a sense of how many emails we're going to embed.

In [39]:
# =============================================================================
# Step 1: Assess scale - how many mailboxes have sent emails?
# =============================================================================

scale_query = """
MATCH (m:Mailbox)-[:SENT]->(e:Email)
WITH m, count(e) AS email_count
RETURN 
    count(m) AS total_mailboxes,
    sum(email_count) AS total_emails,
    avg(email_count) AS avg_emails_per_mailbox,
    percentileCont(email_count, 0.5) AS median_emails,
    percentileCont(email_count, 0.9) AS p90_emails,
    max(email_count) AS max_emails
"""

print("Assessing dataset scale...")
scale = gds.run_cypher(scale_query)
print(scale.to_markdown(index=False))

Assessing dataset scale...
|   total_mailboxes |   total_emails |   avg_emails_per_mailbox |   median_emails |   p90_emails |   max_emails |
|------------------:|---------------:|-------------------------:|----------------:|-------------:|-------------:|
|             20311 |         517399 |                  25.4738 |               3 |           25 |        16735 |


Next we need to acquire only the mailboxes that have sent enough emails to obtain adequate fingerprints.

In [40]:
# =============================================================================
# Step 2: Get all mailboxes with sufficient emails for fingerprinting
# =============================================================================

MIN_EMAILS = 3

mailbox_query = f"""
MATCH (m:Mailbox)-[:SENT]->(e:Email)
WITH m, collect(e.thread) AS threads, count(e) AS email_count
WHERE email_count >= {MIN_EMAILS}

// Get the User(s) associated with this mailbox for reference
OPTIONAL MATCH (u:User)-[:USED]->(m)

RETURN 
    m.address AS mailbox,
    email_count,
    threads,
    collect(DISTINCT u.nameNormStrip)[0..3] AS associated_users,
    collect(DISTINCT u.louvain_community)[0..3] AS communities
ORDER BY email_count DESC
"""

print(f"Fetching mailboxes with >= {MIN_EMAILS} sent emails...")
mailbox_data = gds.run_cypher(mailbox_query)
print(f"Found {len(mailbox_data)} mailboxes to fingerprint")

Fetching mailboxes with >= 3 sent emails...
Found 10272 mailboxes to fingerprint


Next we will use the STAR model to generate embeddings for our concatenated emails.

In [41]:
# =============================================================================
# Step 3 (FIXED): Generate STAR embeddings - TOP MESSAGE ONLY
# =============================================================================

import time
import re

def extract_email_body_v3(thread_text: str) -> str:
    """
    Extract ONLY the first/top message - the actual authored content.
    STOPS (breaks) at ANY sign of quoted/forwarded content.
    """
    if not thread_text:
        return ""
    
    lines = thread_text.split('\n')
    
    # Step 1: Skip past all headers (find body start)
    body_start = 0
    for i, line in enumerate(lines):
        if line.startswith('X-FileName:'):
            body_start = i + 1
            break
        if line.strip() == '' and i > 0:
            prev_lines = lines[max(0,i-3):i]
            if any(':' in l and l.split(':')[0].replace('-','').replace('_','').isalnum() for l in prev_lines):
                body_start = i + 1
                break
    
    # Step 2: Collect lines until we hit ANY reply/forward indicator
    body_lines = []
    for line in lines[body_start:]:
        
        # STOP at quoted lines
        if line.strip().startswith('>'):
            break
        
        # STOP at ">>> email@domain" reply markers
        if re.search(r'>{2,}.*@.*>{2,}', line):
            break
        if re.search(r'>>>.*@', line):
            break
            
        # STOP at "On DATE, NAME wrote:" 
        if re.match(r'^\s*On .+wrote:\s*$', line, re.IGNORECASE):
            break
            
        # STOP at forwarding markers
        if re.match(r'^\s*-{3,}.*(Original|Forwarded).*-*\s*$', line, re.IGNORECASE):
            break
        if '---------------------- Forwarded by' in line:
            break
        if '----- Forwarded by' in line:
            break
        if '-----Original Message-----' in line:
            break
            
        # STOP at inline "From:" headers (forwarded content)
        if re.match(r'^\s*From:\s*[\w\s]*<?[\w\.-]+@[\w\.-]+>?\s*$', line):
            break
            
        # STOP at "Sent:" header (often follows From: in forwards)
        if re.match(r'^\s*Sent:\s+', line):
            break
        
        body_lines.append(line)
    
    body_text = '\n'.join(body_lines)
    
    # Step 3: Remove signature artifacts (phone numbers at end)
    body_text = re.sub(r'\n\d{3}[-.\s]?\d{3}[-.\s]?\d{4}\s*$', '', body_text)
    
    # Step 4: Clean up whitespace
    body_text = re.sub(r'\n{3,}', '\n\n', body_text)
    body_text = body_text.strip()
    
    return body_text


def concatenate_emails_v3(threads: list, max_chars: int = 15000) -> str:
    """Concatenate unique TOP-LEVEL email bodies only."""
    if not threads:
        return ""
    
    seen_content = set()
    segments = []
    total_chars = 0
    
    for thread in threads:
        if thread is None:
            continue
        
        segment = extract_email_body_v3(thread)
        
        # Skip very short segments
        if len(segment) < 50:
            continue
        
        # Skip attachment-only messages
        if segment.count('.doc') + segment.count('.xls') + segment.count('.pdf') > len(segment) / 100:
            continue
        
        # Deduplicate
        fingerprint = segment[:100].strip().lower()
        if fingerprint in seen_content:
            continue
        seen_content.add(fingerprint)
        
        segments.append(segment)
        total_chars += len(segment)
        
        if total_chars > max_chars:
            break
    
    return "\n\n".join(segments)


# Process all mailboxes
MIN_TEXT_LENGTH = 200
embeddings_data_v3 = []
skipped = 0
start_time = time.time()

print(f"Generating STAR embeddings for {len(mailbox_data)} mailboxes (TOP MESSAGE ONLY)...")
print("-" * 70)

for idx, row in mailbox_data.iterrows():
    if idx % 100 == 0:
        elapsed = time.time() - start_time
        rate = idx / elapsed if elapsed > 0 else 0
        print(f"  Processing {idx}/{len(mailbox_data)} ({rate:.1f}/sec)...")
    
    combined_text = concatenate_emails_v3(row['threads'])
    
    if len(combined_text) < MIN_TEXT_LENGTH:
        skipped += 1
        continue
    
    embedding = get_style_embedding(combined_text)
    
    if embedding:
        embeddings_data_v3.append({
            'mailbox': row['mailbox'],
            'email_count': row['email_count'],
            'text_length': len(combined_text),
            'embedding': embedding,
            'associated_users': row['associated_users'],
            'communities': row['communities']
        })

elapsed = time.time() - start_time
print("-" * 70)
print(f"Generated {len(embeddings_data_v3)} embeddings in {elapsed:.1f}s")
print(f"Skipped {skipped} mailboxes (< {MIN_TEXT_LENGTH} chars after cleaning)")
print(f"Rate: {len(embeddings_data_v3)/elapsed:.1f} embeddings/sec")

Generating STAR embeddings for 10272 mailboxes (TOP MESSAGE ONLY)...
----------------------------------------------------------------------
  Processing 0/10272 (0.0/sec)...
  Processing 100/10272 (3.5/sec)...
  Processing 200/10272 (3.5/sec)...
  Processing 300/10272 (3.5/sec)...
  Processing 400/10272 (3.5/sec)...
  Processing 500/10272 (3.5/sec)...
  Processing 600/10272 (3.5/sec)...
  Processing 700/10272 (3.5/sec)...
  Processing 800/10272 (3.5/sec)...
  Processing 900/10272 (3.5/sec)...
  Processing 1000/10272 (3.4/sec)...
  Processing 1100/10272 (3.4/sec)...
  Processing 1200/10272 (3.5/sec)...
  Processing 1300/10272 (3.5/sec)...
  Processing 1400/10272 (3.5/sec)...
  Processing 1500/10272 (3.5/sec)...
  Processing 1600/10272 (3.5/sec)...
  Processing 1700/10272 (3.5/sec)...
  Processing 1800/10272 (3.5/sec)...
  Processing 1900/10272 (3.5/sec)...
  Processing 2000/10272 (3.5/sec)...
  Processing 2100/10272 (3.5/sec)...
  Processing 2200/10272 (3.5/sec)...
  Processing 2300/102

In [42]:
# =============================================================================
# Step 4: Write embeddings to Neo4j as Style nodes
# =============================================================================

print("Writing Style nodes to Neo4j...")

# Create constraint for Style nodes
gds.run_cypher("""
CREATE CONSTRAINT style_mailbox IF NOT EXISTS
FOR (s:Style) REQUIRE s.mailbox IS UNIQUE
""")

# Batch write embeddings
BATCH_SIZE = 100
total_written = 0

for i in range(0, len(embeddings_data_v3), BATCH_SIZE):
    batch = embeddings_data_v3[i:i+BATCH_SIZE]
    
    write_query = """
    UNWIND $batch AS item
    MERGE (s:Style {address: item.mailbox})
    SET s.embedding = item.embedding,
        s.email_count = item.email_count,
        s.text_length = item.text_length
    WITH s, item
    MATCH (m:Mailbox {address: item.mailbox})
    MERGE (m)-[:HAS_STYLE]->(s)
    RETURN count(s) AS written
    """
    
    result = gds.run_cypher(write_query, params={'batch': batch})
    total_written += result['written'].iloc[0]
    
    if (i // BATCH_SIZE) % 10 == 0:
        print(f"  Written {total_written} Style nodes...")

print(f"Total Style nodes created: {total_written}")

Writing Style nodes to Neo4j...
  Written 100 Style nodes...
  Written 1100 Style nodes...
  Written 2100 Style nodes...
  Written 3100 Style nodes...
  Written 4100 Style nodes...
  Written 5100 Style nodes...
  Written 6100 Style nodes...
  Written 7100 Style nodes...
  Written 8100 Style nodes...
  Written 9100 Style nodes...
Total Style nodes created: 9575


In [43]:
# =============================================================================
# Step 5: Project Style nodes with embeddings using Cypher projection
# =============================================================================

# Drop existing projection if exists
try:
    gds.run_cypher("CALL gds.graph.drop('style-similarity')")
    print("Dropped existing 'style-similarity' projection")
except:
    pass

# Cypher projection for Style nodes with embeddings
print("Creating Cypher projection of Style nodes...")

projection_result = gds.run_cypher("""
MATCH (s:Style)
WHERE s.embedding IS NOT NULL
WITH gds.graph.project(
    'style-similarity',
    s,
    null,
    {
        sourceNodeProperties: {embedding: s.embedding},
        targetNodeProperties: {embedding: null}
    }
) AS g
RETURN g.graphName AS graphName, 
       g.nodeCount AS nodeCount, 
       g.relationshipCount AS relationshipCount
""")

print(projection_result.to_markdown(index=False))

Dropped existing 'style-similarity' projection
Creating Cypher projection of Style nodes...
| graphName        |   nodeCount |   relationshipCount |
|:-----------------|------------:|--------------------:|
| style-similarity |       10023 |                   0 |


In [44]:
# =============================================================================
# Step 6: Run KNN to create similarity relationships (mutate)
# =============================================================================

print("Running KNN to find similar writing styles...")

knn_result = gds.run_cypher("""
CALL gds.knn.mutate(
    'style-similarity',
    {
        nodeProperties: ['embedding'],
        topK: 2,
        sampleRate: 1.0,
        deltaThreshold: 0.001,
        similarityCutoff: 0.6,
        mutateRelationshipType: 'SIMILAR_STYLE',
        mutateProperty: 'score'
    }
)
YIELD nodesCompared, relationshipsWritten, similarityDistribution
RETURN 
    nodesCompared,
    relationshipsWritten,
    round(similarityDistribution.min, 4) AS min_similarity,
    round(similarityDistribution.mean, 4) AS mean_similarity,
    round(similarityDistribution.max, 4) AS max_similarity,
    round(similarityDistribution.p50, 4) AS median_similarity,
    round(similarityDistribution.p90, 4) AS p90_similarity
""")

print(knn_result.to_markdown(index=False))

Running KNN to find similar writing styles...
|   nodesCompared |   relationshipsWritten |   min_similarity |   mean_similarity |   max_similarity |   median_similarity |   p90_similarity |
|----------------:|-----------------------:|-----------------:|------------------:|-----------------:|--------------------:|-----------------:|
|           10023 |                  20046 |           0.8007 |            0.9366 |                1 |              0.9408 |           0.9661 |


In [45]:
result = gds.run_cypher("""
CALL gds.graph.relationships.toUndirected(
  'style-similarity',
  {relationshipType: 'SIMILAR_STYLE', mutateRelationshipType: 'STYLE_UNDIRECTED'}
)
YIELD
  inputRelationships, relationshipsWritten
  """)

In [46]:
# =============================================================================
# Step 7: Run Leiden stats for tuning
# =============================================================================

print("Running Leiden community detection (stats mode for tuning)...")

# Try different resolution values
resolutions = [0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]

leiden_stats = []
for res in resolutions:
    result = gds.run_cypher(f"""
    CALL gds.leiden.stats(
        'style-similarity',
        {{
            relationshipTypes: ['STYLE_UNDIRECTED'],
            relationshipWeightProperty: 'score',
            gamma: {res},
            maxLevels: 10
        }}
    )
    YIELD communityCount, modularity, ranLevels
    RETURN {res} AS gamma, communityCount, round(modularity, 4) AS modularity, ranLevels
    """)
    leiden_stats.append(result.iloc[0].to_dict())

import pandas as pd
df_stats = pd.DataFrame(leiden_stats)
print("\nLeiden Stats by Gamma (resolution):")
print(df_stats.to_markdown(index=False))
print("\nHigher gamma = more communities, Lower gamma = fewer larger communities")
print("Choose gamma with good modularity and reasonable community count")

Running Leiden community detection (stats mode for tuning)...

Leiden Stats by Gamma (resolution):
|   gamma |   communityCount |   modularity |   ranLevels |
|--------:|-----------------:|-------------:|------------:|
|     0.1 |               16 |       0.9096 |           7 |
|     0.5 |               35 |       0.8124 |           6 |
|     1   |               56 |       0.7857 |           6 |
|     2   |              105 |       0.7653 |           5 |
|     5   |              198 |       0.7382 |           4 |
|    10   |              323 |       0.7135 |           4 |
|    20   |              477 |       0.6832 |           5 |
|    50   |              801 |       0.6288 |           4 |
|   100   |             1171 |       0.5697 |           4 |

Higher gamma = more communities, Lower gamma = fewer larger communities
Choose gamma with good modularity and reasonable community count


In [47]:
# =============================================================================
# Step 8: Run Leiden write with chosen gamma
# =============================================================================

# Choose gamma based on stats above (adjust as needed)
CHOSEN_GAMMA = 20.0  # <-- Adjust based on Step 7 results

print(f"Running Leiden with gamma={CHOSEN_GAMMA} and writing to nodes...")

leiden_write = gds.run_cypher(f"""
CALL gds.leiden.write(
    'style-similarity',
    {{
        relationshipTypes: ['STYLE_UNDIRECTED'],
        relationshipWeightProperty: 'score',
        gamma: {CHOSEN_GAMMA},
        maxLevels: 10,
        writeProperty: 'styleCommunity'
    }}
)
YIELD communityCount, modularity, ranLevels, nodePropertiesWritten
RETURN communityCount, 
       round(modularity, 4) AS modularity, 
       ranLevels, 
       nodePropertiesWritten
""")

print(leiden_write.to_markdown(index=False))

# Show community size distribution
community_dist = gds.run_cypher("""
MATCH (s:Style)
WHERE s.styleCommunity IS NOT NULL
WITH s.styleCommunity AS community, count(*) AS size
RETURN 
    count(*) AS total_communities,
    avg(size) AS avg_size,
    min(size) AS min_size,
    max(size) AS max_size,
    percentileCont(size, 0.5) AS median_size,
    percentileCont(size, 0.9) AS p90_size
""")
print("\nCommunity Size Distribution:")
print(community_dist.to_markdown(index=False))

Running Leiden with gamma=20.0 and writing to nodes...
|   communityCount |   modularity |   ranLevels |   nodePropertiesWritten |
|-----------------:|-------------:|------------:|------------------------:|
|              486 |       0.6829 |           4 |                   10023 |

Community Size Distribution:
|   total_communities |   avg_size |   min_size |   max_size |   median_size |   p90_size |
|--------------------:|-----------:|-----------:|-----------:|--------------:|-----------:|
|                 486 |    20.6235 |          3 |         73 |            20 |         31 |
