In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [16]:
%%capture
!pip install -q func-timeout
!pip install neo4j_graphrag neo4j func-timeout

In [4]:
!curl -fsSL https://ollama.com/install.sh | sh

>>> Installing ollama to /usr/local
>>> Downloading Linux amd64 bundle
######################################################################## 100.0%
>>> Creating ollama user...
>>> Adding ollama user to video group...
>>> Adding current user to ollama group...
>>> Creating ollama systemd service...
>>> The Ollama API is now available at 127.0.0.1:11434.
>>> Install complete. Run "ollama" from the command line.


In [17]:
from neo4j_graphrag.schema import get_structured_schema
from func_timeout import func_timeout, FunctionTimedOut
from neo4j.exceptions import AuthError, Neo4jError
from neo4j import GraphDatabase
from datetime import datetime
import pandas as pd
import subprocess
import requests
import torch
import time
import json
import os
import re

In [48]:
import logging
import warnings
logging.getLogger('neo4j').setLevel(logging.ERROR)
warnings.filterwarnings('ignore', category=FutureWarning)

**Khởi động Ollama**

In [6]:
ollama_process = subprocess.Popen(['ollama', 'serve'],
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE)

print("Đang khởi động Ollama server...")
time.sleep(5)

# Kiểm tra server
try:
    response = requests.get('http://localhost:11434')
    print("✓ Ollama server đã sẵn sàng!")
except:
    print("✗ Lỗi khi khởi động server")

Đang khởi động Ollama server...
✓ Ollama server đã sẵn sàng!


**Pull model**

In [7]:
!ollama pull qwen2.5-coder:14b

