In [None]:
%%capture
!pip install accelerate bitsandbytes huggingface_hub transformers func-timeout neo4j

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from func_timeout import func_timeout, FunctionTimedOut
from datetime import datetime
from neo4j import GraphDatabase
import pandas as pd
import torch
import time
import os
import re

In [None]:
import logging
import warnings
logging.getLogger('neo4j').setLevel(logging.ERROR)
warnings.filterwarnings('ignore', category=FutureWarning)

**Load model**

In [None]:
model_id = "neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1"
max_seq_length = 4096

# BitsAndBytes config cho 4bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
    low_cpu_mem_usage=True,
    device_map="auto",
)

model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/857 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/864M [00:00<?, ?B/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 3584, padding_idx=0)
    (layers): ModuleList(
      (0-41): 42 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=3584, out_features=4096, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=3584, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=3584, out_features=2048, bias=False)
            (lora_dropout):

**Load data**

In [None]:
checkpoint_path = '/content/drive/MyDrive/T2C_finetune_gemma/finetune_loop_gemma.csv'
test_path = '/content/drive/MyDrive/T2C_finetune_loop_gemma/text2cypher_test.csv'
test_df = pd.read_csv(test_path, encoding="utf-8-sig")

print(f"Loaded test shape: {test_df.shape}")

Loaded test shape: (4833, 6)


**Neo4j setup**

In [None]:
URI = "neo4j+s://demo.neo4jlabs.com:7687"

# Danh sách alias
unique_aliases = test_df["database_reference_alias"].dropna().unique().tolist()
DATABASE_ALIASES = unique_aliases

# Lưu trữ drivers theo alias
DRIVERS_BY_ALIAS = {}

In [None]:
def extract_alias(alias: str):
    name = alias.replace("neo4jlabs_demo_db_", "")
    return name, name

def get_driver(alias):
    if alias in DRIVERS_BY_ALIAS:
        return DRIVERS_BY_ALIAS[alias]
    user, pwd = extract_alias(alias)
    driver = GraphDatabase.driver(URI, auth=(user, pwd))
    DRIVERS_BY_ALIAS[alias] = driver
    return driver

def reset_driver(alias):
    if alias in DRIVERS_BY_ALIAS:
        try:
            DRIVERS_BY_ALIAS[alias].close()
        except Exception as e:
            print(f"Error closing driver: {e}")
        del DRIVERS_BY_ALIAS[alias]
    # Tạo lại driver mới
    return get_driver(alias)

print("✓ Neo4j helper loaded")

✓ Neo4j helper loaded


**Prompt Cypher**

In [None]:
INSTRUCTION_TEMPLATE = (
    "Generate Cypher statement to query a graph database. "
    "Use only the provided relationship types and properties in the schema. \n"
    "Schema: {schema} \n"
    "Question: {question}  \n"
    "Cypher output: "
)

print("✓ INSTRUCTION_TEMPLATE loaded")

✓ INSTRUCTION_TEMPLATE loaded


In [None]:
def prompt_cypher(question, schema) -> list[dict]:
    chat = [
        {
            "role": "user",
            "content": INSTRUCTION_TEMPLATE.format(
                schema=schema,
                question=question
            ),
        }
    ]
    return chat

print("✓ prompt_cypher loaded")

✓ prompt_cypher loaded


**Generate cypher raw**

In [None]:
def generate_cypher_raw(question, schema, timeout):
    def _generate():
        # Chuẩn bị chat prompt
        new_message = prompt_cypher(question=question, schema=schema)

        # Apply chat template
        prompt = tokenizer.apply_chat_template(
            new_message,
            add_generation_prompt=True,
            tokenize=False
        )

        # Tokenize
        inputs = tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = inputs.to(model.device)

        # Generation parameters
        model_generate_parameters = {
            "top_p": 0.9,
            "temperature": 0.2,
            "max_new_tokens": 512,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
        }

        # Generate - SỬA LỖI Ở ĐÂY
        with torch.no_grad():
            tokens = model.generate(
                input_ids=inputs.input_ids,           # Truyền input_ids
                attention_mask=inputs.attention_mask,  # Truyền attention_mask
                **model_generate_parameters           # Unpack parameters
            )
            # Bỏ phần input prompt
            tokens = tokens[:, inputs.input_ids.shape[1]:]
            # Decode
            raw_output = tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]

        return raw_output

    try:
        return func_timeout(timeout, _generate)
    except FunctionTimedOut:
        print(f"[TIMEOUT] Generation exceeded {timeout}s")
        return "time_error"
    except Exception as e:
        print(f"[ERROR] Generation failed: {e}")
        return "error"

