In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import deepspeed
import pandas as pd
# Corrected import path for the checkpoint conversion utility
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint

def generate_sql(instruction, user_input, model, tokenizer, device):
    """
    Generates a SQL query based on an instruction and user input using the DeepSpeed-accelerated model.

    Args:
        instruction (str): The instruction for the model (e.g., "Translate the following into a SQL query").
        user_input (str): The natural language input from the user.
        model: The DeepSpeed-accelerated model engine.
        tokenizer (PreTrainedTokenizer): The tokenizer for the model.
        device (torch.device): The device to run inference on (e.g., 'cuda:0' or 'cpu').

    Returns:
        str: The generated SQL query.
    """
    # Format the input exactly as it was during training
    prompt = f"Instruction: {instruction}\nInput: {user_input}"

    # Tokenize the formatted input prompt
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    # Generate the output sequence using the DeepSpeed model's generate method
    # NOTE: Removed num_beams and early_stopping as they are not supported by DeepSpeed's inference engine.
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=128
        )

    # Decode the generated token IDs back to a string
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_sql

if __name__ == "__main__":
    # --- Configuration ---
    
    # The original model ID used for training.
    model_id = "t5-small"
    
    # Path to the directory containing the 'latest' file and the 'global_step' subdirectories.
    deepspeed_checkpoint_dir = "ray_results/TorchTrainer_2025-07-28_21-03-06/TorchTrainer_c9915_00000_0_2025-07-28_21-03-07/checkpoint_000000/"
    
    # Path to the validation data file.
    validation_file_path = "wikisql/data/validation-00000-of-00001-3f1ecb1168a6a037.parquet"

    # --- Device Setup ---
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # --- 1. Load Fine-Tuned Model Weights ---
    
    # Load the base model architecture first
    print(f"Loading base model architecture from '{model_id}'...")
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Use the correct DeepSpeed utility to consolidate the sharded ZeRO-3 checkpoint in memory.
    # The function will read the 'latest' file in the provided directory to find the correct subfolder.
    print(f"Loading and consolidating fine-tuned weights from '{deepspeed_checkpoint_dir}'...")
    state_dict = get_fp32_state_dict_from_zero_checkpoint(deepspeed_checkpoint_dir)
    
    # Load the consolidated state dict into the model
    model.load_state_dict(state_dict)
    print("Fine-tuned weights loaded successfully.")

    # --- 2. Initialize Model with DeepSpeed Inference Engine ---
    print("\nInitializing model with DeepSpeed Inference Engine...")
    
    # Use deepspeed.init_inference to accelerate the model with fine-tuned weights
    ds_model = deepspeed.init_inference(
        model=model,
        mp_size=1,  # Number of GPUs for model parallelism (1 for single-GPU)
        dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # Match training precision
        replace_with_kernel_inject=True  # Use DeepSpeed's custom kernels for acceleration
    )
    print("DeepSpeed model initialized for inference.")
    
    # The model is already on the correct device after init_inference.
    ds_model.eval()

    # --- 3. Load Validation Data and Run Inference on Samples ---
    print(f"\n--- Loading validation data from '{validation_file_path}' ---")
    try:
        df = pd.read_parquet(validation_file_path)
        # Take 10 random samples for inference. Use a random_state for reproducibility.
        samples = df.sample(n=10, random_state=42)
        print(f"Loaded {len(df)} records, sampled 10 for inference.")

        print("\n--- Running Inference on 10 Samples ---")
        for index, row in samples.iterrows():
            instruction = str(row.get('instruction', ''))
            user_input = str(row.get('input', ''))
            ground_truth_sql = str(row.get('output', ''))

            print(f"\n----- Sample {index + 1} -----")
            print(f"Instruction:\t{instruction}")
            print(f"Input:\t\t{user_input}")
            print(f"Ground Truth:\t{ground_truth_sql}")

            # Generate the SQL query
            generated_sql = generate_sql(instruction, user_input, ds_model, tokenizer, device)

            print(f"Generated SQL:\t{generated_sql}")
            print("-" * 25)

    except FileNotFoundError:
        print(f"Error: Validation file not found at '{validation_file_path}'. Please check the path.")
    except Exception as e:
        print(f"An error occurred while processing the validation file: {e}")



