In [None]:
%%capture
!pip install accelerate bitsandbytes huggingface_hub transformers func-timeout neo4j neo4j_graphrag

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from func_timeout import func_timeout, FunctionTimedOut
from neo4j_graphrag.schema import get_structured_schema
from neo4j.exceptions import AuthError, Neo4jError
from neo4j import GraphDatabase
from datetime import datetime
import pandas as pd
import torch
import time
import os
import re

In [None]:
import logging
import warnings
logging.getLogger('neo4j').setLevel(logging.ERROR)
warnings.filterwarnings('ignore', category=FutureWarning)

**Load model**

In [None]:
model_id = "neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1"
max_seq_length = 4096

# BitsAndBytes config cho 4bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
    low_cpu_mem_usage=True,
    device_map="auto",
)

model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/857 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/864M [00:00<?, ?B/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 3584, padding_idx=0)
    (layers): ModuleList(
      (0-41): 42 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=3584, out_features=4096, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=3584, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=3584, out_features=2048, bias=False)
            (lora_dropout):

**Load data**

In [None]:
checkpoint_path = '/content/drive/MyDrive/T2C_gemma_ft_schema/gemma_ft_schema.csv'
test_path = '/content/drive/MyDrive/T2C_gemma_ft_schema/text2cypher_test.csv'
test_df = pd.read_csv(test_path, encoding="utf-8-sig")

print(f"Loaded test shape: {test_df.shape}")

Loaded test shape: (4833, 6)


**Neo4j setup**

In [None]:
URI = "neo4j+s://demo.neo4jlabs.com:7687"

# Danh sách alias
unique_aliases = test_df["database_reference_alias"].dropna().unique().tolist()
DATABASE_ALIASES = unique_aliases

# Lưu trữ drivers, schemas, examples theo alias
DRIVERS_BY_ALIAS = {}
SCHEMAS_BY_ALIAS = {}
EXAMPLES_BY_ALIAS = {}

In [None]:
def extract_alias(alias: str):
    name = alias.replace("neo4jlabs_demo_db_", "")
    return name, name

def get_driver(alias):
    if alias in DRIVERS_BY_ALIAS:
        return DRIVERS_BY_ALIAS[alias]
    user, pwd = extract_alias(alias)
    driver = GraphDatabase.driver(URI, auth=(user, pwd))
    DRIVERS_BY_ALIAS[alias] = driver
    return driver

def reset_driver(alias):
    if alias in DRIVERS_BY_ALIAS:
        try:
            DRIVERS_BY_ALIAS[alias].close()
        except Exception as e:
            print(f"Error closing driver: {e}")
        del DRIVERS_BY_ALIAS[alias]
    # Tạo lại driver mới
    return get_driver(alias)

print("✓ Neo4j helper loaded")

✓ Neo4j helper loaded


In [None]:
def safe_ident(name):
    return f"`{name}`"

def infer_type(value):
    if value is None:
        return "STRING"
    if isinstance(value, bool):
        return "BOOL"
    if isinstance(value, int):
        return "INT"
    if isinstance(value, float):
        return "FLOAT"
    return "STRING"

def is_valid_example(value, max_length=15):
    if value is None:
        return False

    # Convert sang string để check độ dài
    value_str = str(value)

    # Check độ dài NGAY LẬP TỨC
    if len(value_str) > max_length:
        return False

    if isinstance(value, str):
        val_lower = value.lower()
        if val_lower == "null":
            return False

        # Chuỗi hex dài (check này bây giờ redundant vì đã check len)
        if re.fullmatch(r"[0-9a-fA-F]+", value) and len(value) > 30:
            return False

        # Base64 dài (check này cũng redundant)
        if re.fullmatch(r"[0-9A-Za-z+/=]+", value) and len(value) > 40:
            return False

    return True

def get_sample(tx, label, prop_name, limit=1):
    label_safe = safe_ident(label)
    prop_safe = safe_ident(prop_name)

    q = (
        f"MATCH (n:{label_safe}) "
        f"WHERE n.{prop_safe} IS NOT NULL "
        f"RETURN n.{prop_safe} AS value LIMIT {limit}"
    )
    res = tx.run(q)
    return [r["value"] for r in res]

def get_relationship_sample(tx, rel_type, prop_name, limit=1):
    rel_safe = safe_ident(rel_type)
    prop_safe = safe_ident(prop_name)

    q = (
        f"MATCH ()-[r:{rel_safe}]->() "
        f"WHERE r.{prop_safe} IS NOT NULL "
        f"RETURN r.{prop_safe} AS value LIMIT {limit}"
    )
    res = tx.run(q)
    return [r["value"] for r in res]

