In [None]:
!pip install camel-tools arabic-reshaper nltk transformers faiss-gpu chromadb evaluate

In [None]:
import re
import os
import ast
import sqlite3
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
import json
import arabic_reshaper
import nltk
import torch
from wordcloud import WordCloud
from bidi.algorithm import get_display
from camel_tools.utils.charsets import AR_LETTERS_CHARSET
from camel_tools.utils.dediac import dediac_ar
from camel_tools.tokenizers.word import simple_word_tokenize
from transformers import pipeline
from nltk.corpus import stopwords
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments
from sklearn.model_selection import train_test_split
import evaluate
import faiss
import chromadb

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_SILENT"] = "true"

In [None]:
device = 0 if torch.cuda.is_available() else -1  

In [None]:
dataset_path = "/kaggle/input/arabic-to-sql/Text To SQL Task/Dataset/AR_spider.jsonl"
data = [json.loads(line) for line in open(dataset_path, "r", encoding="utf-8")]

In [None]:
df = pd.DataFrame(data)

In [None]:
df.sample(5)

In [None]:
df.info()

In [None]:
df['query'] = df['query'].str.lower()
df['query'] = df['query'].str.strip() 

In [None]:
df.isnull().sum()

In [None]:
sql_text = " ".join(df["query"])
wordcloud = WordCloud(width=800, height=400, background_color="white").generate(sql_text)

plt.figure(figsize=(10, 5))
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis("off")
plt.title("SQL Query Word Cloud")
plt.show()

# Arabic text processing

In [None]:
def normalize(text):
    text = text.replace("إ", "ا").replace("أ", "ا").replace("آ", "ا")
    text = text.replace("ى", "ي")
    text = text.replace("ة", "ه")
    text = text.replace("ـ", "")
    text = text.replace('?' , '')
    text = text.replace('.' , '')
    text = dediac_ar(text)
    text = text.encode("utf-8").decode("utf-8")
    return "".join([char for char in text if char in AR_LETTERS_CHARSET or char.isdigit() or char.isspace()])

In [None]:
df['arabic'] = df['arabic'].apply(normalize)

In [None]:
df['arabic']

In [None]:
nltk.download('punkt')

In [None]:
#nltk.download('stopwords')
#stopwords = set(stopwords.words('arabic'))
#df['arabic'] = df['arabic'].apply(lambda text: str([token for token in text.split() if token not in stopwords]))

In [None]:
df['arabic']

In [None]:
df['arabic'] =  df['arabic'].apply(arabic_reshaper.reshape)

In [None]:
df['arabic'] = df['arabic'].apply(lambda x: ' '.join(x) if isinstance(x, list) else x)

In [None]:
df.head()

# Schema Handling

In [None]:
def extract_schema_info(db_id, db_path="/kaggle/input/arabic-to-sql/Text To SQL Task/Dataset/database"):
    db_file = os.path.join(db_path, db_id, f"{db_id}.sqlite") 
    conn = sqlite3.connect(db_file)
    cursor = conn.cursor()
    
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0] for row in cursor.fetchall()]
    
    schema = {}
    schema_texts = []
    
    for table in tables:
        cursor.execute(f"PRAGMA table_info({table});")
        columns = [row[1] for row in cursor.fetchall()]
        schema[table] = columns
        schema_texts.append(f"Table: {table}, Columns: {', '.join(columns)}")
    
    conn.close()
    return schema, schema_texts

In [None]:
db_id = "department_management"  
schema, schema_texts = extract_schema_info(db_id)
schema

In [None]:
db_schemas = {}
db_texts = {}
for db_id in df['db_id'].unique():
    schema, texts = extract_schema_info(db_id)
    db_schemas[db_id] = schema
    db_texts[db_id] = texts

# RAG

In [None]:
embed_model = SentenceTransformer("aubmindlab/bert-base-arabertv02")