print("✓ generate_cypher_raw loaded")

✓ generate_cypher_raw loaded


**Format cypher**

In [None]:
def postprocess_cypher(output_cypher: str) -> str:
    # Remove explanation
    partition_by = "**Explanation:**"
    output_cypher, _, _ = output_cypher.partition(partition_by)

    # Remove cypher code block markers
    output_cypher = output_cypher.strip("`\n")
    output_cypher = output_cypher.lstrip("cypher\n")
    output_cypher = output_cypher.strip("`\n ")

    return output_cypher

print("✓ postprocess_cypher loaded")

✓ postprocess_cypher loaded


In [None]:
def extract_cypher(text):
    if text in ["time_error", "error"]:
        return text

    try:
        # Áp dụng postprocess của Neo4j
        cypher = postprocess_cypher(text)

        # Kiểm tra có MATCH không
        if not cypher.strip():
            return "error"

        cypher_lower = cypher.lower()
        if "match" not in cypher_lower:
            return "error"

        return cypher.strip()

    except Exception as e:
        print(f"[ERROR] Extract failed: {e}")
        return "error"

print("✓ extract_cypher loaded")

✓ extract_cypher loaded


**Generate cypher**

In [None]:
def generate_cypher(question, schema, timeout=300):
    # Generate raw output
    raw_output = generate_cypher_raw(question, schema, timeout)

    # Extract Cypher
    cypher = extract_cypher(raw_output)

    return cypher

print("✓ generate_cypher loaded")

✓ generate_cypher loaded


**Prompt Self Correction**