def find_mentioned_nodes(query_text, all_node_labels):
    mentioned = set()
    query_lower = query_text.lower()

    for label in all_node_labels:
        # Hỗ trợ cả label có ký tự đặc biệt
        pattern = r'\b' + re.escape(label.lower()) + r'\b'
        if re.search(pattern, query_lower):
            mentioned.add(label)

    return mentioned

In [None]:
def example_alias(alias):
    driver = get_driver(alias)

    # Lấy schema
    try:
        schema = get_structured_schema(driver, is_enhanced=False)
    except AuthError as e:
        print(f"AuthError when getting schema for {alias}: {e}")
        driver = reset_driver(alias)
        schema = get_structured_schema(driver, is_enhanced=False)

    SCHEMAS_BY_ALIAS[alias] = schema

    node_props = schema.get("node_props", {})
    rel_props = schema.get("rel_props", {})
    examples = {"nodes": {}, "rels": {}}

    # Lấy ví dụ cho nodes
    with driver.session() as sess:
        for label, props in node_props.items():
            ex_node_props = {}
            for p in props:
                prop_name = p.get("property")
                if not prop_name:
                    continue
                try:
                    vals = sess.execute_read(get_sample, label, prop_name, 1)
                except AuthError as e:
                    print(f"AuthError sampling node {label}.{prop_name} for {alias}: {e}")
                    driver = reset_driver(alias)
                    with driver.session() as sess2:
                        vals = sess2.execute_read(get_sample, label, prop_name, 1)
                example = vals[0] if vals else None
                ex_node_props[prop_name] = example if is_valid_example(example) else None
            examples["nodes"][label] = ex_node_props

    # Lấy ví dụ cho relationships
    with driver.session() as sess:
        for rel_type, props in rel_props.items():
            ex_rel_props = {}
            for p in props:
                prop_name = p.get("property")
                if not prop_name:
                    continue
                try:
                    vals = sess.execute_read(get_relationship_sample, rel_type, prop_name, 1)
                except AuthError as e:
                    print(f"AuthError sampling rel {rel_type}.{prop_name} for {alias}: {e}")
                    driver = reset_driver(alias)
                    with driver.session() as sess2:
                        vals = sess2.execute_read(get_relationship_sample, rel_type, prop_name, 1)
                example = vals[0] if vals else None
                ex_rel_props[prop_name] = example if is_valid_example(example) else None
            examples["rels"][rel_type] = ex_rel_props

    EXAMPLES_BY_ALIAS[alias] = examples
    return SCHEMAS_BY_ALIAS[alias], EXAMPLES_BY_ALIAS[alias]

In [None]:
for alias in DATABASE_ALIASES:
    try:
        example_alias(alias)
    except Exception as e:
        print(f"Failed to precompute {alias}: {e}")

