In [2]:
%%capture
!pip install accelerate bitsandbytes huggingface_hub transformers func-timeout

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
from func_timeout import func_timeout, FunctionTimedOut
from datetime import datetime
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
import pandas as pd
import torch
import time
import os
import re

**Load model**

In [5]:
model_id = "neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1"
max_seq_length = 4096

# BitsAndBytes config cho 4bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
    low_cpu_mem_usage=True,
    device_map="auto",
)

model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/857 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/864M [00:00<?, ?B/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 3584, padding_idx=0)
    (layers): ModuleList(
      (0-41): 42 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=3584, out_features=4096, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=3584, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=3584, out_features=2048, bias=False)
            (lora_dropout):

**Load data**

In [6]:
checkpoint_path = '/content/drive/MyDrive/T2C_finetune_gemma/finetune_gemma.csv'
test_path = '/content/drive/MyDrive/T2C_finetune_gemma/text2cypher_test.csv'
test_df = pd.read_csv(test_path, encoding="utf-8-sig")

print(f"Loaded test shape: {test_df.shape}")

Loaded test shape: (4833, 6)


**Prompt template**

In [7]:
INSTRUCTION_TEMPLATE = (
    "Generate Cypher statement to query a graph database. "
    "Use only the provided relationship types and properties in the schema. \n"
    "Schema: {schema} \n Question: {question}  \n Cypher output: "
)

def prepare_chat_prompt(question, schema) -> list[dict]:
    chat = [
        {
            "role": "user",
            "content": INSTRUCTION_TEMPLATE.format(
                schema=schema,
                question=question
            ),
        }
    ]
    return chat

print("✓ Instruction template loaded")

✓ Instruction template loaded


**Generate cypher**

In [8]:
def generate_cypher_raw(question, schema, timeout=30):
    """Generate raw Cypher output từ model"""

    def _generate():
        # Chuẩn bị chat prompt
        new_message = prepare_chat_prompt(question=question, schema=schema)

        # Apply chat template (quan trọng cho Gemma!)
        prompt = tokenizer.apply_chat_template(
            new_message,
            add_generation_prompt=True,
            tokenize=False
        )

        # Tokenize
        inputs = tokenizer(prompt, return_tensors="pt", padding=True)
        inputs = inputs.to(model.device)

        # Generation parameters
        model_generate_parameters = {
            "top_p": 0.9,
            "temperature": 0.2,
            "max_new_tokens": 512,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id,
        }

        # Generate - SỬA LỖI Ở ĐÂY
        with torch.no_grad():
            tokens = model.generate(
                input_ids=inputs.input_ids,           # Truyền input_ids
                attention_mask=inputs.attention_mask,  # Truyền attention_mask
                **model_generate_parameters           # Unpack parameters
            )
            # Bỏ phần input prompt
            tokens = tokens[:, inputs.input_ids.shape[1]:]
            # Decode
            raw_output = tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]

        return raw_output

    try:
        return func_timeout(timeout, _generate)
    except FunctionTimedOut:
        print(f"[TIMEOUT] Generation exceeded {timeout}s")
        return "time_error"
    except Exception as e:
        print(f"[ERROR] Generation failed: {e}")
        return "error"

print("✓ generate_cypher_raw loaded")

✓ generate_cypher_raw loaded


In [9]:
def _postprocess_output_cypher(output_cypher: str) -> str:
    """
    Postprocess theo Neo4j official method:
    - Remove explanation (e.g., **Explanation:**)
    - Remove cypher indicator (e.g., ```cypher)
    """
    # Remove explanation
    partition_by = "**Explanation:**"
    output_cypher, _, _ = output_cypher.partition(partition_by)

    # Remove cypher code block markers
    output_cypher = output_cypher.strip("`\n")
    output_cypher = output_cypher.lstrip("cypher\n")
    output_cypher = output_cypher.strip("`\n ")

    return output_cypher

def extract_cypher_from_text(text):
    """Extract Cypher query from model output"""
    if text in ["time_error", "error"]:
        return text

    try:
        # Áp dụng postprocess của Neo4j
        cypher = _postprocess_output_cypher(text)

        # Kiểm tra có MATCH không
        if not cypher.strip():
            return "error"

        cypher_lower = cypher.lower()
        if "match" not in cypher_lower:
            return "error"

        return cypher.strip()

    except Exception as e:
        print(f"[ERROR] Extract failed: {e}")
        return "error"

print("✓ extract_cypher_from_text loaded")

✓ extract_cypher_from_text loaded


In [10]:
def generate_cypher(question, schema, timeout=30):
    """Wrapper function to generate and extract Cypher query"""
    # Generate raw output
    raw_output = generate_cypher_raw(question, schema, timeout)

    # Extract Cypher
    cypher = extract_cypher_from_text(raw_output)

    return cypher

print("✓ generate_cypher loaded")

✓ generate_cypher loaded