In [None]:
def prompt_correction(schema, question, cypher_current, error) -> list[dict]:
    # AUTO-HINT MECHANISM - GENERAL INSTRUCTIONS WITH COT
    additional_hint = ""

    # Lỗi Syntax Error chung
    if "syntax error" in error.lower() or "invalid syntax" in error.lower():
        additional_hint = """
CRITICAL: Syntax error detected. Follow these steps to fix:

Step 1 - ANALYZE the error message:
- Identify which keyword/token is causing the error
- Check the position (line, column) where error occurs
- Understand what the parser expected at that position

Step 2 - COMMON SYNTAX RULES in Cypher:
- WHERE clause MUST come BEFORE RETURN, never after RETURN
- If you need multiple conditions, combine them with AND/OR in the same WHERE clause
- WITH clause cannot be the last clause - must be followed by RETURN or another clause
- RETURN must be the final clause (unless using UNION or other set operations)
- Cypher does NOT support GROUP BY - use aggregation functions in WITH instead
- Pattern expressions in WHERE must use pattern comprehension: SIZE([(pattern) | var])

Step 3 - FIX the query:
- Move misplaced clauses to correct position
- Combine multiple WHERE clauses into one
- Add missing RETURN if query ends with WITH
- Remove unsupported SQL syntax (GROUP BY, HAVING, etc.)

Example of common mistakes:
WRONG: MATCH (n) WHERE condition1 RETURN n WHERE condition2
CORRECT: MATCH (n) WHERE condition1 AND condition2 RETURN n

WRONG: WITH n.property, COUNT(*) AS count
CORRECT: WITH n.property AS property, COUNT(*) AS count
"""

    # Lỗi Unknown/Missing properties hoặc labels
    elif "unknown" in error.lower() or "does not exist" in error.lower() or "not found" in error.lower():
        additional_hint = """
CRITICAL: Unknown label/property detected. Follow these steps to fix:

Step 1 - IDENTIFY what is missing:
- Check if you're using a label that doesn't exist
- Check if you're accessing a property that doesn't exist
- Verify the entity type: is it a node property or relationship property?

Step 2 - VERIFY against schema:
- Node properties are listed under "Nodes" section
- Relationship properties are listed under "Relationships" section
- Labels use colon syntax (:Label), properties use dot syntax (variable.property)

Step 3 - FIX the query:
- Use correct property name from schema
- Access property from correct entity (node vs relationship)
- Replace incorrect label syntax with property access

Example of common mistakes:
WRONG: MATCH (n) WHERE n:`propertyName` = value
CORRECT: MATCH (n) WHERE n.propertyName = value

WRONG: Using relationship variable for node property: AVG(rel.nodeProperty)
CORRECT: Using node variable for node property: AVG(node.nodeProperty)
"""

    # Lỗi Expression/Aliasing
    elif "must be aliased" in error.lower() or "alias" in error.lower():
        additional_hint = """
CRITICAL: Aliasing error detected. Follow these steps to fix:

Step 1 - UNDERSTAND aliasing rules:
- In WITH clause, expressions must be aliased using AS
- Property expressions cannot be used directly without aliasing
- Aggregation results must be aliased

Step 2 - TWO APPROACHES to fix:
Approach A - Alias each property:
  WITH node.prop1 AS prop1, node.prop2 AS prop2, AGG(...) AS result

Approach B - Use node variable:
  WITH node, AGG(...) AS result
  RETURN node.prop1, node.prop2, result

Step 3 - CHOOSE the simpler approach:
- If you need many properties: use Approach B (pass entire node)
- If you need few properties: use Approach A (alias each one)
"""

    # Lỗi Pattern/Structure
    elif "pattern" in error.lower() or "cannot conclude" in error.lower():
        additional_hint = """
CRITICAL: Query structure error detected. Follow these steps to fix:

Step 1 - CHECK query structure:
- Does the query end with a proper clause? (RETURN, CREATE, DELETE, etc.)
- Are pattern expressions used correctly?
- Is SIZE() used with pattern comprehension?

Step 2 - COMMON STRUCTURE RULES:
- Query MUST end with RETURN (or update clause)
- WITH ... ORDER BY ... LIMIT MUST be followed by RETURN
- Pattern in SIZE() needs comprehension: SIZE([(pattern) | var]) not SIZE((pattern))
- Pattern existence check: use EXISTS { (pattern) } in WHERE

Step 3 - FIX based on rule violated:
- Add RETURN clause if missing
- Wrap pattern in comprehension for SIZE()
- Change to EXISTS if only checking presence

Example of common mistakes:
WRONG: WITH node, COUNT(*) AS count ORDER BY count LIMIT 10
CORRECT: WITH node, COUNT(*) AS count ORDER BY count LIMIT 10 RETURN node, count

WRONG: WHERE SIZE((a)-[:REL]->(b)) > 5
CORRECT: WHERE SIZE([(a)-[:REL]->(b) | a]) > 5
"""

    # BUILD CORRECTION TEMPLATE WITH COT HINT
    correction_content = (
        "You are an expert at fixing Cypher queries. Analyze the error carefully and fix the query.\n\n"
        f"Schema:\n{schema}\n\n"
        f"Original Question: {question}\n\n"
        f"Wrong Cypher Query:\n{cypher_current}\n\n"
        f"Error Message:\n{error}\n"
    )

    # Thêm hint nếu có
    if additional_hint:
        correction_content += f"\n{additional_hint}\n"

    correction_content += "\nFollow the steps above to analyze and fix the query. Return ONLY the corrected Cypher statement, no explanations.\n\nCorrected Cypher output: "

    chat = [
        {
            "role": "user",
            "content": correction_content
        }
    ]
    return chat

print("✓ prompt_correction loaded")

✓ prompt_correction loaded


**Neo4j Execution**

In [None]:
def execute_cypher(cypher_query, alias, timeout=180):
    if cypher_query in ["error", "time_error", None, ""]:
        return (False, "Invalid cypher query")

    driver = get_driver(alias)

    try:
        with driver.session() as session:
            # Chỉ execute query, KHÔNG consume result
            session.run(cypher_query, timeout=timeout)
            return (True, None)
    except Exception as e:
        error_msg = str(e)
        # Nếu lỗi authentication, reset driver
        if "authentication" in error_msg.lower() or "unauthorized" in error_msg.lower():
            reset_driver(alias)
        return (False, error_msg)

print("✓ execute_cypher loaded")

✓ execute_cypher loaded


In [None]:
def explain_cypher(cypher_query, driver, timeout=180):
    if cypher_query in ["error", None, ""]:
        return (False, "Invalid cypher query")

    try:
        with driver.session() as session:
            explain_query = f"EXPLAIN {cypher_query}"
            session.run(explain_query, timeout=timeout)
            return (True, None)
    except Exception as e:
        return (False, str(e))

