In [None]:
!pip install datasets


Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.4.1-py3-none-any.whl (487 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.w

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import Dataset

In [None]:
# Load the WikiSQL dataset splits
dataset = load_dataset("wikisql")
train_data = dataset["train"]
val_data = dataset["validation"]


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/7.80k [00:00<?, ?B/s]

wikisql.py:   0%|          | 0.00/6.57k [00:00<?, ?B/s]

The repository for wikisql contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/wikisql.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/26.2M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/15878 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/8421 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/56355 [00:00<?, ? examples/s]

In [None]:
print(train_data)

Dataset({
    features: ['phase', 'question', 'table', 'sql'],
    num_rows: 56355
})


In [None]:
# Load the plBART tokenizer from uclanlp/plbart-base (without trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")


config.json:   0%|          | 0.00/783 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/986k [00:00<?, ?B/s]

In [None]:
def convert_sql(query_dict, table_header):
    """
    Converts a WikiSQL query dictionary into a SQL string.
    Handles:
      - 'agg': aggregation function (0: none, 1: MAX, 2: MIN, 3: COUNT, 4: SUM, 5: AVG)
      - 'sel': index of the selected column in table_header or a column name.
      - 'conds': list of conditions; each condition is a list where the first element is either
         the column index or column name, followed by operator and value.
    """
    agg = query_dict["agg"]
    sel = query_dict["sel"]
    conds = query_dict.get("conds", [])

    agg_map = {0: "", 1: "MAX", 2: "MIN", 3: "COUNT", 4: "SUM", 5: "AVG"}
    op_map = {0: "=", 1: ">", 2: "<", 3: ">=", 4: "<=", 5: "!="}

    agg_str = agg_map.get(agg, "")

    # Determine the selected column
    if isinstance(sel, int):
        sel_col = table_header[sel]
    else:
        sel_col = sel

    if agg_str:
        select_clause = f"SELECT {agg_str}({sel_col})"
    else:
        select_clause = f"SELECT {sel_col}"

    if conds:
        conditions = []
        for cond in conds:
            # Only take the first three elements
            col_val, op, val = cond[:3]
            # Check if col_val is an integer index or already a column name
            if isinstance(col_val, int):
                col_name = table_header[col_val]
            else:
                col_name = col_val
            op_str = op_map.get(op, "=")
            conditions.append(f"{col_name} {op_str} {val}")
        where_clause = " WHERE " + " AND ".join(conditions)
    else:
        where_clause = ""

    return select_clause + where_clause + ";"


In [None]:
def tokenize_example(example, tokenizer, max_input_length=512, max_output_length=128):
    """
    Tokenizes a single WikiSQL example.
    Expects:
      - example["question"]: the natural language question.
      - example["table"]["header"]: list of column names.
      - example["sql"]: either a dictionary with SQL components or a SQL string.
    """
    question = example["question"]
    table_schema = example["table"]["header"]

    # Use the "sql" field. If it's a dictionary, convert it; otherwise, assume it's already a SQL string.
    if isinstance(example["sql"], dict):
        sql_query = convert_sql(example["sql"], table_schema)
    else:
        sql_query = example["sql"]

    # Create input by combining the question and table schema.
    input_text = f"question: {question} schema: {', '.join(table_schema)}"
    target_text = sql_query

    model_inputs = tokenizer(
        input_text,
        max_length=max_input_length,
        truncation=True,
        padding="max_length"
    )

    # Tokenize the target SQL query.
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target_text,
            max_length=max_output_length,
            truncation=True,
            padding="max_length"
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [None]:
class WikiSQLDataset(Dataset):
    def __init__(self, data, tokenizer, max_input_length=512, max_output_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_output_length = max_output_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        return tokenize_example(example, self.tokenizer, self.max_input_length, self.max_output_length)


In [None]:
# Create custom dataset objects
train_dataset = WikiSQLDataset(train_data, tokenizer)
val_dataset = WikiSQLDataset(val_data, tokenizer)


In [None]:
print(val_dataset)

<__main__.WikiSQLDataset object at 0x7fe517d9ec90>


In [None]:
from transformers import BartForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq


In [None]:
# Load the plBART model using BartForConditionalGeneration
model = BartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")


You are using a model of type plbart to instantiate a model of type bart. This is not supported for all configurations of models and can yield errors.


pytorch_model.bin:   0%|          | 0.00/557M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/557M [00:00<?, ?B/s]

In [None]:
# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./plbart_wikisql_output",   # directory for model checkpoints
    num_train_epochs=3,                     # number of epochs
    per_device_train_batch_size=8,          # training batch size
    per_device_eval_batch_size=8,           # evaluation batch size
    learning_rate=5e-5,                     # learning rate
    evaluation_strategy="epoch",            # evaluation frequency
    save_strategy="epoch",                  # checkpoint saving frequency
    logging_steps=100,                      # log every 100 steps
    predict_with_generate=True,             # use generate() for evaluation
)



In [None]:
# Create a data collator to dynamically pad inputs and labels during batching
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)


In [None]:
# Initialize the Seq2SeqTrainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

  trainer = Seq2SeqTrainer(


In [None]:
# Start training
trainer.train()

# nirman key: 3c756d61f1e64f4a7716d57a61805f7158a99f3d



Epoch,Training Loss,Validation Loss
1,0.0058,0.00557


