In [97]:
import torch 
import numpy as np
import tqdm
import matplotlib.pyplot as plt 
import os
import json
import pandas as pd

In [98]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq,Trainer
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import pipeline
from datasets import Dataset , load_dataset
import pandas as pd
import torch

In [99]:
gretel_dataset = load_dataset("gretelai/synthetic_text_to_sql")

In [100]:
train_data = gretel_dataset['train']
test_data = gretel_dataset['test']

In [101]:
list(train_data.column_names)

['id',
 'domain',
 'domain_description',
 'sql_complexity',
 'sql_complexity_description',
 'sql_task_type',
 'sql_task_type_description',
 'sql_prompt',
 'sql_context',
 'sql',
 'sql_explanation']

In [102]:
train_data_sample = gretel_dataset['train'].shuffle(seed=42).select(range(5000))
test_data_sample = gretel_dataset['train'].shuffle(seed=123).select(range(1000))


In [103]:
train_data_sample.shape

(5000, 11)

In [104]:
# train_data_sample = train_data_sample.to_pandas()
# test_data_sample = test_data_sample.to_pandas()

In [105]:
def construct_augmented_input(row):
    return f"""Domain: {row['domain']}
                Domain Description: {row['domain_description']}
                SQL Complexity: {row['sql_complexity']}
                Complexity Description: {row['sql_complexity_description']}
                SQL Task Type: {row['sql_task_type']}
                Task Type Description: {row['sql_task_type_description']}
                SQL Prompt: {row['sql_prompt']}
                SQL Context: {row['sql_context']}"""


In [106]:
train_data_sample = train_data_sample.map(
    lambda example: {
        "input_text": construct_augmented_input(example),
        "target_sql": example["sql"]
    }
)

In [107]:
test_data_sample = test_data_sample.map(
    lambda example: {
        "input_text": construct_augmented_input(example),
        "target_sql": example["sql"]
    }
)


In [108]:
model_checkpoint = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [109]:
def tokenize_function(example):
    model_inputs = tokenizer(
        example["input_text"], max_length=512, padding="max_length", truncation=True
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example["target_sql"], max_length=512, padding="max_length", truncation=True
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [110]:
train_dataset = train_data_sample.map(tokenize_function, batched=True)
test_dataset = test_data_sample.map(tokenize_function, batched=True)

In [111]:
data_collector = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True,
    max_length=512,
    pad_to_multiple_of=None,
    return_tensors="pt"
)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./t5-sql-finetuned",
    learning_rate=5e-5,
    per_device_train_batch_size=2,
    num_train_epochs=10,  # or less
    max_steps=5000,       # override if needed
    logging_dir="./logs",
    predict_with_generate=True,
)

In [112]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,  # Note: changed from trainer_args to training_args
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collector
)

  trainer = Seq2SeqTrainer(


In [None]:
# Start training and capture the training results
training_results = trainer.train()

# Save the model and tokenizer
output_dir = "gretel_t5"
trainer.save_model(output_dir)
trainer.tokenizer.save_pretrained(output_dir)

# Save training metrics
training_history = {
    'train_loss': training_results.training_loss,
    'metrics': training_results.metrics,
    'total_steps': training_results.global_step
}



Step,Training Loss,Validation Loss
500,0.0,
1000,0.0,
1500,0.0,




KeyboardInterrupt: 

In [None]:
# Save training history as JSON
import json
with open(f"{output_dir}/training_history.json", 'w') as f:
    json.dump(training_history, f, indent=4)

# Save training arguments configuration
trainer.args.to_json_file(f"{output_dir}/training_args.json")

# Print final training metrics
print(f"Training completed. Final loss: {training_results.training_loss}")
print(f"All files saved to: {output_dir}")