print("✓ explain_cypher loaded")

✓ explain_cypher loaded


**Setup model Self Loop**

In [None]:
def llm_correct_cypher(schema, question, cypher_current, error, timeout=300):
    def _correct():
        messages = prompt_correction(
            schema=schema,
            question=question,
            cypher_current=cypher_current,
            error=error
        )

        prompt = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False
        )

        inputs = tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = inputs.to(model.device)

        model_generate_parameters = {
            "top_p": 0.9,
            "temperature": 0.1,
            "max_new_tokens": 512,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
        }

        with torch.no_grad():
            tokens = model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                **model_generate_parameters
            )
            tokens = tokens[:, inputs.input_ids.shape[1]:]
            raw_output = tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]

        corrected_cypher = extract_cypher(raw_output)

        return corrected_cypher

    try:
        return func_timeout(timeout, _correct)
    except FunctionTimedOut:
        return "time_error"
    except Exception as e:
        return "error"

**Hàm loop**

In [None]:
def cypher_self_correction(
    cypher_initial,
    alias,
    schema_context,
    question,
    max_retries=3,
    timeout=900
):
    cypher_current = cypher_initial
    retry = 0
    errors_history = []

    while retry < max_retries:
        success, error = execute_cypher(cypher_current, alias, timeout=timeout)

        if success:
            return {
                "success": True,
                "final_cypher": cypher_current,
                "retries": retry,
                "errors": errors_history
            }

        errors_history.append(f"Retry {retry}: {error}")

        cypher_corrected = llm_correct_cypher(
            schema=schema_context,
            question=question,
            cypher_current=cypher_current,
            error=error
        )

        if cypher_corrected in ["error", "time_error", None, ""]:
            break

        cypher_current = cypher_corrected
        retry += 1

    return {
        "success": False,
        "final_cypher": cypher_current,
        "retries": retry,
        "errors": errors_history
    }

In [None]:
def cypher_explain_correction(
    cypher_initial,
    driver,
    schema_context,
    question,
    max_retries=3,
    timeout=900
):
    cypher_current = cypher_initial
    retry = 0
    errors_history = []

    while retry < max_retries:
        success, error = explain_cypher(cypher_current, driver, timeout=timeout)

        if success:
            return {
                "success": True,
                "final_cypher": cypher_current,
                "retries": retry,
                "errors": errors_history
            }

        errors_history.append(f"Retry {retry}: {error}")

        cypher_corrected = llm_correct_cypher(
            schema=schema_context,
            question=question,
            cypher_current=cypher_current,
            error=error
        )

        if cypher_corrected in ["error", "time_error", None, ""]:
            break

        cypher_current = cypher_corrected
        retry += 1

    return {
        "success": False,
        "final_cypher": cypher_current,
        "retries": retry,
        "errors": errors_history
    }

**Luồng toàn cục**