In [None]:
def convert_schema_json_format(schema, precomputed_examples, alias, node_labels_to_include=None):
    driver = get_driver(alias)

    if node_labels_to_include is None:
        node_labels_to_include = list(schema.get("node_props", {}).keys())

    unified_schema = {
        "nodes": {},
        "relationships": []
    }

    ex_nodes = precomputed_examples.get("nodes", {}) if precomputed_examples else {}

    # Convert nodes
    with driver.session() as sess:
        for label in node_labels_to_include:
            props = schema.get("node_props", {}).get(label, [])
            node_props = []

            for p in props:
                prop_name = p.get("property")
                if not prop_name:
                    continue

                # Ưu tiên dùng example đã precompute
                example = None
                if label in ex_nodes and prop_name in ex_nodes[label]:
                    example = ex_nodes[label][prop_name]

                # Nếu không có example sẵn -> truy vấn on-demand
                if example is None:
                    try:
                        vals = sess.execute_read(get_sample, label, prop_name, 1)
                    except AuthError as e:
                        print(f"AuthError in convert_schema (node) for {alias}: {e}")
                        driver = reset_driver(alias)
                        with driver.session() as sess2:
                            vals = sess2.execute_read(get_sample, label, prop_name, 1)
                    example = vals[0] if vals else None
                    if not is_valid_example(example):
                        example = None

                example_str = str(example) if example is not None else None
                dtype = infer_type(example) if example_str else "STRING"

                node_props.append({
                    "property": prop_name,
                    "type": dtype,
                    "example": example_str
                })

            unified_schema["nodes"][label] = node_props

    # Convert relationships
    ex_rels = precomputed_examples.get("rels", {}) if precomputed_examples else {}

    with driver.session() as sess:
        for rel_info in schema.get("relationships", []):
            rel_type = rel_info.get("type")
            start_label = rel_info.get("start")
            end_label = rel_info.get("end")

            if start_label in node_labels_to_include and end_label in node_labels_to_include:
                rel_props = []

                rel_prop_list = schema.get("rel_props", {}).get(rel_type, [])
                for p in rel_prop_list:
                    prop_name = p.get("property")
                    if not prop_name:
                        continue

                    example = None
                    if rel_type in ex_rels and prop_name in ex_rels[rel_type]:
                        example = ex_rels[rel_type][prop_name]

                    if example is None:
                        try:
                            vals = sess.execute_read(get_relationship_sample, rel_type, prop_name, 1)
                        except AuthError as e:
                            print(f"AuthError in convert_schema (rel) for {alias}: {e}")
                            driver = reset_driver(alias)
                            with driver.session() as sess2:
                                vals = sess2.execute_read(get_relationship_sample, rel_type, prop_name, 1)
                        example = vals[0] if vals else None
                        if not is_valid_example(example):
                            example = None

                    example_str = str(example) if example is not None else None
                    dtype = infer_type(example) if example_str else "STRING"

                    rel_props.append({
                        "property": prop_name,
                        "type": dtype,
                        "example": example_str
                    })

                unified_schema["relationships"].append({
                    "start": start_label,
                    "type": rel_type,
                    "end": end_label,
                    "properties": rel_props
                })

    return unified_schema

In [None]:
def filter_schema_by_query(query_text, alias):
    if alias not in SCHEMAS_BY_ALIAS:
        raise ValueError(f"Schema not found for alias: {alias}")

    schema = SCHEMAS_BY_ALIAS[alias]
    precomputed_examples = EXAMPLES_BY_ALIAS.get(alias)

    all_node_labels = list(schema.get("node_props", {}).keys())

    # Nếu schema có ít hơn hoặc bằng 3 nodes -> trả về full schema
    if len(all_node_labels) <= 3:
        return convert_schema_json_format(schema, precomputed_examples, alias, None)

    # Tìm mentioned nodes
    mentioned_nodes = find_mentioned_nodes(query_text, all_node_labels)

    # Không tìm thấy mentioned nodes -> trả về full schema
    if not mentioned_nodes:
        return convert_schema_json_format(schema, precomputed_examples, alias, None)

    # Có mentioned nodes và schema lớn -> filter
    return convert_schema_json_format(schema, precomputed_examples, alias, mentioned_nodes)

In [None]:
def convert_schema_markdown_format(schema_dict):
    if not schema_dict:
        return None

    md_output = []
    md_output.append("### Nodes")

    # Format nodes
    for label, props in schema_dict.get("nodes", {}).items():
        md_output.append(f"- **{label}**")

        for prop in props:
            prop_name = prop["property"]
            dtype = prop["type"]
            example = prop.get("example")

            if example:
                md_output.append(f"  - `{prop_name}`: {dtype} Example: \"{example}\"")
            else:
                md_output.append(f"  - `{prop_name}`: {dtype}")

    # Format relationships
    md_output.append("\n### Relationships")

    relationships = schema_dict.get("relationships", [])
    if not relationships:
        md_output.append("- No relationships found")
    else:
        for rel in relationships:
            start = rel["start"]
            rel_type = rel["type"]
            end = rel["end"]
            rel_props = rel.get("properties", [])

            md_output.append(f"- **({start})-[:{rel_type}]->({end})**")

            for prop in rel_props:
                prop_name = prop["property"]
                dtype = prop["type"]
                example = prop.get("example")

                if example:
                    md_output.append(f"  - `{prop_name}`: {dtype} Example: \"{example}\"")
                else:
                    md_output.append(f"  - `{prop_name}`: {dtype}")

    return "\n".join(md_output).strip()

In [None]:
def get_full_schema_formatted(alias):

    if alias not in SCHEMAS_BY_ALIAS:
        raise ValueError(f"Schema not found for alias: {alias}")

    schema = SCHEMAS_BY_ALIAS[alias]
    precomputed_examples = EXAMPLES_BY_ALIAS.get(alias)

    # Convert toàn bộ schema sang unified format
    unified = convert_schema_json_format(schema, precomputed_examples, alias, None)

    # Format sang markdown
    return convert_schema_markdown_format(unified)