In [11]:
# Lấy test case
first_row = test_df.iloc[2]
test_question = first_row['question']
test_schema = first_row['schema']

print("="*80)
print("TEST QUESTION:")
print("="*80)
print(test_question)

# Test 1: Raw output
print("\n" + "="*80)
print("TEST 1: generate_cypher_raw()")
print("="*80)
raw_output = generate_cypher_raw(test_question, test_schema, timeout=30)
print("RAW OUTPUT:")
print(raw_output)

# Test 2: Extracted Cypher
print("\n" + "="*80)
print("TEST 2: generate_cypher() - EXTRACTED")
print("="*80)
final_cypher = generate_cypher(test_question, test_schema, timeout=30)
print(final_cypher)

# Test 3: Expected
print("\n" + "="*80)
print("TEST 3: EXPECTED CYPHER")
print("="*80)
print(first_row['cypher'])

TEST QUESTION:
Fetch unique values of label and description from Topic where label does not start with P!

TEST 1: generate_cypher_raw()
RAW OUTPUT:
MATCH (n:Topic) WHERE NOT n.label STARTS WITH 'P' RETURN DISTINCT n.label AS label, n.description AS description


TEST 2: generate_cypher() - EXTRACTED
MATCH (n:Topic) WHERE NOT n.label STARTS WITH 'P' RETURN DISTINCT n.label AS label, n.description AS description

TEST 3: EXPECTED CYPHER
MATCH (n:Topic) WHERE NOT n.label STARTS WITH 'P' RETURN DISTINCT n.label AS label, n.description AS description


**Chạy batch**

In [12]:
def run_with_checkpoint_gemma(
    test_df,
    checkpoint_path,
    timeout=60,
    log_interval=100
):
    """
    Chạy generation với checkpoint support cho Neo4j Gemma model
    """
    # ==========================================================================
    # BƯỚC 1: Kiểm tra và load checkpoint
    if os.path.exists(checkpoint_path):
        print(f"[CHECKPOINT] Tìm thấy file checkpoint: {checkpoint_path}")
        df = pd.read_csv(checkpoint_path, encoding="utf-8-sig")
        print(f"[CHECKPOINT] Đã load {len(df)} dòng từ checkpoint")

        # Đếm số dòng đã xử lý
        processed_count = df['cypher_generated'].notna().sum()
        print(f"[CHECKPOINT] Đã xử lý: {processed_count}/{len(df)} dòng")

    else:
        print(f"[CHECKPOINT] Không tìm thấy checkpoint, tạo mới từ test_df")
        df = test_df.copy()
        df['cypher_generated'] = ''  # Thêm cột cypher_generated trống

        # Lưu file checkpoint ban đầu
        df.to_csv(checkpoint_path, index=False, encoding='utf-8')
        print(f"[CHECKPOINT] Đã tạo file checkpoint: {checkpoint_path}")

    # ==========================================================================
    # BƯỚC 2: Xử lý các dòng chưa có kết quả
    total_rows = len(df)

    # Biến đếm cho log
    success_count = 0
    error_count = 0
    timeout_error_count = 0
    batch_start_idx = 0

    # Biến đếm để lưu checkpoint
    processed_since_last_save = 0

    print(f"\n{'='*80}")
    print(f"BẮT ĐẦU XỬ LÝ - Tổng số dòng: {total_rows}")
    print(f"Model: neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1")
    print(f"{'='*80}\n")

    start_time = time.time()

    for idx in range(total_rows):
        # Kiểm tra nếu dòng này đã có kết quả
        current_cypher = df.at[idx, 'cypher_generated']

        # Nếu đã có dữ liệu (không phải NaN và không phải chuỗi rỗng)
        if pd.notna(current_cypher) and str(current_cypher).strip() != '':
            continue  # SKIP dòng này

        # ======================================================================
        # XỬ LÝ DÒNG CHƯA CÓ KẾT QUẢ
        print(f"[Processing] Dòng {idx}...", end=" ", flush=True)

        try:
            question = df.at[idx, 'question']
            schema = df.at[idx, 'schema']

            # Gọi hàm generate_cypher
            cypher_result = generate_cypher(question, schema, timeout=timeout)

            # Lưu kết quả
            df.at[idx, 'cypher_generated'] = cypher_result

            # Đếm theo loại kết quả và in log
            if cypher_result == "error":
                error_count += 1
                print("ERROR")
            elif cypher_result == "time_error":
                timeout_error_count += 1
                print("TIMEOUT")
            else:
                success_count += 1
                print("SUCCESS")

            # Tăng biến đếm để lưu checkpoint
            processed_since_last_save += 1

        except Exception as e:
            # Xử lý lỗi không mong đợi
            print(f"ERROR - {str(e)}")
            df.at[idx, 'cypher_generated'] = "error"
            error_count += 1
            processed_since_last_save += 1

        # ======================================================================
        # LƯU CHECKPOINT SAU MỖI 50 DÒNG
        # ======================================================================
        if processed_since_last_save >= 50:
            df.to_csv(checkpoint_path, index=False, encoding='utf-8')
            print(f"[CHECKPOINT] Đã lưu sau {processed_since_last_save} dòng")
            processed_since_last_save = 0

        # ======================================================================
        # LOG THỐNG KÊ MỖI 100 DÒNG
        # ======================================================================
        if (idx + 1) % log_interval == 0:
            elapsed_time = time.time() - start_time
            avg_time_per_row = elapsed_time / (idx + 1)
            remaining_rows = total_rows - (idx + 1)
            estimated_time = avg_time_per_row * remaining_rows

            print(f"\n{'='*80}")
            print(f"[LOG] Dòng {batch_start_idx}-{idx}")
            print(f"{'='*80}")
            print(f"Thành công:     {success_count}")
            print(f"Error:          {error_count}")
            print(f"Timeout Error:  {timeout_error_count}")
            print(f"Tổng xử lý:     {success_count + error_count + timeout_error_count}")
            print(f"Tiến độ:        {idx + 1}/{total_rows} ({(idx + 1)/total_rows*100:.2f}%)")
            print(f"Thời gian:      {elapsed_time/60:.2f} phút")
            print(f"Ước tính còn:   {estimated_time/60:.2f} phút")
            print(f"{'='*80}\n")

            # Reset bộ đếm cho batch tiếp theo
            success_count = 0
            error_count = 0
            timeout_error_count = 0
            batch_start_idx = idx + 1

    # ==========================================================================
    # LƯU CHECKPOINT CUỐI CÙNG (nếu còn dòng chưa lưu)
    # ==========================================================================
    if processed_since_last_save > 0:
        df.to_csv(checkpoint_path, index=False, encoding='utf-8')
        print(f"[CHECKPOINT] Đã lưu {processed_since_last_save} dòng cuối cùng")

    # ==========================================================================
    # KẾT THÚC - LOG CUỐI CÙNG
    # ==========================================================================
    total_time = time.time() - start_time

    # Đếm tổng kết quả
    final_success = (df['cypher_generated'].notna() &
                     (df['cypher_generated'] != 'error') &
                     (df['cypher_generated'] != 'time_error') &
                     (df['cypher_generated'] != '')).sum()
    final_error = (df['cypher_generated'] == 'error').sum()
    final_timeout = (df['cypher_generated'] == 'time_error').sum()

    print(f"\n{'='*80}")
    print(f"HOÀN THÀNH")
    print(f"{'='*80}")
    print(f"Tổng số dòng:        {total_rows}")
    print(f"Thành công:          {final_success} ({final_success/total_rows*100:.2f}%)")
    print(f"Error:               {final_error} ({final_error/total_rows*100:.2f}%)")
    print(f"Timeout Error:       {final_timeout} ({final_timeout/total_rows*100:.2f}%)")
    print(f"Tổng thời gian:      {total_time/60:.2f} phút")
    print(f"Thời gian trung bình: {total_time/total_rows:.2f} giây/dòng")
    print(f"Kết quả đã lưu tại:  {checkpoint_path}")
    print(f"{'='*80}\n")

    return df

