# Setup Development Environment

In [None]:
%pip install  --upgrade \
  "transformers==4.44.2" \
  "sagemaker>=2.190.0" \
  "datasets==2.21.0" \
  "evaluate==0.4.2" \
  "s3fs==0.4.2"

# Load and prepare the dataset

In [None]:
def create_conversation(sample):
    system_message = """You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA. SCHEMA: {schema}"""
    return {
    "messages": [
      {"role": "system", "content": system_message.format(schema=sample["context"])},
      {"role": "user", "content": sample["question"]},
      {"role": "assistant", "content": sample["answer"]}
    ]
  }  

In [None]:
from datasets import load_dataset

# Load dataset from the hub
seed = 42
dataset = load_dataset("b-mc2/sql-create-context", split="train")
dataset = dataset.shuffle(seed=seed).select(range(3000))
dataset[0]

In [None]:
dataset = dataset.map(create_conversation, batched=False)
dataset = dataset.remove_columns(["answer", "question", "context"])

dataset[0]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset

def calculate_token_lengths(sample, tokenizer):
    system_tokens = tokenizer(sample["messages"][0]["content"], truncation=False)['input_ids']
    prompt_tokens = tokenizer(sample["messages"][1]["content"], truncation=False)['input_ids']
    completion_tokens = tokenizer(sample["messages"][2]["content"], truncation=False)['input_ids']
    return {
        'system_tokens': len(system_tokens),
        'prompt_tokens': len(prompt_tokens),
        'completion_tokens': len(completion_tokens),
        'total_tokens': len(system_tokens) + len(prompt_tokens) + len(completion_tokens)
    }


def display_token_distribution_and_percentiles(dataset, tokenizer):
    token_lengths = dataset.map(lambda x: calculate_token_lengths(x, tokenizer), batched=False)

    lengths_dict = {
        'Prompt (Input)': np.array(token_lengths['prompt_tokens']),
        'Completion (Output)': np.array(token_lengths['completion_tokens']),
        'Combined (Input + Output)': np.array(token_lengths['total_tokens'])
    }

    plt.figure(figsize=(10, 8))
    for i, (label, lengths) in enumerate(lengths_dict.items(), 1):
        plt.subplot(len(lengths_dict.items()), 1, i)
        plt.hist(lengths, bins=50, alpha=0.7, label=label)
        plt.title(f'Token Length Distribution for {label}')
        plt.xlabel('Token Length')
        plt.ylabel('Frequency')
        plt.legend()

    plt.tight_layout()
    plt.show()

    def print_summary_statistics(data, label):
        print(f"\n{label} Token Length Summary:")
        df = pd.Series(data)
        print(df.describe(percentiles=[0.5, 0.75, 0.9, 0.95, 0.99]))

    for label, lengths in lengths_dict.items():
        print_summary_statistics(lengths, label)

In [None]:
from transformers import AutoTokenizer

model_id = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_id)
display_token_distribution_and_percentiles(dataset, tokenizer)

In [None]:
train_test_split = dataset.train_test_split(test_size=0.3, seed=seed)
val_test_split = train_test_split['test'].train_test_split(test_size=0.5, seed=seed)

dataset = {
    'train': train_test_split['train'],
    'val': val_test_split['train'],
    'test': val_test_split['test']
}

# Check the number of examples in each split
print(f"Training set size: {len(dataset['train'])}")
print(f"Validation set size: {len(dataset['val'])}")
print(f"Test set size: {len(dataset['test'])}")

In [None]:
# save train_dataset to s3 using our SageMaker session
data_input_path = f'./datasets/text-to-sql'

train_dataset_path = f"{data_input_path}/train/train_dataset.json"
eval_dataset_path = f"{data_input_path}/eval/eval_dataset.json"
test_dataset_path = f"{data_input_path}/test/test_dataset.json"

dataset["train"].to_json(train_dataset_path, orient="records")
dataset["val"].to_json(eval_dataset_path, orient="records")
dataset["test"].to_json(test_dataset_path, orient="records")

print(f"Training data uploaded to: {data_input_path}")