In [None]:
def generate_cypher_with_correction(
    question,
    schema,
    alias,
    max_retries=3,
    timeout=1200,
    neo4j_timeout=180
):
    # Kiểm tra alias trước
    if pd.isna(alias) or alias is None or alias == "" or (isinstance(alias, str) and alias.strip() == ""):
        # Nếu alias không hợp lệ, sử dụng syntax validation với EXPLAIN
        def _execute_validation():
            cypher_initial = generate_cypher(question, schema)

            if cypher_initial in ["error", None]:
                return {
                    "cypher_initial": "error",
                    "schema": schema,
                    "correction_result": None,
                    "final_cypher": "error",
                    "success": False,
                    "retries": 0,
                    "errors": ["Failed to generate initial cypher"]
                }

            # Lấy driver đầu tiên để validate syntax
            first_alias = list(DRIVERS_BY_ALIAS.keys())[0] if DRIVERS_BY_ALIAS else DATABASE_ALIASES[0]
            driver = get_driver(first_alias)

            # Syntax validation loop với EXPLAIN
            validation_result = cypher_explain_correction(
                cypher_initial=cypher_initial,
                driver=driver,
                schema_context=schema,
                question=question,
                max_retries=max_retries,
                timeout=neo4j_timeout
            )

            final_cypher = validation_result.get("final_cypher")
            if final_cypher is None:
                final_cypher = "error"

            return {
                "cypher_initial": cypher_initial,
                "schema": schema,
                "correction_result": validation_result,
                "final_cypher": final_cypher,
                "success": validation_result.get("success"),
                "retries": validation_result.get("retries"),
                "errors": validation_result.get("errors")
            }

        try:
            result = func_timeout(timeout, _execute_validation)
            return result

        except FunctionTimedOut:
            return {
                "cypher_initial": "error",
                "schema": schema,
                "correction_result": None,
                "final_cypher": "error",
                "success": False,
                "retries": 0,
                "errors": [f"Total timeout reached after {timeout}s"]
            }
        except Exception as e:
            return {
                "cypher_initial": "error",
                "schema": schema,
                "correction_result": None,
                "final_cypher": "error",
                "success": False,
                "retries": 0,
                "errors": [f"Unexpected error: {str(e)}"]
            }

    # Alias hợp lệ - chạy execution thật
    def _execute_generation():
        cypher_initial = generate_cypher(question, schema)

        if cypher_initial in ["error", None]:
            return {
                "cypher_initial": "error",
                "schema": schema,
                "correction_result": None,
                "final_cypher": "error",
                "success": False,
                "retries": 0,
                "errors": ["Failed to generate initial cypher"]
            }

        # Self-correction loop
        correction_result = cypher_self_correction(
            cypher_initial=cypher_initial,
            alias=alias,
            schema_context=schema,
            question=question,
            max_retries=max_retries,
            timeout=neo4j_timeout
        )

        final_cypher = correction_result.get("final_cypher")
        if final_cypher is None:
            final_cypher = "error"

        return {
            "cypher_initial": cypher_initial,
            "schema": schema,
            "correction_result": correction_result,
            "final_cypher": final_cypher,
            "success": correction_result.get("success"),
            "retries": correction_result.get("retries"),
            "errors": correction_result.get("errors")
        }

    try:
        result = func_timeout(timeout, _execute_generation)
        return result

    except FunctionTimedOut:
        return {
            "cypher_initial": "error",
            "schema": schema,
            "correction_result": None,
            "final_cypher": "error",
            "success": False,
            "retries": 0,
            "errors": [f"Total timeout reached after {timeout}s"]
        }
    except Exception as e:
        return {
            "cypher_initial": "error",
            "schema": schema,
            "correction_result": None,
            "final_cypher": "error",
            "success": False,
            "retries": 0,
            "errors": [f"Unexpected error: {str(e)}"]
        }

In [None]:
# Lấy test case
first_row = test_df.iloc[0]
test_question = first_row['question']
test_schema = first_row['schema']
test_alias = first_row['database_reference_alias']

print("="*80)
print("TEST QUESTION:")
print("="*80)
print(test_question)

# Test 1: Extracted Cypher
print("\n" + "="*80)
print("TEST 1: generate_cypher() - EXTRACTED")
print("="*80)
result = generate_cypher_with_correction(
        question=test_question,
        schema=test_schema,
        alias=test_alias,
        max_retries=3,
        timeout=1200,
        neo4j_timeout=60
    )
print(f"Success: {result['success']}")
print(f"Retries: {result['retries']}")
print(f"Cypher Initial: {result['cypher_initial']}")
print(f"Final Cypher:\n{result['final_cypher']}")
if result["errors"]:
    print("Errors:")
    for e in result["errors"]:
        print(e)

# Test 2: Expected
print("\n" + "="*80)
print("TEST 2: EXPECTED CYPHER")
print("="*80)
print(first_row['cypher'])

TEST QUESTION:
Identify the 5 suppliers with the highest average unit price of products supplied.

TEST 1: generate_cypher() - EXTRACTED
Success: True
Retries: 0
Cypher Initial: MATCH (s:Supplier)<-[:SUPPLIES]-(p:Product) WITH s, avg(toFloat(p.unitPrice)) AS avgUnitPrice ORDER BY avgUnitPrice DESC LIMIT 5 RETURN s.companyName AS supplierName, avgUnitPrice
Final Cypher:
MATCH (s:Supplier)<-[:SUPPLIES]-(p:Product) WITH s, avg(toFloat(p.unitPrice)) AS avgUnitPrice ORDER BY avgUnitPrice DESC LIMIT 5 RETURN s.companyName AS supplierName, avgUnitPrice