In [None]:
def create_faiss_index(texts):
    embeddings = np.array([embed_model.encode(text) for text in texts]).astype("float32")
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)
    return index, texts

In [None]:
db_indices = {db_id: create_faiss_index(db_texts[db_id]) for db_id in db_texts}

In [None]:
index, texts = create_faiss_index(schema_texts)
print("Schema Texts:", texts)

In [None]:
def retrieve_relevant_schema(question, db_id, top_k=3):
    index, texts = db_indices[db_id]
    query_vector = embed_model.encode(question).astype("float32").reshape(1, -1)
    _, indices = index.search(query_vector, top_k)
    return [texts[i] for i in indices[0]]

In [None]:
model_name = "moussaKam/AraBART"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [None]:
model.to(device)

In [None]:
def preprocess_data(example):
    question, db_id, query = example["arabic"], example["db_id"], example["query"]
    schema_context = retrieve_relevant_schema(question, db_id)
    
    input_text = f"{question} Context: {'. '.join(schema_context)}"
    
    inputs = tokenizer(input_text, padding="max_length", truncation=True, max_length=512)
    labels = tokenizer(query, padding="max_length", truncation=True, max_length=256)
    
    labels["input_ids"] = [
        -100 if token == tokenizer.pad_token_id else token for token in labels["input_ids"]
    ]
    
    return {"input_ids": inputs["input_ids"], "labels": labels["input_ids"]}


In [None]:
train_data = df.apply(preprocess_data, axis=1).tolist()

In [None]:
schema_context = retrieve_relevant_schema("كم عدد رؤساء الأقسام الذين تزيد أعمارهم عن 56", "department_management")
input_text = f"عدد رؤساء الأقسام الذين تزيد أعمارهم عن 56. Context: {'. '.join(schema_context)}"

print("Formatted Model Input:")
print(input_text)

# Model Training

In [None]:
class SQLDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, device):
        self.encodings = {key: torch.tensor(val).to(device) for key, val in encodings.items()}

    def __len__(self):
        return len(next(iter(self.encodings.values())))

    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

In [None]:
train_data_split, val_data_split = train_test_split(train_data, test_size=0.1, random_state=42)

In [None]:
train_encodings = {key: [dic[key] for dic in train_data_split] for key in train_data_split[0]}
val_encodings = {key: [dic[key] for dic in val_data_split] for key in val_data_split[0]}

In [None]:
train_dataset = SQLDataset(train_encodings , device)
val_dataset = SQLDataset(val_encodings , device)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./sql_model",
    run_name="arabic_text2sql_experiment",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_total_limit=2,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    logging_strategy="epoch", 
    logging_steps=10,  
    num_train_epochs=65,
    learning_rate=5e-5,
    weight_decay=0.01,
    fp16=True, 
    push_to_hub=False,
    predict_with_generate=True
)


In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset, 
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
model.save_pretrained("./sql_model")
tokenizer.save_pretrained("./sql_model")

# Model evaluation

In [None]:
def generate_sql(question, db_id):
    relevant_schema = retrieve_relevant_schema(question, db_id)
    schema_context = ". ".join(relevant_schema)
    input_text = f"{question}. Context: {schema_context}"
    
    input_ids = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).input_ids.to(device)
    output_ids = model.generate(input_ids, max_length=128)
    generated_sql = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    return generated_sql

In [None]:
def execute_query(sql_query, db_id, db_path="/kaggle/input/arabic-to-sql/Text To SQL Task/Dataset/database"):
    db_file = os.path.join(db_path, db_id, f"{db_id}.sqlite") 
    conn = sqlite3.connect(db_file)
    cursor = conn.cursor()
    
    try:
        cursor.execute(sql_query)
        result = cursor.fetchall()
        conn.close()
        return result
    except Exception as e:
        conn.close()
        return f"Error executing query: {e}"