In [None]:
def one_row_filter(query_text, alias):
    try:
        # Bước 1: Filter schema
        filtered_schema = filter_schema_by_query(query_text, alias)
        if not filtered_schema:
            return None

        # Bước 2: Format sang markdown
        formatted_schema = convert_schema_markdown_format(filtered_schema)
        return formatted_schema

    except Exception as e:
        print(f"Error processing {alias}: {e}")
        return None

**Prompt Cypher**

In [None]:
INSTRUCTION_TEMPLATE = (
    "Generate Cypher statement to query a graph database. \n"
    "Use ONLY the provided relationship types and properties in this sub-schema. If no relationships are listed, the query might not need any (e.g., simple node match).\n"
    "Schema: {schema} \n"
    "Question: {question}  \n"
    "Cypher output: "
)

print("✓ INSTRUCTION_TEMPLATE loaded")

✓ INSTRUCTION_TEMPLATE loaded


In [None]:
def prompt_cypher(question, schema) -> list[dict]:
    chat = [
        {
            "role": "user",
            "content": INSTRUCTION_TEMPLATE.format(
                schema=schema,
                question=question
            ),
        }
    ]
    return chat

print("✓ prompt_cypher loaded")

✓ prompt_cypher loaded


**Generate cypher raw**

In [None]:
def generate_cypher_raw(question, schema, timeout):
    def _generate():
        # Chuẩn bị chat prompt
        new_message = prompt_cypher(question=question, schema=schema)

        # Apply chat template
        prompt = tokenizer.apply_chat_template(
            new_message,
            add_generation_prompt=True,
            tokenize=False
        )

        # Tokenize
        inputs = tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = inputs.to(model.device)

        # Generation parameters
        model_generate_parameters = {
            "top_p": 0.9,
            "temperature": 0.2,
            "max_new_tokens": 512,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
        }

        # Generate - SỬA LỖI Ở ĐÂY
        with torch.no_grad():
            tokens = model.generate(
                input_ids=inputs.input_ids,           # Truyền input_ids
                attention_mask=inputs.attention_mask,  # Truyền attention_mask
                **model_generate_parameters           # Unpack parameters
            )
            # Bỏ phần input prompt
            tokens = tokens[:, inputs.input_ids.shape[1]:]
            # Decode
            raw_output = tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]

        return raw_output

    try:
        return func_timeout(timeout, _generate)
    except FunctionTimedOut:
        print(f"[TIMEOUT] Generation exceeded {timeout}s")
        return "time_error"
    except Exception as e:
        print(f"[ERROR] Generation failed: {e}")
        return "error"

print("✓ generate_cypher_raw loaded")

✓ generate_cypher_raw loaded


**Format cypher**

In [None]:
def postprocess_cypher(output_cypher: str) -> str:
    # Remove explanation
    partition_by = "**Explanation:**"
    output_cypher, _, _ = output_cypher.partition(partition_by)

    # Remove cypher code block markers
    output_cypher = output_cypher.strip("`\n")
    output_cypher = output_cypher.lstrip("cypher\n")
    output_cypher = output_cypher.strip("`\n ")

    return output_cypher

print("✓ postprocess_cypher loaded")

✓ postprocess_cypher loaded


In [None]:
def extract_cypher(text):
    if text in ["time_error", "error"]:
        return text

    try:
        # Áp dụng postprocess của Neo4j
        cypher = postprocess_cypher(text)

        # Kiểm tra có MATCH không
        if not cypher.strip():
            return "error"

        cypher_lower = cypher.lower()
        if "match" not in cypher_lower:
            return "error"

        return cypher.strip()

    except Exception as e:
        print(f"[ERROR] Extract failed: {e}")
        return "error"

print("✓ extract_cypher loaded")

✓ extract_cypher loaded


**Generate cypher**

In [None]:
def generate_cypher(question, schema, timeout):
    # Generate raw output
    raw_output = generate_cypher_raw(question, schema, timeout)

    # Extract Cypher
    cypher = extract_cypher(raw_output)

    return cypher

print("✓ generate_cypher loaded")

✓ generate_cypher loaded