TEST 2: EXPECTED CYPHER
MATCH (s:Supplier)-[:SUPPLIES]->(p:Product) WITH s, avg(p.unitPrice) AS avgUnitPrice ORDER BY avgUnitPrice DESC LIMIT 5 RETURN s.companyName AS Supplier, avgUnitPrice AS AverageUnitPrice


**Chạy batch**

In [None]:
def run_batch_with_correction(
    test_df,
    checkpoint_path,
    max_retries=3,
    timeout=1200,
    neo4j_timeout=60,
    log_interval=100,
    save_interval=50
):
    """
    Chạy generation với self-correction và checkpoint support
    """
    # ==========================================================================
    # BƯỚC 1: Kiểm tra và load checkpoint
    if os.path.exists(checkpoint_path):
        print(f"[CHECKPOINT] Tìm thấy file checkpoint: {checkpoint_path}")
        df = pd.read_csv(checkpoint_path, encoding="utf-8-sig")
        print(f"[CHECKPOINT] Đã load {len(df)} dòng từ checkpoint")

        processed_count = df['cypher_generated'].notna().sum()
        print(f"[CHECKPOINT] Đã xử lý: {processed_count}/{len(df)} dòng")

    else:
        print(f"[CHECKPOINT] Không tìm thấy checkpoint, tạo mới từ test_df")
        df = test_df.copy()
        df['cypher_generated'] = ''
        df['success'] = ''
        df['retries'] = ''
        df['errors'] = ''

        df.to_csv(checkpoint_path, index=False, encoding='utf-8')
        print(f"[CHECKPOINT] Đã tạo file checkpoint: {checkpoint_path}")

    # ==========================================================================
    # BƯỚC 2: Xử lý các dòng chưa có kết quả
    total_rows = len(df)

    success_count = 0
    error_count = 0
    timeout_error_count = 0
    batch_start_idx = 0

    processed_since_last_save = 0

    print(f"\n{'='*80}")
    print(f"BẮT ĐẦU XỬ LÝ - Tổng số dòng: {total_rows}")
    print(f"Model: neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1")
    print(f"Self-correction: max_retries={max_retries}, neo4j_timeout={neo4j_timeout}s")
    print(f"{'='*80}\n")

    start_time = time.time()

    for idx in range(total_rows):
        current_cypher = df.at[idx, 'cypher_generated']

        if pd.notna(current_cypher) and str(current_cypher).strip() != '':
            continue

        # ======================================================================
        # XỬ LÝ DÒNG CHƯA CÓ KẾT QUẢ
        print(f"[Processing] Dòng {idx}...", end=" ", flush=True)

        try:
            question = df.at[idx, 'question']
            schema = df.at[idx, 'schema']
            alias = df.at[idx, 'database_reference_alias']

            result = generate_cypher_with_correction(
                question=question,
                schema=schema,
                alias=alias,
                max_retries=max_retries,
                timeout=timeout,
                neo4j_timeout=neo4j_timeout
            )

            df.at[idx, 'cypher_generated'] = result['final_cypher']
            df.at[idx, 'success'] = result['success']
            df.at[idx, 'retries'] = result['retries']
            df.at[idx, 'errors'] = str(result['errors'])

            cypher_result = result['final_cypher']

            if cypher_result == "error":
                error_count += 1
                print("ERROR")
            elif cypher_result == "time_error":
                timeout_error_count += 1
                print("TIMEOUT")
            else:
                success_count += 1
                print(f"SUCCESS (retries: {result['retries']})")

            processed_since_last_save += 1

        except Exception as e:
            print(f"ERROR - {str(e)}")
            df.at[idx, 'cypher_generated'] = "error"
            df.at[idx, 'success'] = False
            df.at[idx, 'retries'] = 0
            df.at[idx, 'errors'] = str(e)
            error_count += 1
            processed_since_last_save += 1

        # ======================================================================
        # LƯU CHECKPOINT
        # ======================================================================
        if processed_since_last_save >= save_interval:
            df.to_csv(checkpoint_path, index=False, encoding='utf-8')
            print(f"[CHECKPOINT] Đã lưu sau {processed_since_last_save} dòng")
            processed_since_last_save = 0

        # ======================================================================
        # LOG THỐNG KÊ
        # ======================================================================
        if (idx + 1) % log_interval == 0:
            elapsed_time = time.time() - start_time
            avg_time_per_row = elapsed_time / (idx + 1)
            remaining_rows = total_rows - (idx + 1)
            estimated_time = avg_time_per_row * remaining_rows

            print(f"\n{'='*80}")
            print(f"[LOG] Dòng {batch_start_idx}-{idx}")
            print(f"{'='*80}")
            print(f"Thành công:     {success_count}")
            print(f"Error:          {error_count}")
            print(f"Timeout Error:  {timeout_error_count}")
            print(f"Tổng xử lý:     {success_count + error_count + timeout_error_count}")
            print(f"Tiến độ:        {idx + 1}/{total_rows} ({(idx + 1)/total_rows*100:.2f}%)")
            print(f"Thời gian:      {elapsed_time/60:.2f} phút")
            print(f"Ước tính còn:   {estimated_time/60:.2f} phút")
            print(f"{'='*80}\n")

            success_count = 0
            error_count = 0
            timeout_error_count = 0
            batch_start_idx = idx + 1

    # ==========================================================================
    # LƯU CHECKPOINT CUỐI CÙNG
    # ==========================================================================
    if processed_since_last_save > 0:
        df.to_csv(checkpoint_path, index=False, encoding='utf-8')
        print(f"[CHECKPOINT] Đã lưu {processed_since_last_save} dòng cuối cùng")

    # ==========================================================================
    # KẾT THÚC - LOG CUỐI CÙNG
    # ==========================================================================
    total_time = time.time() - start_time

    final_success = (df['cypher_generated'].notna() &
                     (df['cypher_generated'] != 'error') &
                     (df['cypher_generated'] != 'time_error') &
                     (df['cypher_generated'] != '')).sum()
    final_error = (df['cypher_generated'] == 'error').sum()
    final_timeout = (df['cypher_generated'] == 'time_error').sum()

    print(f"\n{'='*80}")
    print(f"HOÀN THÀNH")
    print(f"{'='*80}")
    print(f"Tổng số dòng:        {total_rows}")
    print(f"Thành công:          {final_success} ({final_success/total_rows*100:.2f}%)")
    print(f"Error:               {final_error} ({final_error/total_rows*100:.2f}%)")
    print(f"Timeout Error:       {final_timeout} ({final_timeout/total_rows*100:.2f}%)")
    print(f"Tổng thời gian:      {total_time/60:.2f} phút")
    print(f"Thời gian trung bình: {total_time/total_rows:.2f} giây/dòng")
    print(f"Kết quả đã lưu tại:  {checkpoint_path}")
    print(f"{'='*80}\n")

    return df

