In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer

In [None]:
dataset = load_dataset("spider")
print(dataset["train"][0])

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")

def preprocess(example):
    input_text = "Translate to SQL: " + example["question"]
    output_text = example["query"]

    model_inputs = tokenizer(input_text, truncation=True, padding="max_length", max_length=128)
    labels = tokenizer(output_text, truncation=True, padding="max_length", max_length=128)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_ds = dataset.map(preprocess, remove_columns=dataset["train"].column_names)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

training_args = Seq2SeqTrainingArguments(
    output_dir="./sql_model",
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=2e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    save_total_limit=2,
    predict_with_generate=True,
    logging_dir="./logs"
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["validation"],
    tokenizer=tokenizer
)

trainer.train()


In [None]:
def driver(question):
    inputs = tokenizer("Translate to SQL: " + question, return_tensors="pt").input_ids
    outputs = model.generate(inputs, max_length=128)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

print(driver("Ask sth here???"))