In [None]:
def generate_cypher2(question, schema, alias=None, timeout=900):
    def _inner_process():
        # Kiểm tra nếu alias là null, rỗng hoặc chỉ chứa khoảng trắng
        is_alias_empty = (
            alias is None or
            (isinstance(alias, float) and pd.isna(alias)) or  # Xử lý NaN từ pandas
            (isinstance(alias, str) and alias.strip() == "")
        )

        if is_alias_empty:
            # Chỉ thực hiện bước 1 với schema
            cypher_1 = generate_cypher(question, schema, timeout=300)
            if cypher_1 in ["error"]:
                return (cypher_1, None)
            return (cypher_1, schema)

        # Nếu alias hợp lệ, thực hiện đầy đủ 3 bước
        schema_full = get_full_schema_formatted(alias)

        # BƯỚC 1: Generate Cypher lần 1
        cypher_1 = generate_cypher(question, schema_full, timeout=300)
        if cypher_1 in ["error"]:
            return (cypher_1, None)

        # BƯỚC 2: Extract schema linking từ Cypher(1)
        schema_linked = one_row_filter(cypher_1, alias)
        if schema_linked is None:
            schema_linked = schema_full

        # BƯỚC 3: Generate Cypher lần 2 với schema đã link
        cypher_2 = generate_cypher(question, schema_linked, timeout=300)
        if cypher_2 in ["error"]:
            return (cypher_1, schema_linked)

        # Return cypher(2) và schema đã link
        return (cypher_2, schema_linked)

    try:
        return func_timeout(timeout, _inner_process)
    except FunctionTimedOut:
        return ("time_error", None)
    except Exception:
        return ("error", None)

In [None]:
# Lấy test case
first_row = test_df.iloc[1]
test_question = first_row['question']
test_schema = first_row['schema']
test_alias = first_row['database_reference_alias']
cypher_1 = generate_cypher(test_question, test_schema, timeout=300)
schema_linked = one_row_filter(cypher_1, test_alias)

print("="*80)
print("TEST QUESTION:")
print("="*80)
print(test_question)

# Test 1: Cypher (1)
print("\n" + "="*80)
print("TEST 1: Cypher (1)")
print("="*80)
print(cypher_1)

# Test 2: Cypher (2)
print("\n" + "="*80)
print("TEST 2: Cypher (2)")
print("="*80)
cypher_2 = generate_cypher2(test_question, test_schema, test_alias, timeout=900)[0]
print(cypher_2)

# Test 3: Expected cypher
print("\n" + "="*80)
print("TEST 3: Expected cypher")
print("="*80)
print(first_row['cypher'])

# Test 4: Extract schema linking từ Cypher(1)
print("\n" + "="*80)
print("TEST 4: Extract schema linking")
print("="*80)
print(schema_linked)

Error processing nan: Schema not found for alias: nan
TEST QUESTION:
What are the names of the technicians that have not been assigned to repair machines?

TEST 1: Cypher (1)
MATCH (t:Technician) WHERE NOT (t) --> (:RepairAssignment) RETURN t.Name

TEST 2: Cypher (2)
MATCH (t:Technician) WHERE NOT (t) --> (:RepairAssignment) RETURN t.Name

TEST 3: Expected cypher
MATCH (t:Technician) WHERE NOT EXISTS ((:RepairAssignment)-[:ASSIGNED_TO]->(t)) RETURN t.Name

TEST 4: Extract schema linking
None


**Chạy batch**