print("✓ run_batch_with_correction loaded")

✓ run_batch_with_correction loaded


In [None]:
df_results = run_batch_with_correction(
    test_df=test_df,
    checkpoint_path=checkpoint_path,
    max_retries=3,
    timeout=1200,
    neo4j_timeout=60,
    log_interval=100,
    save_interval=50
)

[CHECKPOINT] Tìm thấy file checkpoint: /content/drive/MyDrive/T2C_finetune_gemma/finetune_gemma.csv
[CHECKPOINT] Đã load 4833 dòng từ checkpoint
[CHECKPOINT] Đã xử lý: 4100/4833 dòng

BẮT ĐẦU XỬ LÝ - Tổng số dòng: 4833
Model: neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1
Self-correction: max_retries=3, neo4j_timeout=60s

[Processing] Dòng 4100... SUCCESS (retries: 0)
[Processing] Dòng 4101... SUCCESS (retries: 0)
[Processing] Dòng 4102... SUCCESS (retries: 0)
[Processing] Dòng 4103... SUCCESS (retries: 0)
[Processing] Dòng 4104... SUCCESS (retries: 0)
[Processing] Dòng 4105... SUCCESS (retries: 0)
[Processing] Dòng 4106... SUCCESS (retries: 0)
[Processing] Dòng 4107... SUCCESS (retries: 0)
[Processing] Dòng 4108... SUCCESS (retries: 0)
[Processing] Dòng 4109... SUCCESS (retries: 0)
[Processing] Dòng 4110... SUCCESS (retries: 0)
[Processing] Dòng 4111... SUCCESS (retries: 0)
[Processing] Dòng 4112... SUCCESS (retries: 0)
[Processing] Dòng 4113... SUCCESS (retries: 0)
[Processing] Dòn