In [None]:
def normalize_sql(sql):
    sql = sql.lower()  
    sql = re.sub(r"\s+", " ", sql)  
    sql = re.sub(r"\s*,\s*", ", ", sql) 
    return sql.strip() 


In [None]:
test = [
    {
        "arabic": "ﺍﻋﺮﺽ ﻗﺎﺋﻤﻪ ﺑﺎﺳﻤﺎﺀ ﺭﺅﺳﺎﺀ ﺍﻻﻗﺴﺎﻡ ﻣﻜﺎﻥ ﻣﻴﻼﺩﻫﻢ ﻭﺍﻋﻤﺎﺭﻫﻢ ﻣﺮﺗﺒﻪ ﺣﺴﺐ ﺍﻟﻌﻤﺮ",
        "db_id": "department_management",
        "query": "select name ,  born_state ,  age from head order by age"
    },
    {
        "arabic": "ﺍﺑﺤﺚ ﻋﻦ ﺍﺳﻢ ﻭﻋﻨﻮﺍﻥ ﺍﻟﺒﺮﻳﺪ ﺍﻻﻟﻜﺘﺮﻭﻧﻲ ﻟﻠﻤﺴﺘﺨﺪﻡ ﺍﻟﺬﻱ ﻳﺤﺘﻮﻱ ﺍﺳﻤﻪ ﻋﻠﻲ ﻛﻠﻤﻪ ﺳﻮﻳﻔﺖ",
        "db_id": "twitter_1",
        "query": "select name ,  email from user_profiles where name like '%swift%'"
    },
    {
        "arabic": "ﻛﻢ ﻋﺪﺩ ﺍﻟﻼﻋﺒﻴﻦ",
        "db_id": "riding_club",
        "query": "select count(*) from player"
    },
    {
        "arabic": "ﻣﺎ ﻫﻲ ﺍﺳﻤﺎﺀ ﺍﻟﻔﻨﺎﻧﻴﻦ ﻣﺮﺗﺒﻪ ﺍﺑﺠﺪﻳﺎ",
        "db_id": "musical",
        "query": "select name from actor order by name asc"
    },
    {
        "arabic": "ﻣﺎ ﻫﻲ ﺍﺳﻤﺎﺀ ﺍﻟﻤﺼﺎﺭﻋﻴﻦ ﻣﺮﺗﺒﻪ ﺗﻨﺎﺯﻟﻴﺎ ﺣﺴﺐ ﺍﻻﻳﺎﻡ ﺍﻟﺘﻲ ﺗﻤﺖ ﺍﻻﺣﺘﻔﺎﻅ ﺑﻬﺎ",
        "db_id": "wrestler",
        "query": "select name from wrestler order by days_held desc"
    },
    {
        "arabic": "ﻣﺎ ﻫﻮ ﻣﺘﻮﺳﻂ ﺍﻋﻤﺎﺭ ﺍﻟﻄﺎﻟﺒﺎﺕ ﺍﻟﻼﺗﻲ ﻟﺪﻳﻬﻦ ﺗﺼﻮﻳﺖ ﻻﻣﻴﻦ ﺍﻟﺴﺮ ﻓﻲ ﺩﻭﺭﻩ ﺍﻻﻧﺘﺨﺎﺑﺎﺕ ﺍﻟﺮﺑﻴﻌﻴﻪ",
        "db_id": "voter_2",
        "query": "select avg(t1.age) from student as t1 join voting_record as t2 on t1.stuid  =  secretary_vote where t1.sex  =  \"f\" and t2.election_cycle  =  \"spring\""
    },
    {
        "arabic": "ﺍﻋﺮﺽ ﺟﻤﻴﻊ ﺍﻟﻤﺴﺘﺸﺎﺭﻳﻦ ﺍﻟﺬﻳﻦ ﻟﺪﻳﻬﻢ ﻋﻠﻲ ﺍﻻﻗﻞ ﻃﺎﻟﺒﻴﻦ",
        "db_id": "game_1",
        "query": "select major ,  avg(age) ,  min(age) ,  max(age) from student group by major"
    },
    {
        "arabic": " ﻣﺎ ﻫﻲ ﻣﻌﺮﻓﺎﺕ ﺟﻤﻴﻊ ﺍﻟﻄﺎﺋﺮﺍﺕ ﺍﻟﺘﻲ ﻳﻤﻜﻨﻬﺎ ﻗﻄﻊ ﻣﺴﺎﻓﻪ ﺗﺰﻳﺪ ﻋﻦ 1000",
        "db_id": "flight_1",
        "query": "select aid from aircraft where distance  >  1000"
    },
    {
        "arabic": "ﻛﻢ ﻋﺪﺩ ﺍﻻﻗﺴﺎﻡ ﺍﻟﺘﻲ ﺗﻘﺪﻡ ﺩﻭﺭﺍﺕ",
        "db_id": "college_2",
        "query": "select count(distinct dept_name) from course"
    },
    {
        "arabic": "ﻣﺎ ﻫﻮ ﻧﻮﻉ ﺍﻟﻤﻨﺘﺞ ﺍﻟﺬﻱ ﻳﺘﻤﺘﻊ ﺑﻤﺘﻮﺳﻂ ﺳﻌﺮ ﺍﻋﻠﻲ ﻣﻦ ﻣﺘﻮﺳﻂ ﺳﻌﺮ ﺟﻤﻴﻊ ﺍﻟﻤﻨﺘﺠﺎﺕ",
        "db_id": "department_store",
        "query": "select product_type_code ,  max(product_price) ,  min(product_price) from products group by product_type_code"
    },
    {
        "arabic": "ﺍﻋﺜﺮ ﻋﻠﻲ ﻋﺪﺩ ﺍﻻﻟﺒﻮﻣﺎﺕ ﻟﻠﻔﻨﺎﻥ ﻣﻴﺘﺎﻟﻴﻜﺎ",
        "db_id": "chinook_1",
        "query": "select count(*) from album as t1 join artist as t2 on t1.artistid  =  t2.artistid where t2.name  =  \"metallica\""
    },
    {
        "arabic": "ﺍﻇﻬﺮ ﻛﻞ ﺍﻟﻤﻌﻠﻮﻣﺎﺕ ﺣﻮﻝ ﺍﻻﻧﺘﺨﺎﺑﺎﺕ",
        "db_id": "election",
        "query": "select distinct year from party where governor  =  \"eliot spitzer\""
    },
    {
        "arabic": "ﻣﺎ ﻫﻲ ﺍﺳﻤﺎﺀ ﺍﻟﺴﻔﻦ ﻣﺮﺗﺒﻪ ﺣﺴﺐ ﺳﻨﻪ ﺍﻟﺒﻨﺎﺀ ﻭﻓﺌﺘﻬﺎ",
        "db_id": "ship_1",
        "query": "select name from ship order by built_year ,  class"
    },
    {
        "arabic": "ﺍﻟﻌﺜﻮﺭ ﻋﻠﻲ ﺍﺳﻤﺎﺀ ﺟﻤﻴﻊ ﺍﻟﻤﺪﺭﺳﻴﻦ ﻓﻲ ﻗﺴﻢ ﻋﻠﻮﻡ ﺍﻟﺤﺎﺳﻮﺏ",
        "db_id": "college_2",
        "query": "select name from instructor where dept_name  =  'comp. sci.'"
    },
    {
        "arabic": "ﻣﺎ ﻫﻲ ﺍﻻﺳﻤﺎﺀ ﻟﻠﺜﻼﺛﻪ ﻓﺮﻭﻉ ﺍﻟﺘﻲ ﺗﺤﺘﻮﻱ ﻋﻠﻲ ﺍﻛﺒﺮ ﻋﺪﺩ ﻣﻦ ﺍﻟﻌﻀﻮﻳﺎﺕ",
        "db_id": "shop_membership",
        "query": "select name from branch order by membership_amount desc limit 3"
    },
    {
        "arabic": "ﺍﺧﺘﺮ ﺍﺳﻢ ﺍﻟﻤﻨﺘﺠﺎﺕ ﺑﺴﻌﺮ ﺍﻗﻞ ﻣﻦ ﺍﻭ ﻳﺴﺎﻭﻱ 200 ﺩﻭﻻﺭ",
        "db_id": "manufactory_1",
        "query": "select name ,  price from products"
    },
    {
        "arabic": "ﺍﻇﻬﺮ ﺍﺳﻢ ﺍﻟﻌﺎﺋﻠﻪ ﻭﺍﻻﺳﻢ ﺍﻻﻭﻝ ﻟﻜﻞ ﻃﺎﻟﺐ",
        "db_id": "student_1",
        "query": "select distinct firstname ,  lastname from list"
    },
    {
        "arabic": "ﻣﺎ ﻫﻲ ﻧﻄﺎﻗﺎﺕ ﺍﻻﺳﻌﺎﺭ ﻟﻠﻔﻨﺎﺩﻕ ﺫﺍﺕ ﺍﻟﺘﻘﻴﻴﻢ ﺍﻟﺨﻤﺲ ﻧﺠﻮﻡ",
        "db_id": "cre_Theme_park",
        "query": "select price_range from hotels where star_rating_code  =  \"5\""
    },
    {
        "arabic": "ﻛﻢ ﻋﺪﺩ ﺍﻟﻤﺘﺎﺑﻌﻴﻦ ﻟﻜﻞ ﻣﺴﺘﺨﺪﻡ",
        "db_id": "twitter_1",
        "query": "select count(*) from follows"
    },
    {
        "arabic": "ﻛﻢ ﻗﺴﻤﺎ ﻳﺪﻳﺮﻩ ﺭﺅﺳﺎﺀ ﻟﻢ ﻳﺬﻛﺮ ﺍﺳﻤﺎﺅﻫﻢ",
        "db_id": "department_management",
        "query": "select count(*) from department where department_id not in (select department_id from management);"
    }
]