In [None]:
def run_batch(
    test_df,
    checkpoint_path,
    timeout=1200,
    log_interval=100,
    save_interval=50
):
    """
    Chạy batch generation sử dụng hàm generate_cypher2
    """
    # ==========================================================================
    # BƯỚC 1: Kiểm tra và load checkpoint
    if os.path.exists(checkpoint_path):
        print(f"[CHECKPOINT] Tìm thấy file checkpoint: {checkpoint_path}")
        df = pd.read_csv(checkpoint_path, encoding="utf-8-sig")
        print(f"[CHECKPOINT] Đã load {len(df)} dòng từ checkpoint")

        processed_count = df['cypher_generated'].notna().sum()
        print(f"[CHECKPOINT] Đã xử lý: {processed_count}/{len(df)} dòng")

    else:
        print(f"[CHECKPOINT] Không tìm thấy checkpoint, tạo mới từ test_df")
        df = test_df.copy()
        df['cypher_generated'] = ''

        df.to_csv(checkpoint_path, index=False, encoding='utf-8')
        print(f"[CHECKPOINT] Đã tạo file checkpoint: {checkpoint_path}")

    # ==========================================================================
    # BƯỚC 2: Xử lý các dòng chưa có kết quả
    total_rows = len(df)
    batch_start_idx = 0
    processed_since_last_save = 0
    actual_processed = 0

    print(f"\n{'='*80}")
    print(f"BẮT ĐẦU XỬ LÝ - Tổng số dòng: {total_rows}")
    print(f"Model: generate_cypher2")
    print(f"Timeout per query: {timeout}s")
    print(f"{'='*80}\n")

    start_time = time.time()

    for idx in range(total_rows):
        current_cypher = df.at[idx, 'cypher_generated']

        if pd.notna(current_cypher) and str(current_cypher).strip() != '':
            continue

        # ======================================================================
        # XỬ LÝ DÒNG CHƯA CÓ KẾT QUẢ
        actual_processed += 1
        print(f"[Processing] Dòng {idx}...", end=" ", flush=True)

        try:
            question = df.at[idx, 'question']
            schema = df.at[idx, 'schema']
            alias = df.at[idx, 'database_reference_alias']

            result = generate_cypher2(question, schema, alias, timeout=timeout)[0]

            df.at[idx, 'cypher_generated'] = result

            if result == "error":
                print("ERROR")
            elif result == "time_error":
                print("TIME ERROR")
            else:
                print("SUCCESS")

            processed_since_last_save += 1

        except Exception as e:
            print("ERROR")
            df.at[idx, 'cypher_generated'] = "error"
            processed_since_last_save += 1

        # ======================================================================
        # LƯU CHECKPOINT
        # ======================================================================
        if processed_since_last_save >= save_interval:
            df.to_csv(checkpoint_path, index=False, encoding='utf-8')
            print(f"[CHECKPOINT] Đã lưu sau {processed_since_last_save} dòng")
            processed_since_last_save = 0

        # ======================================================================
        # LOG THỐNG KÊ
        # ======================================================================
        if (idx + 1) % log_interval == 0:
            elapsed_time = time.time() - start_time

            # Tính thời gian ước tính
            if actual_processed > 0:
                avg_time_per_row = elapsed_time / actual_processed
                remaining_to_process = ((df['cypher_generated'].isna()) |
                                       (df['cypher_generated'] == '')).sum()
                estimated_time = avg_time_per_row * remaining_to_process
            else:
                avg_time_per_row = 0
                estimated_time = 0

            # Đếm lại từ DataFrame cho batch hiện tại
            batch_df = df.iloc[batch_start_idx:idx+1]
            batch_success = ((batch_df['cypher_generated'].notna()) &
                           (batch_df['cypher_generated'] != 'error') &
                           (batch_df['cypher_generated'] != 'time_error') &
                           (batch_df['cypher_generated'] != '')).sum()
            batch_error = (batch_df['cypher_generated'] == 'error').sum()
            batch_timeout = (batch_df['cypher_generated'] == 'time_error').sum()

            print(f"\n{'='*80}")
            print(f"[LOG] Dòng {batch_start_idx}-{idx}")
            print(f"{'='*80}")
            print(f"Thành công:     {batch_success}")
            print(f"Error:          {batch_error}")
            print(f"Timeout Error:  {batch_timeout}")
            print(f"Tổng xử lý:     {batch_success + batch_error + batch_timeout}")
            print(f"Tiến độ:        {idx + 1}/{total_rows} ({(idx + 1)/total_rows*100:.2f}%)")
            print(f"Thời gian:      {elapsed_time/60:.2f} phút")
            print(f"Ước tính còn:   {estimated_time/60:.2f} phút")
            print(f"{'='*80}\n")

            batch_start_idx = idx + 1

    # ==========================================================================
    # LƯU CHECKPOINT CUỐI CÙNG
    # ==========================================================================
    if processed_since_last_save > 0:
        df.to_csv(checkpoint_path, index=False, encoding='utf-8')
        print(f"[CHECKPOINT] Đã lưu {processed_since_last_save} dòng cuối cùng")

    # ==========================================================================
    # KẾT THÚC - LOG CUỐI CÙNG
    # ==========================================================================
    total_time = time.time() - start_time

    final_success = (df['cypher_generated'].notna() &
                     (df['cypher_generated'] != 'error') &
                     (df['cypher_generated'] != 'time_error') &
                     (df['cypher_generated'] != '')).sum()
    final_error = (df['cypher_generated'] == 'error').sum()
    final_timeout = (df['cypher_generated'] == 'time_error').sum()

    print(f"\n{'='*80}")
    print(f"HOÀN THÀNH")
    print(f"{'='*80}")
    print(f"Tổng số dòng:        {total_rows}")
    print(f"Thành công:          {final_success} ({final_success/total_rows*100:.2f}%)")
    print(f"Error:               {final_error} ({final_error/total_rows*100:.2f}%)")
    print(f"Timeout Error:       {final_timeout} ({final_timeout/total_rows*100:.2f}%)")
    print(f"Tổng thời gian:      {total_time/60:.2f} phút")
    print(f"{'='*80}")

    return df

