### Install Transformers Datasets 

In [1]:
!pip install -q -U datasets > /dev/null
!pip show datasets
#https://huggingface.co/shahrukhx01/schema-aware-denoising-bart-large-cnn-text2sql

Name: datasets
Version: 2.11.0
Summary: HuggingFace community-driven open-source library of datasets
Home-page: https://github.com/huggingface/datasets
Author: HuggingFace Inc.
Author-email: thomas@huggingface.co
License: Apache 2.0
Location: /usr/local/lib/python3.9/dist-packages
Requires: aiohttp, dill, fsspec, huggingface-hub, multiprocess, numpy, packaging, pandas, pyarrow, pyyaml, requests, responses, tqdm, xxhash
Required-by: 


In [2]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m59.7 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, transformers
Successfully installed tokenizers-0.13.3 transformers-4.28.1


### Libraries 📚⬇

In [3]:
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
import pandas as pd
from datasets import load_dataset
import random, warnings
import pandas as pd
import re
warnings.filterwarnings("ignore")

### Import the MBART model 

In [4]:
model = BartForConditionalGeneration.from_pretrained('shahrukhx01/schema-aware-denoising-bart-large-cnn-text2sql')
tokenizer = BartTokenizer.from_pretrained('shahrukhx01/schema-aware-denoising-bart-large-cnn-text2sql')

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/890 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/300 [00:00<?, ?B/s]

### Predict Function

In [10]:
def get_sql(query, cur_data):
    result = "%s </s> " % query
    arr = cur_data['table']['header']
    brr = cur_data['table']['types']
    for i in range(len(arr)):
      result += f'<col{i}> {arr[i]} : {brr[i]}, '
    
    input_text = result

    
    features = tokenizer([input_text], max_length=1024, return_tensors='pt')

    output = model.generate(input_ids=features['input_ids'], 
               attention_mask=features['attention_mask'], num_beams=4, min_length=0, max_length=125,early_stopping=True,return_dict=False)
    
    final_output = format_change([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in output][0].replace("`", ""), arr)

    return final_output    




In [6]:
def format_change(cur_output, column_name):
    for i in range(len(column_name)):
        cur_output = cur_output.replace(f"<col{i}>", column_name[i])
    
    while '   ' in cur_output:
        cur_output = cur_output.replace('   ', ' ')
    
    while '  ' in cur_output:
        cur_output = cur_output.replace('  ', ' ')

    return cur_output

In [7]:
valid_dataset = load_dataset('wikisql', split='validation')

Downloading builder script:   0%|          | 0.00/6.57k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.76k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.80k [00:00<?, ?B/s]

Downloading and preparing dataset wikisql/default to /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d...


Downloading data:   0%|          | 0.00/26.2M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/15878 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/8421 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/56355 [00:00<?, ? examples/s]

Dataset wikisql downloaded and prepared to /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d. Subsequent calls will reuse this data.


### Sample Validation Data

In [8]:
valid_dataset[0]

{'phase': 1,
 'question': 'What position does the player who played for butler cc (ks) play?',
 'table': {'header': ['Player',
   'No.',
   'Nationality',
   'Position',
   'Years in Toronto',
   'School/Club Team'],
  'page_title': 'Toronto Raptors all-time roster',
  'page_id': '',
  'types': ['text', 'text', 'text', 'text', 'text', 'text'],
  'id': '1-10015132-11',
  'section_title': 'L',
  'caption': 'L',
  'rows': [['Antonio Lang',
    '21',
    'United States',
    'Guard-Forward',
    '1999-2000',
    'Duke'],
   ['Voshon Lenard', '2', 'United States', 'Guard', '2002-03', 'Minnesota'],
   ['Martin Lewis',
    '32, 44',
    'United States',
    'Guard-Forward',
    '1996-97',
    'Butler CC (KS)'],
   ['Brad Lohaus', '33', 'United States', 'Forward-Center', '1996', 'Iowa'],
   ['Art Long',
    '42',
    'United States',
    'Forward-Center',
    '2002-03',
    'Cincinnati'],
   ['John Long', '25', 'United States', 'Guard', '1996-97', 'Detroit'],
   ['Kyle Lowry', '3', 'United Sta

### Prediction on WikiSQL Validation Set

In [9]:
for idx in random.sample(range(len(valid_dataset)), 5):
    print(f"Text: {valid_dataset[idx]['question']}")
    print(f"Pred SQL: {get_sql(valid_dataset[idx]['question'], valid_dataset[idx])}")
    print(f"True SQL: {valid_dataset[idx]['sql']['human_readable']}\n")

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Text: HOW MANY YEARS WAS IT FOR THE SCORE (76-73-79-72=300)?
Pred SQL: SELECT COUNT Year FROM table WHERE Winning score = 76-73-79-72=300 
True SQL: SELECT COUNT Year FROM table WHERE Winning score = (76-73-79-72=300)

Text: if the geez is libb, what is the akkadian?
Pred SQL: SELECT Akkadian FROM table WHERE Geez = libb 
True SQL: SELECT Akkadian FROM table WHERE Geez = libb

Text: Name the call sign with frequency of 89.5
Pred SQL: SELECT Call sign FROM table WHERE Frequency MHz = 89.5 
True SQL: SELECT Call sign FROM table WHERE Frequency MHz = 89.5

Text: What is the main service for the station with 14.849 million passengers 2011-12? 
Pred SQL: SELECT Main Services FROM table WHERE Total Passengers (millions) 2011–12 = 14.849 
True SQL: SELECT Main Services FROM table WHERE Total Passengers (millions) 2011–12 = 14.849

Text: On September 10, 1989 how many people attended the game?
Pred SQL: SELECT COUNT Attendance FROM table WHERE Date = september 10, 1989 
True SQL: SELECT Attend