In [None]:
bleu_metric = evaluate.load("bleu")

In [None]:
num_runs = 20
exact_match_count = 0
execution_accuracy_count = 0
bleu_scores = []

In [None]:
for ex in test:
    question = ex['arabic']
    db_id = ex['db_id']
    true_sql = ex['query']
    
    generated_sql = generate_sql(question, db_id)
    
    print("Arabic Question:", question)
    print("\nGenerated SQL Query:", generated_sql)
    print("\nGround Truth SQL:", true_sql)


    normalized_generated_sql = normalize_sql(generated_sql)
    normalized_true_sql = normalize_sql(true_sql)

    exact_match = normalized_generated_sql == normalized_true_sql
    print('Exact match: ' ,exact_match)
    if exact_match:
        exact_match_count += 1

    try:
        generated_result = execute_query(generated_sql, db_id)
        true_result = execute_query(true_sql, db_id)
        execution_accuracy = generated_result == true_result
        print('Execution accuracy: ' , execution_accuracy)
        if execution_accuracy:
            execution_accuracy_count += 1
    except Exception:
        execution_accuracy = False  

    bleu_score = bleu_metric.compute(
        predictions=[generated_sql], 
        references=[[true_sql]] 
    )["bleu"]

    bleu_scores.append(bleu_score)
    print("BLEU Score:", bleu_score)

In [None]:
avg_bleu_score = np.mean(bleu_scores)
exact_match_accuracy = exact_match_count / num_runs
execution_accuracy = execution_accuracy_count / num_runs

In [None]:
print("\nFinal Results:")
print("Average BLEU Score:", avg_bleu_score)
print("Exact Match Accuracy:", exact_match_accuracy)
print("Execution Accuracy:", execution_accuracy)