In [None]:
df_results = run_batch(
    test_df=test_df,
    checkpoint_path=checkpoint_path,
    timeout=1200,
    log_interval=100,
    save_interval=50
)

In [None]:
def run_batch_retry_errors(
    checkpoint_path,
    timeout=1200,
    log_interval=100,
    save_interval=50
):
    """
    Chạy lại các dòng bị error hoặc time_error từ file checkpoint
    """
    # ==========================================================================
    # BƯỚC 1: Load checkpoint
    if not os.path.exists(checkpoint_path):
        print(f"[ERROR] Không tìm thấy file checkpoint: {checkpoint_path}")
        return None

    print(f"[CHECKPOINT] Đang load file: {checkpoint_path}")
    df = pd.read_csv(checkpoint_path, encoding="utf-8-sig")
    print(f"[CHECKPOINT] Đã load {len(df)} dòng")

    # Đếm số dòng cần xử lý lại
    error_count = (df['cypher_generated'] == 'error').sum()
    timeout_count = (df['cypher_generated'] == 'time_error').sum()
    total_to_retry = error_count + timeout_count

    print(f"[INFO] Số dòng cần chạy lại:")
    print(f"  - Error:        {error_count}")
    print(f"  - Time Error:   {timeout_count}")
    print(f"  - Tổng:         {total_to_retry}")

    if total_to_retry == 0:
        print("[INFO] Không có dòng nào cần chạy lại!")
        return df

    # ==========================================================================
    # BƯỚC 2: Xử lý lại các dòng bị lỗi
    total_rows = len(df)
    batch_start_idx = 0
    processed_since_last_save = 0
    actual_processed = 0

    print(f"\n{'='*80}")
    print(f"BẮT ĐẦU XỬ LÝ LẠI")
    print(f"Model: generate_cypher2")
    print(f"Timeout per query: {timeout}s")
    print(f"{'='*80}\n")

    start_time = time.time()

    for idx in range(total_rows):
        current_cypher = df.at[idx, 'cypher_generated']

        # Skip nếu không phải error hoặc time_error
        if current_cypher not in ['error', 'time_error']:
            continue

        # ======================================================================
        # XỬ LÝ LẠI DÒNG BỊ LỖI
        actual_processed += 1
        print(f"[Processing] Dòng {idx} (Retry)...", end=" ", flush=True)

        try:
            question = df.at[idx, 'question']
            schema = df.at[idx, 'schema']
            alias = df.at[idx, 'database_reference_alias']

            result = generate_cypher2(question, schema, alias, timeout=timeout)[0]

            df.at[idx, 'cypher_generated'] = result

            if result == "error":
                print("ERROR")
            elif result == "time_error":
                print("TIME ERROR")
            else:
                print("SUCCESS")

            processed_since_last_save += 1

        except Exception as e:
            print("ERROR")
            df.at[idx, 'cypher_generated'] = "error"
            processed_since_last_save += 1

        # ======================================================================
        # LƯU CHECKPOINT
        # ======================================================================
        if processed_since_last_save >= save_interval:
            df.to_csv(checkpoint_path, index=False, encoding='utf-8')
            print(f"[CHECKPOINT] Đã lưu sau {processed_since_last_save} dòng")
            processed_since_last_save = 0

        # ======================================================================
        # LOG THỐNG KÊ
        # ======================================================================
        if (idx + 1) % log_interval == 0:
            elapsed_time = time.time() - start_time

            # Tính thời gian ước tính
            if actual_processed > 0:
                avg_time_per_row = elapsed_time / actual_processed
                remaining_to_process = ((df['cypher_generated'] == 'error') |
                                       (df['cypher_generated'] == 'time_error')).sum()
                estimated_time = avg_time_per_row * remaining_to_process
            else:
                avg_time_per_row = 0
                estimated_time = 0

            # Đếm lại từ DataFrame cho batch hiện tại
            batch_df = df.iloc[batch_start_idx:idx+1]
            batch_success = ((batch_df['cypher_generated'].notna()) &
                           (batch_df['cypher_generated'] != 'error') &
                           (batch_df['cypher_generated'] != 'time_error') &
                           (batch_df['cypher_generated'] != '')).sum()
            batch_error = (batch_df['cypher_generated'] == 'error').sum()
            batch_timeout = (batch_df['cypher_generated'] == 'time_error').sum()

            print(f"\n{'='*80}")
            print(f"[LOG] Dòng {batch_start_idx}-{idx}")
            print(f"{'='*80}")
            print(f"Thành công:     {batch_success}")
            print(f"Error:          {batch_error}")
            print(f"Timeout Error:  {batch_timeout}")
            print(f"Tổng xử lý:     {batch_success + batch_error + batch_timeout}")
            print(f"Tiến độ:        {idx + 1}/{total_rows} ({(idx + 1)/total_rows*100:.2f}%)")
            print(f"Thời gian:      {elapsed_time/60:.2f} phút")
            print(f"Ước tính còn:   {estimated_time/60:.2f} phút")
            print(f"{'='*80}\n")

            batch_start_idx = idx + 1

    # ==========================================================================
    # LƯU CHECKPOINT CUỐI CÙNG
    # ==========================================================================
    if processed_since_last_save > 0:
        df.to_csv(checkpoint_path, index=False, encoding='utf-8')
        print(f"[CHECKPOINT] Đã lưu {processed_since_last_save} dòng cuối cùng")

    # ==========================================================================
    # KẾT THÚC - LOG CUỐI CÙNG
    # ==========================================================================
    total_time = time.time() - start_time

    final_success = (df['cypher_generated'].notna() &
                     (df['cypher_generated'] != 'error') &
                     (df['cypher_generated'] != 'time_error') &
                     (df['cypher_generated'] != '')).sum()
    final_error = (df['cypher_generated'] == 'error').sum()
    final_timeout = (df['cypher_generated'] == 'time_error').sum()

    print(f"\n{'='*80}")
    print(f"HOÀN THÀNH")
    print(f"{'='*80}")
    print(f"Tổng số dòng:        {total_rows}")
    print(f"Thành công:          {final_success} ({final_success/total_rows*100:.2f}%)")
    print(f"Error:               {final_error} ({final_error/total_rows*100:.2f}%)")
    print(f"Timeout Error:       {final_timeout} ({final_timeout/total_rows*100:.2f}%)")
    print(f"Tổng thời gian:      {total_time/60:.2f} phút")
    print(f"{'='*80}")

    return df

