### Install Transformers Datasets

In [None]:
!pip install -q -U datasets > /dev/null

In [None]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m82.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, transformers
Successfully installed tokenizers-0.13.3 transformers-4.28.1


### Libraries 📚⬇

In [None]:
!pip install sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentencepiece
  Downloading sentencepiece-0.1.98-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.98


### Import the T5-base model 

In [None]:
from transformers import AutoTokenizer, T5ForConditionalGeneration

model_name = "mrm8488/t5-base-finetuned-wikiSQL" # you can specify the model size here
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)


Downloading (…)okenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.23k [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Downloading pytorch_model.bin:   0%|          | 0.00/1.19G [00:00<?, ?B/s]

### Predict Function

In [None]:
def get_sql(query):
    
    input_text = "translate English to SQL: %s </s>" % query
    
    features = tokenizer([input_text], return_tensors='pt')

    output = model.generate(input_ids=features['input_ids'], 
               attention_mask=features['attention_mask'], return_dict=False)
    output = tokenizer.decode(output[0])
    
    if "<pad> " in output:
        output = output.replace("<pad> ", "")
    if "</s>" in output:
        output = output.replace("</s>", "")
    if "<unk>" in output:
        output = output.replace("<unk>", "")
    
    return output

In [None]:
from datasets import load_dataset

valid_dataset = load_dataset('wikisql', split='validation')

Downloading builder script:   0%|          | 0.00/6.57k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.76k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.80k [00:00<?, ?B/s]

Downloading and preparing dataset wikisql/default to /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d...


Downloading data:   0%|          | 0.00/26.2M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/15878 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/8421 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/56355 [00:00<?, ? examples/s]

Dataset wikisql downloaded and prepared to /root/.cache/huggingface/datasets/wikisql/default/0.1.0/7037bfe6a42b1ca2b6ac3ccacba5253b1825d31379e9cc626fc79a620977252d. Subsequent calls will reuse this data.


In [None]:
train_dataset = load_dataset('wikisql', split='train')



In [None]:
train_dataset[0]
len(train_dataset)

56355

### Sample Validation Data

In [None]:
valid_dataset[0]

{'phase': 1,
 'question': 'What position does the player who played for butler cc (ks) play?',
 'table': {'header': ['Player',
   'No.',
   'Nationality',
   'Position',
   'Years in Toronto',
   'School/Club Team'],
  'page_title': 'Toronto Raptors all-time roster',
  'page_id': '',
  'types': ['text', 'text', 'text', 'text', 'text', 'text'],
  'id': '1-10015132-11',
  'section_title': 'L',
  'caption': 'L',
  'rows': [['Antonio Lang',
    '21',
    'United States',
    'Guard-Forward',
    '1999-2000',
    'Duke'],
   ['Voshon Lenard', '2', 'United States', 'Guard', '2002-03', 'Minnesota'],
   ['Martin Lewis',
    '32, 44',
    'United States',
    'Guard-Forward',
    '1996-97',
    'Butler CC (KS)'],
   ['Brad Lohaus', '33', 'United States', 'Forward-Center', '1996', 'Iowa'],
   ['Art Long',
    '42',
    'United States',
    'Forward-Center',
    '2002-03',
    'Cincinnati'],
   ['John Long', '25', 'United States', 'Guard', '1996-97', 'Detroit'],
   ['Kyle Lowry', '3', 'United Sta

### Prediction on WikiSQL Validation Set

In [None]:
import random, warnings
warnings.filterwarnings("ignore")

for idx in random.sample(range(len(valid_dataset)), 5):
    print(f"Text: {valid_dataset[idx]['question']}")
    print(f"Pred SQL: {get_sql(valid_dataset[idx]['question'])}")
    print(f"True SQL: {valid_dataset[idx]['sql']['human_readable']}\n")

Text: What is the Partner during the Asian Games Year?
Pred SQL: SELECT Partner FROM table WHERE Year = asian games
True SQL: SELECT Partner FROM table WHERE Year = asian games

Text: What was the lowest year that the engine Ilmor 2175a 3.5 v10 was used?
Pred SQL: SELECT MIN Year FROM table WHERE Engine = ilmor 2175a
True SQL: SELECT MIN Year FROM table WHERE Engine = ilmor 2175a 3.5 v10

Text: How many weeks have a Winning team of yellow team, and an Event of foos it or lose it?
Pred SQL: SELECT COUNT Week FROM table WHERE Winning team = yellow team AND Event
True SQL: SELECT SUM Week FROM table WHERE Winning team = yellow team AND Event = foos it or lose it

Text: What is the lowest rank that spain got?
Pred SQL: SELECT MIN Rank FROM table WHERE Country = spain
True SQL: SELECT MIN Rank FROM table WHERE Nation = spain

Text: What is the result in oakland?
Pred SQL: SELECT Result FROM table WHERE City = oakland
True SQL: SELECT Result FROM table WHERE Venue = oakland