[2025-07-28 21:34:37,091] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cpu (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


[2025-07-28 21:34:42,369] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
Using device: cpu
Loading base model architecture from 't5-small'...
Loading and consolidating fine-tuned weights from 'ray_results/TorchTrainer_2025-07-28_21-03-06/TorchTrainer_c9915_00000_0_2025-07-28_21-03-07/checkpoint_000000/'...
Processing zero checkpoint 'ray_results/TorchTrainer_2025-07-28_21-03-06/TorchTrainer_c9915_00000_0_2025-07-28_21-03-07/checkpoint_000000/global_step0'


Loading checkpoint shards: 100%|██████████| 1/1 [00:00<00:00, 34.26it/s]


Detected checkpoint of type zero stage ZeroStageEnum.weights, world_size: 1
Parsing checkpoint created by deepspeed==0.17.3


Gathering sharded weights: 100%|██████████| 131/131 [00:00<00:00, 688538.63it/s]

Reconstructed Trainable fp32 state dict with 131 params 60506624 elements





Fine-tuned weights loaded successfully.

Initializing model with DeepSpeed Inference Engine...
[2025-07-28 21:34:46,652] [INFO] [logging.py:107:log_dist] [Rank -1] DeepSpeed info: version=0.17.3, git-hash=unknown, git-branch=unknown
[2025-07-28 21:34:46,654] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
[2025-07-28 21:34:46,654] [INFO] [logging.py:107:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
DeepSpeed model initialized for inference.

--- Loading validation data from 'wikisql/data/validation-00000-of-00001-3f1ecb1168a6a037.parquet' ---
Loaded 8421 records, sampled 10 for inference.

--- Running Inference on 10 Samples ---

----- Sample 2466 -----
Instruction:	Translate the following into a SQL query
Input:		Name the play for 1976
Ground Truth:	SELECT Play FROM table WHERE Year = 1976
Generated SQL:	SELECT Play FROM table WHERE Year = 1976
-------------------------

----- Sample 34 -

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import pandas as pd

def generate_sql(instruction, user_input, model, tokenizer, device):
    """
    Generates a SQL query based on an instruction and user input.

    Args:
        instruction (str): The instruction for the model.
        user_input (str): The natural language input from the user.
        model (PreTrainedModel): The Hugging Face model.
        tokenizer (PreTrainedTokenizer): The tokenizer for the model.
        device (torch.device): The device to run inference on.

    Returns:
        str: The generated SQL query.
    """
    # Format the input prompt
    prompt = f"Instruction: {instruction}\nInput: {user_input}"

    # Tokenize the formatted input prompt
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256
    ).to(device)

    # Generate the output sequence
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=128,
            num_beams=4,
            early_stopping=True
        )

    # Decode the generated token IDs back to a string
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_sql

if __name__ == "__main__":
    # --- Configuration ---
    
    # The original model ID to be used for inference.
    model_id = "t5-small"
    
    # Path to the validation data file.
    validation_file_path = "wikisql/data/validation-00000-of-00001-3f1ecb1168a6a037.parquet"

    # --- Device Setup ---
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # --- 1. Load Base Model and Tokenizer ---
    
    # Load the base model and tokenizer directly from Hugging Face
    print(f"Loading base model and tokenizer from '{model_id}'...")
    try:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_id).to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model.eval() # Set the model to evaluation mode
        print("Base model loaded successfully.")
    except Exception as e:
        print(f"Error loading model: {e}")
        exit()

    # --- 2. Load Validation Data and Run Inference on Samples ---
    print(f"\n--- Loading validation data from '{validation_file_path}' ---")
    try:
        df = pd.read_parquet(validation_file_path)
        # Take 10 random samples for inference. Use a random_state for reproducibility.
        samples = df.sample(n=10, random_state=42)
        print(f"Loaded {len(df)} records, sampled 10 for inference.")

        print("\n--- Running Inference on 10 Samples (Base Model) ---")
        for index, row in samples.iterrows():
            instruction = str(row.get('instruction', ''))
            user_input = str(row.get('input', ''))
            ground_truth_sql = str(row.get('output', ''))

            print(f"\n----- Sample {index + 1} -----")
            print(f"Instruction:\t{instruction}")
            print(f"Input:\t\t{user_input}")
            print(f"Ground Truth:\t{ground_truth_sql}")

            # Generate the SQL query using the base model
            generated_sql = generate_sql(instruction, user_input, model, tokenizer, device)

            print(f"Generated SQL:\t{generated_sql}")
            print("-" * 25)

    except FileNotFoundError:
        print(f"Error: Validation file not found at '{validation_file_path}'. Please check the path.")
    except Exception as e:
        print(f"An error occurred while processing the validation file: {e}")


Using device: cpu
Loading base model and tokenizer from 't5-small'...
Base model loaded successfully.

--- Loading validation data from 'wikisql/data/validation-00000-of-00001-3f1ecb1168a6a037.parquet' ---
Loaded 8421 records, sampled 10 for inference.

--- Running Inference on 10 Samples (Base Model) ---

----- Sample 2466 -----
Instruction:	Translate the following into a SQL query
Input:		Name the play for 1976
Ground Truth:	SELECT Play FROM table WHERE Year = 1976
Generated SQL:	the play for 1976.
-------------------------

----- Sample 34 -----
Instruction:	Translate the following into a SQL query
Input:		what are all the playoffs for u.s. open cup in 1st round
Ground Truth:	SELECT Playoffs FROM table WHERE U.S. Open Cup = 1st Round
Generated SQL:	: what are all the playoffs for u.s. open cup in 1st round?
-------------------------

----- Sample 7849 -----
Instruction:	Translate the following into a SQL query
Input:		What is the location of the game that has a number smaller than 2