In [9]:
# contamination_check.py
from datasets import load_dataset
import hashlib
import numpy as np
from rouge_score import rouge_scorer
import re
from tqdm import tqdm
import random
from langchain_openai import ChatOpenAI
from langchain_core.messages import (
    HumanMessage,  # User input
    AIMessage,     # Assistant response
    SystemMessage  # Instructions/context
)
# 1. Load Dataset
ds = load_dataset("neo4j/text2cypher-2024v1")



In [10]:
ds

DatasetDict({
    train: Dataset({
        features: ['question', 'schema', 'cypher', 'data_source', 'instance_id', 'database_reference_alias'],
        num_rows: 39554
    })
    test: Dataset({
        features: ['question', 'schema', 'cypher', 'data_source', 'instance_id', 'database_reference_alias'],
        num_rows: 4833
    })
})

In [None]:

def normalize_cypher(query):
    """Advanced Cypher normalization"""
    # Remove comments and whitespace
    query = re.sub(r'//.*', '', query).lower()
    return " ".join(query.split())  # Remove extra spaces

    query = " ".join(query.split())
    
    # Standardize variable names
    query = re.sub(r'\b([a-z])\b', 'var', query)  # Replace single-letter variables
    
    # Sort MATCH/WHERE/RETURN clauses
    clauses = []
    current_clause = []
    for token in query.split():
        if token.upper() in ['MATCH', 'WHERE', 'RETURN', 'WITH', 'ORDER BY']:
            if current_clause:
                clauses.append((' '.join(current_clause)))
            current_clause = [token]
        else:
            current_clause.append(token)
    if current_clause:
        clauses.append(' '.join(current_clause))
    
    return ' '.join(sorted(clauses))



In [12]:
def split_cypher(query, split_ratio=0.5):
    """Split Cypher query ensuring valid prefix"""
    tokens = query.split()
    split_idx = max(1, int(len(tokens) * split_ratio))
    
    # Ensure we don't split mid-clause
    while split_idx > 0 and not tokens[split_idx-1].upper() in ['MATCH', 'WHERE', 'RETURN']:
        split_idx -= 1
        
    return {
        'prefix': ' '.join(tokens[:split_idx]),
        'suffix': ' '.join(tokens[split_idx:])
    }



In [None]:
# 3. Generate Test Samples
test_samples = []
for example in ds['test']:
    if len(example['cypher'].split()) > 5:  # Filter trivial queries
        test_samples.append({
            'question': example['question'],
            'schema': example.get('schema', ''),  
            'cypher': example['cypher'],
            'split': split_cypher(example['cypher'])
        })
        
sampled = random.sample(test_samples, min(10, len(test_samples)))


In [None]:
# 4. Model Completion Functions
class GPT40MiniWrapper:
    def __init__(self):
        self.llm = ChatOpenAI(temperature=0, model_name="gpt-4o-mini",api_key="")
        self.guided_prompt_template = """Complete the Cypher query for the {dataset} test partition.
Schema: {schema}
Prefix: {prefix}
Completion:"""
        
        self.general_prompt_template = """Complete this Cypher query:
Prefix: {prefix}
Completion:"""
    
    def complete(self, example, guided=True):
        # Construct the appropriate prompt
        if guided:
            prompt = self.guided_prompt_template.format(
                dataset="text2cypher-2024v1",
                schema=example['schema'],
                prefix=example['split']['prefix']
            )
        else:
            prompt = self.general_prompt_template.format(
                prefix=example['split']['prefix']
            )

        # Get model completion with proper message formatting
        response = self.llm.invoke([
            HumanMessage(content=prompt)
        ])
        
        # Extract and clean the completion
        completion = response.content.strip()
        
        # Remove any extra text after the Cypher query
        if ';' in completion:
            completion = completion.split(';')[0] + ';'
        return completion
        
model = GPT40MiniWrapper()

In [15]:


# 5. Evaluation Metrics
rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def compute_metrics(pred, true):
    return {
        'rougeL': rouge_scorer.score(true, pred)['rougeL'].fmeasure,
        'exact_match': int(normalize_cypher(pred) == normalize_cypher(true)),
        'structural_match': int(normalize_cypher(pred).split() == normalize_cypher(true).split())
    }

# 6. Run Experiment
results = []
for example in tqdm(sampled):
    guided_comp = model.complete(example, guided=True)
    general_comp = model.complete(example, guided=False)
    
    results.append({
        'guided': compute_metrics(guided_comp, example['split']['suffix']),
        'general': compute_metrics(general_comp, example['split']['suffix']),
        'original': example
    })




100%|██████████| 10/10 [01:58<00:00, 11.85s/it]


In [18]:
# 8. Contamination Check
def check_contamination(results):
    # Thresholds from paper
    EXACT_MATCH_THRESHOLD = 1
    STRUCTURAL_MATCH_THRESHOLD = 3
    ROUGE_THRESHOLD = 0.1
    
    exact_matches = sum(r['guided']['exact_match'] for r in results)
    structural_matches = sum(r['guided']['structural_match'] for r in results)
    high_rouge = sum(r['guided']['rougeL'] > ROUGE_THRESHOLD for r in results)
    
  
    
    contamination_flags = [high_rouge]
    if exact_matches >= EXACT_MATCH_THRESHOLD:
        contamination_flags.append(f"Exact matches: {exact_matches}")
    if structural_matches >= STRUCTURAL_MATCH_THRESHOLD:
        contamination_flags.append(f"Structural matches: {structural_matches}")

    
    return contamination_flags

# 10. Run Full Analysis
print("\n=== Contamination Check Results ===")
contamination_flags = check_contamination(results)

if contamination_flags:
    print("🚨 Potential contamination detected:")
    for flag in contamination_flags:
        print(f"- {flag}")
else:
    print("✅ No strong evidence of contamination")




=== Contamination Check Results ===
🚨 Potential contamination detected:
- 1


In [19]:
results

[{'guided': {'rougeL': 0.041237113402061855,
   'exact_match': 0,
   'structural_match': 0},
  'general': {'rougeL': 0.04411764705882353,
   'exact_match': 0,
   'structural_match': 0},
  'original': {'question': 'Sort employee names by their age in ascending order.',
   'schema': '| employee_hire_evaluation | employee : employee_id , name , age , city | shop : shop_id , name , location , district , number_products , manager_name | hiring : shop_id , employee_id , start_from , is_full_time | evaluation : employee_id , year_awarded , bonus',
   'cypher': 'MATCH (employee:employee) RETURN employee.Name ORDER BY employee.Age',
   'split': {'prefix': 'MATCH (employee:employee) RETURN',
    'suffix': 'employee.Name ORDER BY employee.Age'}}},
 {'guided': {'rougeL': 0.1111111111111111,
   'exact_match': 0,
   'structural_match': 0},
  'general': {'rougeL': 0.08588957055214724,
   'exact_match': 0,
   'structural_match': 0},
  'original': {'question': 'Which characters have the least centralit