[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2026h[?25l[A[1G[?25h[?2026l[?2

In [29]:
# Cell mới - Kiểm tra model
!ollama list

NAME                 ID              SIZE      MODIFIED       
qwen2.5-coder:14b    9ec8897f747e    9.0 GB    26 minutes ago    


**Load data**

In [46]:
test_path = '/content/drive/MyDrive/T2C_qwen214b_bs_loop/qwen214b_bs_loop.csv'
checkpoint_path = '/content/drive/MyDrive/T2C_qwen214b_bs_loop/qwen214b_bs_loop_result.csv'
test_df = pd.read_csv(test_path, encoding="utf-8-sig")
print(f"✓ Loaded test data shape: {test_df.shape}")

✓ Loaded test data shape: (4833, 7)


In [18]:
URI = "neo4j+s://demo.neo4jlabs.com:7687"

# Danh sách alias
unique_aliases = test_df["database_reference_alias"].dropna().unique().tolist()
DATABASE_ALIASES = unique_aliases

# Lưu trữ drivers, schemas, examples theo alias
DRIVERS_BY_ALIAS = {}

In [19]:
def extract_alias(alias: str):
    """Extract username và password từ alias"""
    name = alias.replace("neo4jlabs_demo_db_", "")
    return name, name

def get_driver(alias):
    """Lấy driver đã tồn tại hoặc tạo mới"""
    if alias in DRIVERS_BY_ALIAS:
        return DRIVERS_BY_ALIAS[alias]
    user, pwd = extract_alias(alias)
    driver = GraphDatabase.driver(URI, auth=(user, pwd))
    DRIVERS_BY_ALIAS[alias] = driver
    return driver

def reset_driver(alias):
    """Reset driver khi gặp lỗi auth"""
    print(f"Resetting driver for alias: {alias}")
    if alias in DRIVERS_BY_ALIAS:
        try:
            DRIVERS_BY_ALIAS[alias].close()
        except Exception as e:
            print(f"Error closing driver: {e}")
        del DRIVERS_BY_ALIAS[alias]

    # Tạo lại driver mới
    return get_driver(alias)

**Tạo prompt**

In [10]:
def prompt(question, schema):
    system_message = """Task: Generate a Cypher statement to query a graph database. Instructions: Use only the provided relationship types and properties in the schema. Do not use any other relationship types or properties that are not provided in the schema. Do not include any explanations or apologies in your responses. Do not respond to any questions that ask anything other than constructing a Cypher statement. Do not include any text except the generated Cypher statement."""

    user_content = f"""Generate Cypher statement to query a graph database. Use only the provided relationship types and properties in the schema.
Schema: {schema}
Question: {question}
Cypher output:"""

    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_content}
    ]

    return messages

print("✓ prompt loaded")

✓ prompt loaded


**Tạo cypher**

In [11]:
def generate_cypher(question, schema, timeout=300):
    messages = prompt(question, schema)

    def _generate():
        # Gọi Ollama API
        response = requests.post(
            'http://localhost:11434/api/chat',
            json={
                "model": "qwen2.5-coder:14b",
                "messages": messages,
                "stream": False,
                "options": {
                    "temperature": 0.1,
                    "top_p": 0.9,
                    "num_predict": 256
                }
            },
            timeout=timeout
        )

        if response.status_code == 200:
            result = response.json()
            content = result['message']['content']

            # Bước 1: Lấy nội dung trong cặp ``` ```
            code_block_match = re.search(r'```(?:cypher)?\s*(.*?)```', content, re.DOTALL | re.IGNORECASE)
            if code_block_match:
                content = code_block_match.group(1).strip()

            # Bước 2: Tìm và lấy từ MATCH trở đi (case-insensitive)
            match_pos = re.search(r'\b(MATCH|match)\b', content)
            if match_pos:
                content = content[match_pos.start():]

            # Bước 3: Xử lý content như code cũ
            content = content.strip().replace('\n', ' ')
            content = re.sub(r'\s+', ' ', content)
            content = content.rstrip(';').strip()

            return content
        else:
            print(f"[ERROR] API returned status {response.status_code}")
            return "error"

    try:
        return func_timeout(timeout, _generate)
    except FunctionTimedOut:
        print(f"[TIMEOUT] Generation exceeded {timeout}s")
        return "time_error"
    except Exception as e:
        print(f"[ERROR] Generation failed: {e}")
        return "error"

print("✓ generate_cypher loaded")

✓ generate_cypher loaded


**Dòng mẫu**

In [12]:
first_row = test_df.iloc[5]
test_question = first_row['question']
test_schema = first_row['schema']

print("="*80)
print("TEST QUESTION:")
print("="*80)
print(test_question)

# ============================================================================
# Test 1: generate_cypher
# ============================================================================
print("\n" + "="*80)
print("TEST 1: Generated Cypher")
print("="*80)
final_cypher = generate_cypher(test_question, test_schema, timeout=300)
print(final_cypher)

# ============================================================================
# Test 2: EXPECTED CYPHER
# ============================================================================
print("\n" + "="*80)
print("TEST 2: Expected Cypher")
print("="*80)
if 'cypher' in test_df.columns:
    print(first_row['cypher'])
else:
    print("N/A")

TEST QUESTION:
What are the district names and city populations for all districts that between 200,000 and 2,000,000 residents?

TEST 1: Generated Cypher
MATCH (d:District) WHERE d.City_Population >= 200000 AND d.City_Population <= 2000000 RETURN d.District_name, d.City_Population

TEST 2: Expected Cypher
MATCH (n:District) WHERE n.City_Population >= 200000 AND n.City_Population <= 2000000 RETURN n.District_name, n.City_Population


**Prompt Self Correction**

In [20]:
def prompt_correction(schema_context, question, cypher_current, error):
    system_message = "You are an expert at fixing Cypher queries. Analyze the error carefully and fix the query."

    additional_hint = ""
    # Syntax Error
    if "syntax error" in error.lower() or "invalid syntax" in error.lower():
        additional_hint = """
CRITICAL: Syntax error detected. Follow these steps to fix:

Step 1 - ANALYZE the error message:
- Identify which keyword/token is causing the error
- Check the position (line, column) where error occurs
- Understand what the parser expected at that position

Step 2 - COMMON SYNTAX RULES in Cypher:
- WHERE clause MUST come BEFORE RETURN, never after RETURN
- If you need multiple conditions, combine them with AND/OR in the same WHERE clause
- WITH clause cannot be the last clause - must be followed by RETURN or another clause
- RETURN must be the final clause (unless using UNION or other set operations)
- Cypher does NOT support GROUP BY - use aggregation functions in WITH instead
- Pattern expressions in WHERE must use pattern comprehension: SIZE([(pattern) | var])

Step 3 - FIX the query:
- Move misplaced clauses to correct position
- Combine multiple WHERE clauses into one
- Add missing RETURN if query ends with WITH
- Remove unsupported SQL syntax (GROUP BY, HAVING, etc.)

Example of common mistakes:
WRONG: MATCH (n) WHERE condition1 RETURN n WHERE condition2
CORRECT: MATCH (n) WHERE condition1 AND condition2 RETURN n

WRONG: WITH n.property, COUNT(*) AS count
CORRECT: WITH n.property AS property, COUNT(*) AS count
"""

    # Unknown/Missing properties or labels
    elif "unknown" in error.lower() or "does not exist" in error.lower() or "not found" in error.lower():
        additional_hint = """
CRITICAL: Unknown label/property detected. Follow these steps to fix:

Step 1 - IDENTIFY what is missing:
- Check if you're using a label that doesn't exist
- Check if you're accessing a property that doesn't exist
- Verify the entity type: is it a node property or relationship property?

Step 2 - VERIFY against schema:
- Node properties are listed under "Nodes" section
- Relationship properties are listed under "Relationships" section
- Labels use colon syntax (:Label), properties use dot syntax (variable.property)

Step 3 - FIX the query:
- Use correct property name from schema
- Access property from correct entity (node vs relationship)
- Replace incorrect label syntax with property access

Example of common mistakes:
WRONG: MATCH (n) WHERE n:`propertyName` = value
CORRECT: MATCH (n) WHERE n.propertyName = value

WRONG: Using relationship variable for node property: AVG(rel.nodeProperty)
CORRECT: Using node variable for node property: AVG(node.nodeProperty)
"""

    # Expression/Aliasing error
    elif "must be aliased" in error.lower() or "alias" in error.lower():
        additional_hint = """
CRITICAL: Aliasing error detected. Follow these steps to fix:

Step 1 - UNDERSTAND aliasing rules:
- In WITH clause, expressions must be aliased using AS
- Property expressions cannot be used directly without aliasing
- Aggregation results must be aliased

Step 2 - TWO APPROACHES to fix:
Approach A - Alias each property:
  WITH node.prop1 AS prop1, node.prop2 AS prop2, AGG(...) AS result

Approach B - Use node variable:
  WITH node, AGG(...) AS result
  RETURN node.prop1, node.prop2, result

Step 3 - CHOOSE the simpler approach:
- If you need many properties: use Approach B (pass entire node)
- If you need few properties: use Approach A (alias each one)
"""

    # Pattern/Structure error
    elif "pattern" in error.lower() or "cannot conclude" in error.lower():
        additional_hint = """
CRITICAL: Query structure error detected. Follow these steps to fix:

Step 1 - CHECK query structure:
- Does the query end with a proper clause? (RETURN, CREATE, DELETE, etc.)
- Are pattern expressions used correctly?
- Is SIZE() used with pattern comprehension?

Step 2 - COMMON STRUCTURE RULES:
- Query MUST end with RETURN (or update clause)
- WITH ... ORDER BY ... LIMIT MUST be followed by RETURN
- Pattern in SIZE() needs comprehension: SIZE([(pattern) | var]) not SIZE((pattern))
- Pattern existence check: use EXISTS { (pattern) } in WHERE

Step 3 - FIX based on rule violated:
- Add RETURN clause if missing
- Wrap pattern in comprehension for SIZE()
- Change to EXISTS if only checking presence

Example of common mistakes:
WRONG: WITH node, COUNT(*) AS count ORDER BY count LIMIT 10
CORRECT: WITH node, COUNT(*) AS count ORDER BY count LIMIT 10 RETURN node, count

WRONG: WHERE SIZE((a)-[:REL]->(b)) > 5
CORRECT: WHERE SIZE([(a)-[:REL]->(b) | a]) > 5
"""

    # Build user content
    user_content = f"""Fix the following Cypher query based on the error message.

Schema:
{schema_context}

Original Question: {question}

Wrong Cypher Query:
{cypher_current}

Error Message:
{error}
"""

    # Add hint if available
    if additional_hint:
        user_content += f"\n{additional_hint}\n"

    user_content += "\nFollow the steps above to analyze and fix the query. Return ONLY the corrected Cypher statement, no explanations.\n\nCorrected Cypher output:"

    # Build messages with system and user roles
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_content}
    ]

    return messages

print("✓ prompt_correction loaded")

✓ prompt_correction loaded


**Neo4j Execution**

In [14]:
def execute_cypher(cypher_query, alias, timeout=180):
    if cypher_query in ["error", "time_error", None, ""]:
        return (False, "Invalid cypher query")

    driver = get_driver(alias)

    try:
        with driver.session() as session:
            # Chỉ execute query, KHÔNG consume result
            session.run(cypher_query, timeout=timeout)
            return (True, None)
    except Exception as e:
        error_msg = str(e)
        # Nếu lỗi authentication, reset driver
        if "authentication" in error_msg.lower() or "unauthorized" in error_msg.lower():
            reset_driver(alias)
        return (False, error_msg)

print("✓ execute_cypher loaded")

✓ execute_cypher loaded


In [15]:
def explain_cypher(cypher_query, driver, timeout=180):
    if cypher_query in ["error", None, ""]:
        return (False, "Invalid cypher query")

    try:
        with driver.session() as session:
            explain_query = f"EXPLAIN {cypher_query}"
            session.run(explain_query, timeout=timeout)
            return (True, None)
    except Exception as e:
        return (False, str(e))

print("✓ explain_cypher loaded")

✓ explain_cypher loaded


**Setup model Self Loop**

In [22]:
def llm_correct_cypher(schema, question, cypher_current, error, timeout=300):
    """
    Sử dụng LLM để sửa lỗi Cypher query
    """
    messages = prompt_correction(
        schema_context=schema,
        question=question,
        cypher_current=cypher_current,
        error=error
    )

    def _correct():
        # Gọi Ollama API
        response = requests.post(
            'http://localhost:11434/api/chat',
            json={
                "model": "qwen2.5-coder:14b",
                "messages": messages,
                "stream": False,
                "options": {
                    "temperature": 0.1,
                    "top_p": 0.9,
                    "num_predict": 256
                }
            },
            timeout=timeout
        )

        if response.status_code == 200:
            result = response.json()
            content = result['message']['content']

            # Bước 1: Lấy nội dung trong cặp ``` ```
            code_block_match = re.search(r'```(?:cypher)?\s*(.*?)```', content, re.DOTALL | re.IGNORECASE)
            if code_block_match:
                content = code_block_match.group(1).strip()

            # Bước 2: Tìm và lấy từ MATCH trở đi (case-insensitive)
            match_pos = re.search(r'\b(MATCH|match)\b', content)
            if match_pos:
                content = content[match_pos.start():]

            # Bước 3: Xử lý content
            content = content.strip().replace('\n', ' ')
            content = re.sub(r'\s+', ' ', content)
            content = content.rstrip(';').strip()

            return content
        else:
            print(f"[ERROR] API returned status {response.status_code}")
            return "error"

    try:
        return func_timeout(timeout, _correct)
    except FunctionTimedOut:
        print(f"[TIMEOUT] Correction exceeded {timeout}s")
        return "time_error"
    except Exception as e:
        print(f"[ERROR] Correction failed: {e}")
        return "error"

print("✓ llm_correct_cypher loaded")

✓ llm_correct_cypher loaded


In [23]:
def cypher_self_correction(
    cypher_initial,
    alias,
    schema_context,
    question,
    max_retries=3,
    timeout=900
):
    cypher_current = cypher_initial
    retry = 0
    errors_history = []

    while retry < max_retries:
        success, error = execute_cypher(cypher_current, alias, timeout=timeout)

        if success:
            return {
                "success": True,
                "final_cypher": cypher_current,
                "retries": retry,
                "errors": errors_history
            }

        errors_history.append(f"Retry {retry}: {error}")

        cypher_corrected = llm_correct_cypher(
            schema=schema_context,
            question=question,
            cypher_current=cypher_current,
            error=error
        )

        if cypher_corrected in ["error", "time_error", None, ""]:
            break

        cypher_current = cypher_corrected
        retry += 1

    return {
        "success": False,
        "final_cypher": cypher_current,
        "retries": retry,
        "errors": errors_history
    }

In [24]:
def cypher_explain_correction(
    cypher_initial,
    driver,
    schema_context,
    question,
    max_retries=3,
    timeout=900
):
    cypher_current = cypher_initial
    retry = 0
    errors_history = []

    while retry < max_retries:
        success, error = explain_cypher(cypher_current, driver, timeout=timeout)

        if success:
            return {
                "success": True,
                "final_cypher": cypher_current,
                "retries": retry,
                "errors": errors_history
            }

        errors_history.append(f"Retry {retry}: {error}")

        cypher_corrected = llm_correct_cypher(
            schema=schema_context,
            question=question,
            cypher_current=cypher_current,
            error=error
        )

        if cypher_corrected in ["error", "time_error", None, ""]:
            break

        cypher_current = cypher_corrected
        retry += 1

    return {
        "success": False,
        "final_cypher": cypher_current,
        "retries": retry,
        "errors": errors_history
    }

In [25]:
def generate_cypher_with_correction(
    cypher_initial,
    question,
    schema,
    alias,
    max_retries=3,
    timeout=1200,
    neo4j_timeout=180
):
    # Kiểm tra cypher_initial
    if pd.isna(cypher_initial) or cypher_initial is None or cypher_initial == "" or cypher_initial in ["error", "time_error"]:
        return {
            "cypher_initial": cypher_initial,
            "schema": schema,
            "correction_result": None,
            "final_cypher": "error",
            "success": False,
            "retries": 0,
            "errors": ["Invalid cypher_initial value"]
        }

    # Kiểm tra alias trước
    if pd.isna(alias) or alias is None or alias == "" or (isinstance(alias, str) and alias.strip() == ""):
        # Nếu alias không hợp lệ, sử dụng syntax validation với EXPLAIN
        def _execute_validation():
            # Lấy driver đầu tiên để validate syntax
            first_alias = list(DRIVERS_BY_ALIAS.keys())[0] if DRIVERS_BY_ALIAS else DATABASE_ALIASES[0]
            driver = get_driver(first_alias)

            # Syntax validation loop với EXPLAIN
            validation_result = cypher_explain_correction(
                cypher_initial=cypher_initial,
                driver=driver,
                schema_context=schema,
                question=question,
                max_retries=max_retries,
                timeout=neo4j_timeout
            )

            final_cypher = validation_result.get("final_cypher")
            if final_cypher is None:
                final_cypher = "error"

            return {
                "cypher_initial": cypher_initial,
                "schema": schema,
                "correction_result": validation_result,
                "final_cypher": final_cypher,
                "success": validation_result.get("success"),
                "retries": validation_result.get("retries"),
                "errors": validation_result.get("errors")
            }

        try:
            result = func_timeout(timeout, _execute_validation)
            return result

        except FunctionTimedOut:
            return {
                "cypher_initial": cypher_initial,
                "schema": schema,
                "correction_result": None,
                "final_cypher": "error",
                "success": False,
                "retries": 0,
                "errors": [f"Total timeout reached after {timeout}s"]
            }
        except Exception as e:
            return {
                "cypher_initial": cypher_initial,
                "schema": schema,
                "correction_result": None,
                "final_cypher": "error",
                "success": False,
                "retries": 0,
                "errors": [f"Unexpected error: {str(e)}"]
            }

    # Alias hợp lệ - chạy execution thật
    def _execute_generation():
        # Self-correction loop
        correction_result = cypher_self_correction(
            cypher_initial=cypher_initial,
            alias=alias,
            schema_context=schema,
            question=question,
            max_retries=max_retries,
            timeout=neo4j_timeout
        )

        final_cypher = correction_result.get("final_cypher")
        if final_cypher is None:
            final_cypher = "error"

        return {
            "cypher_initial": cypher_initial,
            "schema": schema,
            "correction_result": correction_result,
            "final_cypher": final_cypher,
            "success": correction_result.get("success"),
            "retries": correction_result.get("retries"),
            "errors": correction_result.get("errors")
        }

    try:
        result = func_timeout(timeout, _execute_generation)
        return result

    except FunctionTimedOut:
        return {
            "cypher_initial": cypher_initial,
            "schema": schema,
            "correction_result": None,
            "final_cypher": "error",
            "success": False,
            "retries": 0,
            "errors": [f"Total timeout reached after {timeout}s"]
        }
    except Exception as e:
        return {
            "cypher_initial": cypher_initial,
            "schema": schema,
            "correction_result": None,
            "final_cypher": "error",
            "success": False,
            "retries": 0,
            "errors": [f"Unexpected error: {str(e)}"]
        }

print("✓ generate_cypher_with_correction loaded")

✓ generate_cypher_with_correction loaded


In [34]:
# Lấy test case
first_row = test_df.iloc[1124]
test_question = first_row['question']
test_schema = first_row['schema']
test_alias = first_row['database_reference_alias']
test_cypher = first_row['cypher_generated']

# In ra test case

print("="*80)
print("TEST QUESTION:")
print("="*80)
print(test_question)

# Test 1: Extracted Cypher
print("\n" + "="*80)
print("TEST 1: generate_cypher() - EXTRACTED")
print("="*80)
result = generate_cypher_with_correction(
        cypher_initial = test_cypher,
        question=test_question,
        schema=test_schema,
        alias=test_alias,
        max_retries=3,
        timeout=1200,
        neo4j_timeout=180
    )
print(f"Success: {result['success']}")
print(f"Retries: {result['retries']}")
print(f"Cypher Initial:\n{result['cypher_initial']}")
print(f"\nFinal Cypher:\n{result['final_cypher']}")
if result["errors"]:
    print("Errors:")
    for e in result["errors"]:
        print(e)

# Test 2: Expected
print("\n" + "="*80)
print("TEST 2: EXPECTED CYPHER")
print("="*80)
print(first_row['cypher'])

TEST QUESTION:
Show the first 3 businesses that have reviews mentioning 'sandwich'.

TEST 1: generate_cypher() - EXTRACTED
Success: True
Retries: 1
Cypher Initial:
MATCH (b:Business)-[r:REVIEWS]-(rev:Review {text: CONTAINS('sandwich')}) RETURN b.name, r.stars LIMIT 3

Final Cypher:
MATCH (b:Business)-[r:REVIEWS]-(rev:Review {text: 'sandwich'}) RETURN b.name, r.stars LIMIT 3
Errors:
Retry 0: {code: Neo.ClientError.Statement.SyntaxError} {message: Unknown function 'CONTAINS' (line 1, column 51 (offset: 50))
"MATCH (b:Business)-[r:REVIEWS]-(rev:Review {text: CONTAINS('sandwich')}) RETURN b.name, r.stars LIMIT 3"
                                                   ^}

TEST 2: EXPECTED CYPHER
MATCH (b:Business)<-[:REVIEWS]-(r:Review) WHERE r.text CONTAINS 'sandwich' RETURN b.name, b.address, b.city, b.state LIMIT 3


**Reset server mỗi 500 dòng**

In [35]:
def restart_ollama_server():
    global ollama_process

    print("[RESTART] Đang restart Ollama server...")

    # Kill process cũ
    try:
        ollama_process.terminate()
        ollama_process.wait(timeout=10)
    except:
        ollama_process.kill()

    time.sleep(3)

    # Khởi động lại
    ollama_process = subprocess.Popen(['ollama', 'serve'],
                                       stdout=subprocess.PIPE,
                                       stderr=subprocess.PIPE)
    time.sleep(5)

    print("[RESTART] Ollama server đã khởi động lại!")

In [43]:
def run_batch(
    test_df,
    checkpoint_path,
    max_retries=3,
    timeout=1200,
    neo4j_timeout=180,
    log_interval=50,
    restart_interval=250
):
    """
    Chạy batch correction với restart server và tính thời gian chính xác
    """
    # ==========================================================================
    # BƯỚC 1: Kiểm tra và load checkpoint
    if os.path.exists(checkpoint_path):
        print(f"[CHECKPOINT] Tìm thấy file checkpoint: {checkpoint_path}")
        df = pd.read_csv(checkpoint_path, encoding="utf-8-sig")
        print(f"[CHECKPOINT] Đã load {len(df)} dòng từ checkpoint")

        processed_count = (
            (df['cypher_generated2'].notna()) &
            (df['cypher_generated2'].astype(str).str.strip() != '')
        ).sum()
        print(f"[CHECKPOINT] Đã xử lý: {processed_count}/{len(df)} dòng")

    else:
        print(f"[CHECKPOINT] Không tìm thấy checkpoint, tạo mới từ test_df")
        df = test_df.copy()
        df['cypher_generated2'] = ''

        df.to_csv(checkpoint_path, index=False, encoding='utf-8-sig')
        print(f"[CHECKPOINT] Đã tạo file checkpoint: {checkpoint_path}")

    # ==========================================================================
    # BƯỚC 2: Xử lý các dòng chưa có kết quả
    total_rows = len(df)
    batch_start_idx = 0
    processed_since_last_log = 0
    processed_since_last_restart = 0
    actual_processed = 0

    print(f"\n{'='*80}")
    print(f"BẮT ĐẦU XỬ LÝ - Tổng số dòng: {total_rows}")
    print(f"Max retries: {max_retries}")
    print(f"Timeout: {timeout}s")
    print(f"Neo4j timeout: {neo4j_timeout}s")
    print(f"Restart mỗi: {restart_interval} dòng")
    print(f"{'='*80}\n")

    restart_ollama_server()
    start_time = time.time()

    for idx in range(total_rows):
        current_cypher = df.at[idx, 'cypher_generated2']

        if pd.notna(current_cypher) and str(current_cypher).strip() != '':
            continue

        # ======================================================================
        # XỬ LÝ DÒNG CHƯA CÓ KẾT QUẢ
        actual_processed += 1
        print(f"[Processing] Dòng {idx}...", end=" ", flush=True)

        try:
            cypher_initial = df.at[idx, 'cypher_generated']
            question = df.at[idx, 'question']
            schema = df.at[idx, 'schema']
            alias = df.at[idx, 'database_reference_alias']

            result = generate_cypher_with_correction(
                cypher_initial=cypher_initial,
                question=question,
                schema=schema,
                alias=alias,
                max_retries=max_retries,
                timeout=timeout,
                neo4j_timeout=neo4j_timeout
            )

            cypher_result = result['final_cypher']

            df.at[idx, 'cypher_generated2'] = cypher_result

            if cypher_result == "error":
                print("ERROR")
            elif cypher_result == "time_error":
                print("TIME ERROR")
            else:
                print("SUCCESS")

            processed_since_last_log += 1
            processed_since_last_restart += 1

        except Exception as e:
            print(f"ERROR - {str(e)}")
            df.at[idx, 'cypher_generated2'] = "error"
            processed_since_last_log += 1
            processed_since_last_restart += 1

        # ======================================================================
        # RESTART OLLAMA SERVER SAU MỖI restart_interval DÒNG
        # ======================================================================
        if processed_since_last_restart >= restart_interval:
            df.to_csv(checkpoint_path, index=False, encoding='utf-8-sig')
            print(f"[CHECKPOINT] Đã lưu trước khi restart")

            restart_ollama_server()

            processed_since_last_restart = 0
            processed_since_last_log = 0

        # ======================================================================
        # LOG THỐNG KÊ VÀ LƯU CHECKPOINT
        # ======================================================================
        elif processed_since_last_log >= log_interval:
            df.to_csv(checkpoint_path, index=False, encoding='utf-8-sig')

            elapsed_time = time.time() - start_time

            # Tính thời gian ước tính ĐÚNG
            if actual_processed > 0:
                avg_time_per_row = elapsed_time / actual_processed
                remaining_to_process = (
                    (df['cypher_generated2'].isna()) |
                    (df['cypher_generated2'].astype(str).str.strip() == '')
                ).sum()
                estimated_time = avg_time_per_row * remaining_to_process
            else:
                avg_time_per_row = 0
                estimated_time = 0

            # Đếm lại từ DataFrame cho batch hiện tại
            batch_df = df.iloc[batch_start_idx:idx+1]
            batch_success = (
                (batch_df['cypher_generated2'].notna()) &
                (batch_df['cypher_generated2'] != 'error') &
                (batch_df['cypher_generated2'] != 'time_error') &
                (batch_df['cypher_generated2'].astype(str).str.strip() != '')
            ).sum()
            batch_error = (batch_df['cypher_generated2'] == 'error').sum()
            batch_timeout = (batch_df['cypher_generated2'] == 'time_error').sum()

            print(f"\n{'='*80}")
            print(f"[LOG] Dòng {batch_start_idx}-{idx}")
            print(f"{'='*80}")
            print(f"Thành công:     {batch_success}")
            print(f"Error:          {batch_error}")
            print(f"Timeout Error:  {batch_timeout}")
            print(f"Tổng xử lý:     {batch_success + batch_error + batch_timeout}")
            print(f"Tiến độ:        {idx + 1}/{total_rows} ({(idx + 1)/total_rows*100:.2f}%)")
            print(f"Thời gian:      {elapsed_time/60:.2f} phút")
            print(f"Ước tính còn:   {estimated_time/60:.2f} phút")
            print(f"[CHECKPOINT] Đã lưu sau {processed_since_last_log} dòng")
            print(f"{'='*80}\n")

            batch_start_idx = idx + 1
            processed_since_last_log = 0

    # ==========================================================================
    # LƯU CHECKPOINT CUỐI CÙNG
    # ==========================================================================
    if processed_since_last_log > 0:
        df.to_csv(checkpoint_path, index=False, encoding='utf-8-sig')
        print(f"[CHECKPOINT] Đã lưu {processed_since_last_log} dòng cuối cùng")

    # ==========================================================================
    # KẾT THÚC - LOG CUỐI CÙNG
    # ==========================================================================
    total_time = time.time() - start_time

    final_success = (
        (df['cypher_generated2'].notna()) &
        (df['cypher_generated2'] != 'error') &
        (df['cypher_generated2'] != 'time_error') &
        (df['cypher_generated2'].astype(str).str.strip() != '')
    ).sum()
    final_error = (df['cypher_generated2'] == 'error').sum()
    final_timeout = (df['cypher_generated2'] == 'time_error').sum()

    print(f"\n{'='*80}")
    print(f"HOÀN THÀNH")
    print(f"{'='*80}")
    print(f"Tổng số dòng:        {total_rows}")
    print(f"Thành công:          {final_success} ({final_success/total_rows*100:.2f}%)")
    print(f"Error:               {final_error} ({final_error/total_rows*100:.2f}%)")
    print(f"Timeout Error:       {final_timeout} ({final_timeout/total_rows*100:.2f}%)")
    print(f"Tổng thời gian:      {total_time/60:.2f} phút")
    print(f"{'='*80}")

    return df

print("✓ run_batch loaded")

✓ run_batch loaded


In [49]:
result_df = run_batch(
    test_df=test_df,
    checkpoint_path=checkpoint_path
)

[CHECKPOINT] Tìm thấy file checkpoint: /content/drive/MyDrive/T2C_qwen214b_bs_loop/qwen214b_bs_loop_result.csv
[CHECKPOINT] Đã load 4833 dòng từ checkpoint
[CHECKPOINT] Đã xử lý: 0/4833 dòng

BẮT ĐẦU XỬ LÝ - Tổng số dòng: 4833
Max retries: 3
Timeout: 1200s
Neo4j timeout: 180s
Restart mỗi: 250 dòng

[RESTART] Đang restart Ollama server...
[RESTART] Ollama server đã khởi động lại!
[Processing] Dòng 0... SUCCESS
[Processing] Dòng 1... SUCCESS
[Processing] Dòng 2... SUCCESS
[Processing] Dòng 3... SUCCESS
[Processing] Dòng 4... SUCCESS
[Processing] Dòng 5... SUCCESS
[Processing] Dòng 6... SUCCESS
[Processing] Dòng 7... SUCCESS
[Processing] Dòng 8... SUCCESS
[Processing] Dòng 9... SUCCESS
[Processing] Dòng 10... SUCCESS
[Processing] Dòng 11... SUCCESS
[Processing] Dòng 12... SUCCESS
[Processing] Dòng 13... 

KeyboardInterrupt: 