In [1]:
from datasets import load_dataset
from transformers import GPT2Tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [3]:
ds = load_dataset("DipamSoni/custom_text_to_sql_dataset", split="train")


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Generating train split: 100%|██████████| 368059/368059 [00:00<00:00, 592686.19 examples/s]


In [4]:
print(ds[0])

{'instruction': 'Name the home team for carlton away team', 'input': 'CREATE TABLE table_name_77 (home_team VARCHAR,away_team VARCHAR)', 'response': 'SELECT home_team FROM table_name_77 WHERE away_team = "carlton"'}


In [9]:
def format_and_tokenize(obj):
    return f"{obj['instruction']}\n{obj['input']}\n=> {obj['response']}"

In [28]:
def tokenize_function(example):
    encoded = tokenizer(
        format_and_tokenize(example),
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors=None  # Important for Hugging Face dataset
    )
    
    return {
        "input_ids": encoded["input_ids"],
        "attention_mask": encoded["attention_mask"],
        "labels": encoded["input_ids"]
    }

In [29]:
print(tokenize_function(ds[0]))

{'input_ids': [5376, 262, 1363, 1074, 329, 1097, 75, 1122, 1497, 1074, 198, 43387, 6158, 43679, 3084, 62, 3672, 62, 3324, 357, 11195, 62, 15097, 569, 31315, 1503, 11, 8272, 62, 15097, 569, 31315, 1503, 8, 198, 14804, 33493, 1363, 62, 15097, 16034, 3084, 62, 3672, 62, 3324, 33411, 1497, 62, 15097, 796, 366, 66, 7063, 1122, 1, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 5

In [30]:
tokenized_dataset = ds.map(tokenize_function, batched=False, remove_columns=ds.column_names)


Map: 100%|██████████| 368059/368059 [07:32<00:00, 814.00 examples/s] 


In [31]:
tokenized_dataset[1]

{'input_ids': [10919,
  481,
  262,
  3265,
  286,
  7229,
  307,
  618,
  9133,
  2253,
  14,
  9914,
  571,
  14289,
  318,
  767,
  5999,
  357,
  22,
  13,
  20,
  4407,
  30,
  198,
  43387,
  6158,
  43679,
  3084,
  62,
  24403,
  3134,
  5855,
  17688,
  1,
  1103,
  553,
  10603,
  1,
  1103,
  553,
  38555,
  1,
  2420,
  553,
  17584,
  30997,
  1,
  2420,
  553,
  16112,
  1,
  2420,
  553,
  49022,
  2253,
  14,
  9914,
  571,
  14289,
  1,
  2420,
  553,
  40495,
  2253,
  1,
  2420,
  553,
  46,
  344,
  5411,
  1,
  2420,
  8,
  198,
  14804,
  33493,
  366,
  38555,
  1,
  16034,
  3084,
  62,
  24403,
  3134,
  33411,
  366,
  49022,
  2253,
  14,
  9914,
  571,
  14289,
  1,
  796,
  705,
  50165,
  357,
  22,
  13,
  20,
  4407,
  6,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  50256,
  

In [32]:
format_and_tokenize(ds[1])

'what will the population of Asia be when Latin America/Caribbean is 783 (7.5%)?\nCREATE TABLE table_22767 ("Year" real,"World" real,"Asia" text,"Africa" text,"Europe" text,"Latin America/Caribbean" text,"Northern America" text,"Oceania" text)\n=> SELECT "Asia" FROM table_22767 WHERE "Latin America/Caribbean" = \'783 (7.5%)\''