print("✓ run_with_checkpoint_gemma loaded")

✓ run_with_checkpoint_gemma loaded


In [13]:
result_df = run_with_checkpoint_gemma(
    test_df=test_df,
    checkpoint_path=checkpoint_path,
    timeout=60,
    log_interval=100
)

[CHECKPOINT] Tìm thấy file checkpoint: /content/drive/MyDrive/T2C_finetune_gemma/finetune_gemma.csv
[CHECKPOINT] Đã load 4833 dòng từ checkpoint
[CHECKPOINT] Đã xử lý: 4600/4833 dòng

BẮT ĐẦU XỬ LÝ - Tổng số dòng: 4833
Model: neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1

[Processing] Dòng 4600... SUCCESS
[Processing] Dòng 4601... SUCCESS
[Processing] Dòng 4602... SUCCESS
[Processing] Dòng 4603... SUCCESS
[Processing] Dòng 4604... SUCCESS
[Processing] Dòng 4605... SUCCESS
[Processing] Dòng 4606... SUCCESS
[Processing] Dòng 4607... SUCCESS
[Processing] Dòng 4608... SUCCESS
[Processing] Dòng 4609... SUCCESS
[Processing] Dòng 4610... SUCCESS
[Processing] Dòng 4611... SUCCESS
[Processing] Dòng 4612... SUCCESS
[Processing] Dòng 4613... SUCCESS
[Processing] Dòng 4614... SUCCESS
[Processing] Dòng 4615... [TIMEOUT] Generation exceeded 60s
TIMEOUT
[Processing] Dòng 4616... SUCCESS
[Processing] Dòng 4617... SUCCESS
[Processing] Dòng 4618... SUCCESS
[Processing] Dòng 4619... SUCCESS
[Processin