In [1]:
import sys
sys.path.append('../utils/')
sys.path.append('../queryProcessing/')

from utils import *
from TableMapper import TableMapper

from tqdm.notebook import tqdm
tqdm.pandas()

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, pipeline
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import json

class TQLRunner():
    
    def __init__(self, schema_id):
        
        if(schema_id is None):
            raise Exception("Schema ID is needed")
        
        self.query, self.schema = get_spider_schema_table_files()
        self.tableMapper = TableMapper(self.query, self.schema)
        
        self.s, self.t = self.tableMapper.get_filtered_schema(schema_id)
        
        print('All libraries loaded')
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        print("Helloj")
        # Load the adapter configuration from the provided URL
        adapter_config_url = 'https://huggingface.co/naman1011/TQL/raw/main/adapter_config.json'
        adapter_config = json.loads('''{
                          "auto_mapping": null,
                          "base_model_name_or_path": "meta-llama/Llama-2-7b-chat-hf",
                          "bias": "none",
                          "fan_in_fan_out": false,
                          "inference_mode": true,
                          "init_lora_weights": true,
                          "layers_pattern": null,
                          "layers_to_transform": null,
                          "lora_alpha": 32,
                          "lora_dropout": 0.1,
                          "modules_to_save": null,
                          "peft_type": "LORA",
                          "model_type" : "t5",
                          "r": 2,
                          "revision": null,
                          "target_modules": [
                                "q_proj",
                                "v_proj"
                  ],
                  "task_type": "CASUAL_LM"
            }'''
       )
        
        # Load the model using the adapter configuration
        print(adapter_config)
        self.model = AutoModelForCausalLM.from_pretrained('naman1011/TQL', config=adapter_config)
        print("Hello 4j")
        self.tokenizer = AutoTokenizer.from_pretrained('naman1011/TQL')
        print("Hello kfr4j")
            
        print('LLM Model initialized')
            
        
    def create_schema_natural_language(self, row):

        schema_id = row['schema_id']
        table_name = row['table_name']
        primary_key = row['primary_key']
        column_list = eval(row['column_list_original'])
        datatype_list = eval(row['column_datatypes'])
        foreign_key = eval(row['foreign_keys'])

        column_list_with_datatype = []
        for column, datatype in zip(column_list, datatype_list):
            column_list_with_datatype.append(\
                     ' has datatype '.join([column, datatype])
            )

        schema_natural_language = \
                f"Given the Table {table_name} having columns as \
                        {', '.join(column_list_with_datatype)} \
                            which has {primary_key}"
        return schema_natural_language
    
    
    def get_table_prompt(self, input_text):
        
        table_names_from_tql = self.tableMapper.get_table_names_tql(self.s, input_text)
        
        if(len(table_names_from_tql) == 0):
            raise Exception("No tables found, please repharse the query and try again")
        
        prompt_tables = []
        for i in table_names_from_tql:
            prompt_tables.append(
                self.s[self.s['table_name_original'] == i].apply(
                    self.create_schema_natural_language, axis = 1
                ).reset_index(drop = True).iloc[0]
            )
                
        return ' and '.join(prompt_tables)
    
    
    def get_final_prompt(self, input_text):
        
        task_prefix = 'Generate an SQL Query for'
        table_prompt = self.get_table_prompt(input_text)
        
        final_prompt = task_prefix + ' ' + input_text + ' ' + table_prompt
        
        return final_prompt
        
        
    def get_SQL_query(self, input_text):
        
        prompt = self.get_final_prompt(input_text)
        print(prompt)
        
        tokens = self.tokenizer(prompt, 
                           return_tensors="pt", max_length=512, 
                           truncation=True, padding="max_length")
        
        outputs = self.model.generate(input_ids=tokens.input_ids.to(self.device), max_new_tokens = 512)
        predicted_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        return predicted_query

In [2]:
tqlRunner = TQLRunner('yelp')

All libraries loaded
Helloj
{'auto_mapping': None, 'base_model_name_or_path': 'meta-llama/Llama-2-7b-chat-hf', 'bias': 'none', 'fan_in_fan_out': False, 'inference_mode': True, 'init_lora_weights': True, 'layers_pattern': None, 'layers_to_transform': None, 'lora_alpha': 32, 'lora_dropout': 0.1, 'modules_to_save': None, 'peft_type': 'LORA', 'model_type': 't5', 'r': 2, 'revision': None, 'target_modules': ['q_proj', 'v_proj'], 'task_type': 'CASUAL_LM'}


Downloading (…)lve/main/config.json:   0%|          | 0.00/1.50k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [None]:
input_text = 'How many reviews are there in the database'
tqlRunner.get_SQL_query(input_text)

In [None]:
input_text = 'How many businesses are there in "Vegas"'
tqlRunner.get_SQL_query(input_text)

--------

In [None]:
tqlRunner = TQLRunner('college_2')

In [None]:
input_text = 'What is the name and building of the departments whose budget is more than the average budget?'
tqlRunner.get_SQL_query(input_text)

In [None]:
input_text = 'What are the names of students who have taken the prerequisite for the course "International Finance"?'
tqlRunner.get_SQL_query(input_text)

In [None]:
from transformers import LlamaModel, LlamaConfig

# Initializing a LLaMA llama-7b style configuration
configuration = LlamaConfig()

# Initializing a model from the llama-7b style configuration
model = LlamaModel(configuration)

# Accessing the model configuration
configuration = model.config