In [None]:
# Chạy lại các dòng bị error
result_df = run_batch_retry_errors(
    checkpoint_path=checkpoint_path,
    timeout=1200,
    log_interval=100,
    save_interval=50
)

[CHECKPOINT] Đang load file: /content/drive/MyDrive/T2C_gemma_ft_schema/gemma_ft_schema.csv
[CHECKPOINT] Đã load 4833 dòng
[INFO] Số dòng cần chạy lại:
  - Error:        2363
  - Time Error:   0
  - Tổng:         2363

BẮT ĐẦU XỬ LÝ LẠI
Model: generate_cypher2
Timeout per query: 1200s

[Processing] Dòng 1 (Retry)... SUCCESS
[Processing] Dòng 2 (Retry)... SUCCESS
[Processing] Dòng 3 (Retry)... SUCCESS
[Processing] Dòng 5 (Retry)... SUCCESS
[Processing] Dòng 6 (Retry)... SUCCESS
[Processing] Dòng 7 (Retry)... SUCCESS
[Processing] Dòng 8 (Retry)... SUCCESS
[Processing] Dòng 9 (Retry)... SUCCESS
[Processing] Dòng 10 (Retry)... SUCCESS
[Processing] Dòng 14 (Retry)... SUCCESS
[Processing] Dòng 15 (Retry)... SUCCESS
[Processing] Dòng 17 (Retry)... SUCCESS
[Processing] Dòng 18 (Retry)... SUCCESS
[Processing] Dòng 21 (Retry)... SUCCESS
[Processing] Dòng 22 (Retry)... SUCCESS
[Processing] Dòng 23 (Retry)... SUCCESS
[Processing] Dòng 26 (Retry)... SUCCESS
[Processing] Dòng 29 (Retry)... SUCCESS
[

KeyboardInterrupt: 