In [1]:
import pandas as pd
from supabase import create_client, Client
from dotenv import load_dotenv
import os
import warnings
warnings.filterwarnings("ignore")

import sqlparse # For SQL validation

# LangChain Imports
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer



In [3]:
load_dotenv()
SUPABASE_URL = os.environ.get("SUPABASE_URL")
SUPABASE_KEY = os.environ.get("SUPABASE_KEY")

In [4]:
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)

In [5]:
def validate_sql(sql_query: str) -> bool:
    """
    Performs basic validation on the generated SQL query.
    Allows only SELECT statements. Blocks DDL, DML, and potentially harmful clauses.
    """
    parsed = sqlparse.parse(sql_query)[0] # parse.parse returns a list of statements
    
    # Check for keywords that indicate DDL or DML
    forbidden_keywords = {'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER', 'TRUNCATE', 'CREATE', 'GRANT', 'REVOKE', 'VACUUM'}
    
    # Check if the first token is SELECT
    if not parsed.tokens[0].normalized == 'SELECT':
        print(f"Validation failed: Query does not start with SELECT. Found: {parsed.tokens[0].normalized}")
        return False

    # Iterate through tokens to find forbidden keywords
    for token in parsed.tokens:
        if isinstance(token, sqlparse.sql.Token) and token.normalized in forbidden_keywords:
            print(f"Validation failed: Forbidden keyword found: {token.normalized}")
            return False
        # Prevent UNION/UNION ALL if not specifically needed and risky for your use case
        if isinstance(token, sqlparse.sql.Token) and token.normalized in {'UNION'}:
             print(f"Validation failed: UNION is forbidden for security.")
             return False

    print("SQL query passed basic validation.")
    return True

In [6]:
METADATA_FILE = "metadata.json"
import json
try:
    with open(METADATA_FILE, 'r') as f:
        db_metadata = json.load(f)
    print(f"Loaded database metadata from {METADATA_FILE}")
except FileNotFoundError:
    print(f"Warning: {METADATA_FILE} not found. Schema descriptions will be limited.")
    db_metadata = {"tables": []}
except json.JSONDecodeError:
    print(f"Error: Could not parse {METADATA_FILE}. Check JSON format.")
    db_metadata = {"tables": []}

Loaded database metadata from metadata.json


In [4]:
MISTRAL_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

In [7]:
import torch
# run huggingface-cli login
model_kwargs = {
            "load_in_4bit": True,
            "torch_dtype": torch.bfloat16,
            "device_map": "auto"
        }

tokenizer = AutoTokenizer.from_pretrained(MISTRAL_MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MISTRAL_MODEL_NAME, **model_kwargs)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Fetching 3 files: 100%|██████████| 3/3 [50:29<00:00, 1009.83s/it]  
Loading checkpoint shards: 100%|██████████| 3/3 [00:27<00:00,  9.05s/it]


In [11]:
# client = InferenceClient(model=MISTRAL_MODEL_NAME)
model.to("cuda")

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): Mist

In [None]:
msg = "Hi! How are you? Can you explain me what are Neural Networks in short?"

In [15]:
tokens = tokenizer(msg)
tokens

{'input_ids': [1, 15359, 28808, 1602, 460, 368, 28804, 2418, 368, 7282, 528, 767, 460, 3147, 1890, 9488, 28713, 297, 2485, 28804], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [22]:
res = model.generate(tokenizer.encode(msg, return_tensors="pt").to("cuda"), max_new_tokens=1000, do_sample=True, temperature=0.7)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [26]:
tokenizer.decode(res[0].tolist()).split("\n")

['<s> Hi! How are you? Can you explain me what are Neural Networks in short?',
 '',
 'Sure! A neural network is a type of artificial intelligence model that is inspired by the human brain. It\'s designed to recognize patterns and learn from data, much like the human brain does. Neural networks consist of interconnected nodes, or "neurons," that process and transmit information. Each node takes in some input, applies a function to it, and passes the output on to the next node. The network learns by adjusting the weights of the connections between neurons based on the error of its predictions. Neural networks are particularly well-suited to tasks like image and speech recognition, where they can learn to recognize complex patterns from large amounts of data.</s>']

In [12]:
import re

In [26]:
q = re.findall(r"```sql\n(.*?);\n```", msg, re.DOTALL)[0]
q = """SELECT AVG("MeterA_reading") \nFROM synthetic_data_01 \nWHERE "MeterA_reading" BETWEEN 10 AND 30"""

In [27]:
supabase.rpc("execute_dynamic_sql", {"sql_command": q}).execute()

SingleAPIResponse[~_ReturnT](data=[{'avg': 22.1793215979329}], count=None)