In [1]:
#### Environment check & imports for Data Engineering LLM
# OVERVIEW: Sets up environment for data engineering LLM fine-tuning with Gemma 2B

import torch
import json
import os
from typing import Dict, Any, List
from datetime import datetime


print(f" Data Engineering LLM Fine-Tuning Setup")
print(f" {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f" Working Directory: {os.getcwd()}")
print(f" PyTorch Version: {torch.__version__}")
print(f" CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f" GPU: {gpu_name} ({gpu_memory:.1f} GB VRAM)")
    print(f" CUDA Capability: {torch.cuda.get_device_capability(0)}")
else:
    print(" No GPU detected - training will be very slow!")

torch.cuda.empty_cache()
print(f" GPU memory cleared\n")

 Data Engineering LLM Fine-Tuning Setup
 2026-01-23 07:41:50
 Working Directory: /home/manuelbomi/fine_tune_LLM
 PyTorch Version: 2.9.1+cu128
 CUDA Available: True
 GPU: NVIDIA GeForce RTX 4070 Laptop GPU (8.6 GB VRAM)
 CUDA Capability: (8, 9)
 GPU memory cleared



In [2]:
## Load data engineering training data
# OVERVIEW: Loads specialized data engineering training examples

data_files = [
    "enhanced_data_engineering_dataset.json"    
]

all_data = []
for data_file in data_files:
    if os.path.exists(data_file):
        with open(data_file, "r") as f:
            data = json.load(f)
            all_data.extend(data)
            print(f" Loaded {len(data)} examples from {data_file}")
    else:
        print(f"  File not found: {data_file}")

print(f"\n Total training examples: {len(all_data)}")
if len(all_data) > 0:
    print(f" Sample task type: {list(all_data[0].keys())}")

 Loaded 2000 examples from enhanced_data_engineering_dataset.json

 Total training examples: 2000
 Sample task type: ['task_type', 'input', 'output']


In [3]:
### Model Selection for Data Engineering
# OVERVIEW: Selects Gemma 2B for data engineering tasks

# Available models optimized for data engineering tasks
available_models = [
    "unsloth/gemma-2b-bnb-4bit",            # 2B - Google's efficient model (RECOMMENDED)
    "unsloth/llama-3.2-3b-bnb-4bit",        # 3B - Meta's latest small model
    "unsloth/deepseek-coder-1.3b-bnb-4bit", # 1.3B - Code specialized
    "unsloth/tinyllama-bnb-4bit",           # 1.1B - Smallest & fastest
    "unsloth/Phi-3-mini-4k-instruct-bnb-4bit", # 3.8B - Instruction tuned
]

print(" Available Models for Data Engineering Tasks:")
for i, model in enumerate(available_models, 1):
    size = model.split('-')[0].split('/')[-1]
    print(f"  {i}. {size:<20} â†’ {model}")

# Select Gemma 2B for data engineering
MODEL_CHOICE = "gemma-2b"  # Changed to Gemma 2B

model_map = {
    "gemma-2b": "unsloth/gemma-2b-bnb-4bit",
    "llama3.2-3b": "unsloth/llama-3.2-3b-bnb-4bit",
    "deepseek-coder": "unsloth/deepseek-coder-1.3b-bnb-4bit",
    "tinyllama": "unsloth/tinyllama-bnb-4bit",
    "phi-3-mini": "unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
}

model_name = model_map[MODEL_CHOICE]
print(f"\n Selected Model: {model_name}")
print(f" Reason: Gemma 2B offers excellent balance of efficiency and capability for structured data tasks")

 Available Models for Data Engineering Tasks:
  1. gemma                â†’ unsloth/gemma-2b-bnb-4bit
  2. llama                â†’ unsloth/llama-3.2-3b-bnb-4bit
  3. deepseek             â†’ unsloth/deepseek-coder-1.3b-bnb-4bit
  4. tinyllama            â†’ unsloth/tinyllama-bnb-4bit
  5. Phi                  â†’ unsloth/Phi-3-mini-4k-instruct-bnb-4bit

 Selected Model: unsloth/gemma-2b-bnb-4bit
 Reason: Gemma 2B offers excellent balance of efficiency and capability for structured data tasks


In [4]:
### Load Gemma 2B model and tokenizer
# OVERVIEW: Loads Gemma 2B with optimized settings for data engineering tasks

from unsloth import FastLanguageModel

# Gemma uses different chat template - longer context for data engineering
max_seq_length = 2048  # Increased for data engineering tasks
print(f" Loading {model_name} with max_seq_length={max_seq_length}...")

# Load with optimized settings for Gemma
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    dtype=None,  # Auto-detect
    load_in_4bit=True,
    token=os.getenv("HF_TOKEN", None),  # Optional: for gated models
)

# Gemma specific tokenizer setup
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Print model info
print(f" Loaded: {model_name}")
print(f" Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f" Tokenizer Vocab Size: {tokenizer.vocab_size}")
print(f" Max Position Embeddings: {model.config.max_position_embeddings}")

# Gemma specific settings
print(f"  Model Config:")
print(f"  - Architecture: {model.config.architectures[0]}")
print(f"  - Hidden Size: {model.config.hidden_size}")
print(f"  - Num Attention Heads: {model.config.num_attention_heads}")
print(f"  - Num Hidden Layers: {model.config.num_hidden_layers}")

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
 Loading unsloth/gemma-2b-bnb-4bit with max_seq_length=2048...
==((====))==  Unsloth 2026.1.3: Fast Gemma patching. Transformers: 4.57.3.
   \\   /|    NVIDIA GeForce RTX 4070 Laptop GPU. Num GPUs = 1. Max memory: 7.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.1+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
 Loaded: unsloth/gemma-2b-bnb-4bit
 Model Parameters: 1,515,268,096
 Tokenizer Vocab Size: 256000
 Max Position Embeddings: 8192
  Model Config:
  - Architecture: GemmaForCausalLM
  - Hidden Size: 2048
  - Num Attention Heads: 8
  - Num Hidden Layers: 18


In [5]:
## Prepare data engineering training data
# OVERVIEW: Formats data engineering examples using Gemma's chat template

from datasets import Dataset
import random

print("\n  Preparing Data Engineering Training Data...")

def format_data_engineering_example(example):
    """Format examples for Gemma 2B with data engineering focus"""
    
    # Gemma chat format
    template = """<start_of_turn>user
{system_prompt}

{user_prompt}<end_of_turn>
<start_of_turn>model
{assistant_response}<end_of_turn>"""
    
    # Determine task type and format accordingly
    task_type = example.get('task_type', 'extraction')
    
    if task_type == 'schema_inference':
        system_prompt = "You are a data engineer. Infer database schema from the given data description."
        user_prompt = example['input']
        assistant_response = json.dumps(example['output'], indent=2)
    
    elif task_type == 'data_quality':
        system_prompt = "You are a data quality engineer. Generate data quality rules for the given table description."
        user_prompt = example['input']
        assistant_response = json.dumps(example['output'], indent=2)
    
    elif task_type == 'etl_pipeline':
        system_prompt = "You are an ETL pipeline architect. Design a data pipeline from the given requirements."
        user_prompt = example['input']
        assistant_response = json.dumps(example['output'], indent=2)
    
    elif task_type == 'sql_optimization':
        system_prompt = "You are a database performance expert. Optimize the given SQL query."
        user_prompt = example['input']
        assistant_response = json.dumps(example['output'], indent=2)
    
    else:  # Default extraction task
        system_prompt = "Extract structured JSON data from the given text."
        user_prompt = example['input']
        assistant_response = json.dumps(example['output'], indent=2)
    
    return template.format(
        system_prompt=system_prompt,
        user_prompt=user_prompt,
        assistant_response=assistant_response
    )

# Format all examples
formatted_data = [format_data_engineering_example(item) for item in all_data]

# Split into train/validation (80/20)
random.shuffle(formatted_data)
split_idx = int(0.8 * len(formatted_data))
train_data = formatted_data[:split_idx]
val_data = formatted_data[split_idx:]

# Create datasets
train_dataset = Dataset.from_dict({"text": train_data})
val_dataset = Dataset.from_dict({"text": val_data})

print(f" Dataset Statistics:")
print(f"  - Total Examples: {len(formatted_data)}")
print(f"  - Training Set: {len(train_data)}")
print(f"  - Validation Set: {len(val_data)}")
print(f"  - Task Distribution:")
task_types = [item.get('task_type', 'extraction') for item in all_data]
for task in set(task_types):
    count = task_types.count(task)
    print(f"      â€¢ {task}: {count} examples ({(count/len(all_data))*100:.1f}%)")

# Add LoRA with optimized settings for Gemma
print("\n Adding LoRA Adapters for Gemma 2B...")
model = FastLanguageModel.get_peft_model(
    model,
    r=32,  # Rank - good balance for Gemma
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=64,
    lora_dropout=0.1,  # Slight dropout for regularization
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    max_seq_length=max_seq_length,
)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f" LoRA Configuration Complete:")
print(f"   Trainable Parameters: {trainable_params:,}")
print(f"   Total Parameters: {total_params:,}")
print(f"   Trainable %: {(trainable_params/total_params)*100:.3f}%")
print(f"   Memory per Parameter: ~{(total_params * 0.5) / 1e9:.1f} GB (4-bit quantized)")

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.1.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.



  Preparing Data Engineering Training Data...
 Dataset Statistics:
  - Total Examples: 2000
  - Training Set: 1600
  - Validation Set: 400
  - Task Distribution:
      â€¢ schema_inference: 400 examples (20.0%)
      â€¢ sql_optimization: 400 examples (20.0%)
      â€¢ etl_pipeline: 400 examples (20.0%)
      â€¢ extraction: 400 examples (20.0%)
      â€¢ data_quality: 400 examples (20.0%)

 Adding LoRA Adapters for Gemma 2B...


Unsloth 2026.1.3 patched 18 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


 LoRA Configuration Complete:
   Trainable Parameters: 39,223,296
   Total Parameters: 1,554,491,392
   Trainable %: 2.523%
   Memory per Parameter: ~0.8 GB (4-bit quantized)


In [6]:
## Simplified Training Pipeline
# OVERVIEW: Basic training without SFTTrainer complexities

from transformers import Trainer, TrainingArguments
from datetime import datetime

print("\nðŸŽ“ Configuring Simplified Training Pipeline...")

# Simple batch size
batch_size = 2
gradient_accumulation = 4

training_args = TrainingArguments(
    output_dir=f"simple-outputs-{MODEL_CHOICE}",
    num_train_epochs=3,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation,
    warmup_steps=10,
    learning_rate=2e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=5,
    save_steps=60,
    save_total_limit=2,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    seed=3407,
    report_to="none",
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    gradient_checkpointing=False,  # Disable for simplicity
)

# Create a simple tokenization function
def prepare_dataset(dataset):
    """Prepare dataset with proper tokenization"""
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=max_seq_length,
        )
    
    return dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
    )

print(" Preparing datasets...")
train_dataset_prepared = prepare_dataset(train_dataset)
val_dataset_prepared = prepare_dataset(val_dataset)

# Create data collator
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_prepared,
    eval_dataset=val_dataset_prepared,
    data_collator=data_collator,
)

print(f"\n Starting training...")
print(f"   Model: {MODEL_CHOICE}")
print(f"   Training samples: {len(train_dataset_prepared)}")
print(f"   Sequence length: {max_seq_length}")
print(f"   Estimated time: 15-30 minutes\n")

# Train
train_history = trainer.train()

# Save
output_dir = f"{MODEL_CHOICE}-finetuned-simple"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"\n Training Complete!")
print(f" Model saved to: {output_dir}")
print(f" Final loss: {train_history.training_loss:.4f}")


ðŸŽ“ Configuring Simplified Training Pipeline...
 Preparing datasets...


Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,600 | Num Epochs = 3 | Total steps = 600
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 39,223,296 of 2,545,395,712 (1.54% trained)



 Starting training...
   Model: gemma-2b
   Training samples: 1600
   Sequence length: 2048
   Estimated time: 15-30 minutes

Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
5,5.9044
10,1.6338
15,0.8592
20,0.4968
25,0.247
30,0.2072
35,0.2099
40,0.1344
45,0.167
50,0.1838


'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: b126cc76-41d4-4598-836f-6e5e6f3146f5)')' thrown while requesting HEAD https://huggingface.co/unsloth/gemma-2b-bnb-4bit/resolve/main/config.json
Retrying in 1s [Retry 1/5].



 Training Complete!
 Model saved to: gemma-2b-finetuned-simple
 Final loss: 0.1881


In [7]:
## Enhanced Data Engineering Model Testing Suite
# OVERVIEW: Comprehensive testing for all data engineering task types



def prepare_model_for_inference(model):
    """Prepare model for inference mode"""
    from unsloth import FastLanguageModel
    FastLanguageModel.for_inference(model)
    return model

def test_data_engineering_model(model, tokenizer, test_input: str, task_type: str = "extraction", 
                                max_seq_length: int = 2048) -> Dict[str, Any]:
    """
    Test the fine-tuned model on data engineering tasks
    
    Args:
        model: Fine-tuned model
        tokenizer: Model tokenizer
        test_input: Input text to test
        task_type: Type of data engineering task
        max_seq_length: Maximum sequence length
    
    Returns:
        Dictionary containing response and validation results
    """
    
    # System prompts for different tasks - MUST match training prompts
    system_prompts = {
        "extraction": "Extract structured JSON data from the given text.",
        "schema_inference": "You are a data engineer. Infer database schema from the given data description.",
        "data_quality": "You are a data quality engineer. Generate data quality rules for the given table description.",
        "etl_pipeline": "You are an ETL pipeline architect. Design a data pipeline from the given requirements.",
        "sql_optimization": "You are a database performance expert. Optimize the given SQL query."
    }
    
    system_prompt = system_prompts.get(task_type, system_prompts["extraction"])
    
    # Use EXACT same format as training
    # Training format: <start_of_turn>user\n{system_prompt}\n\n{user_prompt}<end_of_turn>\n<start_of_turn>model\n{assistant_response}<end_of_turn>
    # For testing: we give the prompt up to where model should start generating
    formatted_prompt = f"""<start_of_turn>user
{system_prompt}

{test_input}<end_of_turn>
<start_of_turn>model
"""
    
    # Tokenize input
    inputs = tokenizer(
        [formatted_prompt],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_seq_length,
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Generation parameters optimized for structured output
    generation_config = {
        "max_new_tokens": 512,
        "temperature": 0.1,
        "top_p": 0.9,
        "top_k": 40,
        "do_sample": False,  # Deterministic for structured data
        "repetition_penalty": 1.1,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
    }
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(**inputs, **generation_config)
    
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract model response
    if "<start_of_turn>model" in full_response:
        response = full_response.split("<start_of_turn>model")[-1].strip()
        response = response.replace("<end_of_turn>", "").strip()
    else:
        response = full_response
    
    # Validate response based on task type
    validation_result = validate_response(response, task_type)
    
    return {
        "task_type": task_type,
        "input": test_input,
        "response": response,
        "validation": validation_result,
        "formatted_prompt": formatted_prompt[:200] + "..." if len(formatted_prompt) > 200 else formatted_prompt
    }

def validate_response(response: str, task_type: str) -> Dict[str, Any]:
    """
    Validate model response based on task type
    
    Args:
        response: Model response text
        task_type: Type of task
    
    Returns:
        Validation results dictionary
    """
    validation_result = {
        "is_valid_json": False,
        "has_expected_structure": False,
        "error_message": None,
        "parsed_json": None,
        "response_length": len(response)
    }
    
    # Task-specific expected keys
    expected_keys_map = {
        "extraction": ["customer_id", "service_name", "subscription_plan", "monthly_price_usd", 
                      "device_type", "region", "usage_hours", "event_timestamp"],
        "schema_inference": ["tables", "relationships"],
        "data_quality": ["table_name", "quality_rules"],
        "etl_pipeline": ["pipeline_name", "description", "frequency", "schedule"],
        "sql_optimization": ["original_query", "optimized_query", "optimizations"]
    }
    
    # Clean response - remove markdown code blocks if present
    clean_response = response.strip()
    if clean_response.startswith("```json"):
        clean_response = clean_response[7:]
    if clean_response.endswith("```"):
        clean_response = clean_response[:-3]
    clean_response = clean_response.strip()
    
    # Special handling for SQL optimization (might return SQL directly)
    if task_type == "sql_optimization":
        if "SELECT" in clean_response.upper() or "WHERE" in clean_response.upper():
            validation_result.update({
                "is_sql_response": True,
                "has_expected_structure": True,
                "notes": "SQL response detected (valid for optimization tasks)"
            })
            return validation_result
    
    # Try to parse as JSON
    try:
        if clean_response.startswith("{") or clean_response.startswith("["):
            parsed = json.loads(clean_response)
            validation_result["is_valid_json"] = True
            validation_result["parsed_json"] = parsed
            
            # Check for expected structure
            if task_type in expected_keys_map:
                if isinstance(parsed, dict):
                    expected_keys = expected_keys_map[task_type]
                    # Check if any expected key is present (not necessarily all)
                    found_keys = [key for key in expected_keys if key in parsed]
                    if found_keys:
                        validation_result["has_expected_structure"] = True
                        validation_result["found_keys"] = found_keys
                        validation_result["missing_keys"] = [k for k in expected_keys if k not in found_keys]
                    else:
                        validation_result["error_message"] = f"No expected keys found. Expected any of: {expected_keys}"
                else:
                    validation_result["error_message"] = f"Response is not a dictionary. Type: {type(parsed)}"
        else:
            validation_result["error_message"] = "Response does not start with { or ["
            
    except json.JSONDecodeError as e:
        validation_result["error_message"] = f"JSON decode error: {str(e)}"
        # Try to find where JSON breaks
        lines = clean_response.split('\n')
        for i, line in enumerate(lines):
            try:
                json.loads(line)
            except:
                validation_result["json_error_line"] = i + 1
                validation_result["json_error_content"] = line[:100]
                break
    
    return validation_result

def run_comprehensive_test_suite(model, tokenizer, max_seq_length: int = 2048):
    """
    Run comprehensive test suite covering all data engineering tasks
    
    Args:
        model: Fine-tuned model
        tokenizer: Model tokenizer
        max_seq_length: Maximum sequence length
    """
    
    print("\n" + "=" * 100)
    print("COMPREHENSIVE DATA ENGINEERING LLM TEST SUITE")
    print("=" * 100)
    
    # Comprehensive test suite covering all task types
    test_suite = [
        # 1. Data Extraction Tests
        {
            "task": "Data Extraction - Basic",
            "input": "Extract subscription usage details:\nCustomer 10001 used Apple Music on Game Console under the Basic plan costing $16.23 in region AU.",
            "type": "extraction",
            "description": "Basic structured data extraction"
        },
        {
            "task": "Data Extraction - Multiple Services",
            "input": "Extract subscription usage details:\nCustomer 20045 used Netflix on Smart TV under the Premium plan costing $19.99 in region US with 8.5 hours of usage on 2024-07-15.",
            "type": "extraction",
            "description": "Extraction with additional details"
        },
        
        # 2. Schema Inference Tests
        {
            "task": "Schema Inference - Streaming Service",
            "input": "Design a database schema for tracking Disney+ subscription usage. Include customer details, subscription info, and usage metrics.",
            "type": "schema_inference",
            "description": "Complete database schema design"
        },
        {
            "task": "Schema Inference - Simple Table",
            "input": "Infer schema from CSV with columns: 'user_id,email,signup_date,last_login,plan_type' with sample: '123,user@example.com,2024-01-15,2024-06-20,premium'",
            "type": "schema_inference",
            "description": "Schema inference from CSV description"
        },
        
        # 3. Data Quality Tests
        {
            "task": "Data Quality - Subscription Data",
            "input": "Generate data quality rules for subscription data with fields: customer_id, service_name, subscription_plan, monthly_price_usd, device_type, region, usage_hours, event_timestamp",
            "type": "data_quality",
            "description": "Comprehensive data quality rules"
        },
        {
            "task": "Data Quality - Customer Table",
            "input": "Generate data quality rules for a customer table with columns: customer_id (int), email (string), age (int), signup_date (date)",
            "type": "data_quality",
            "description": "Basic table quality rules"
        },
        
        # 4. ETL Pipeline Tests
        {
            "task": "ETL Pipeline - Netflix Daily",
            "input": "Design an ETL pipeline to process daily Netflix subscription data from CSV files to a data warehouse for analytics.",
            "type": "etl_pipeline",
            "description": "Daily batch ETL pipeline"
        },
        {
            "task": "ETL Pipeline - Real-time Streaming",
            "input": "Design a real-time ETL pipeline for Spotify subscription events from Kafka to analytics database.",
            "type": "etl_pipeline",
            "description": "Real-time streaming pipeline"
        },
        
        # 5. SQL Optimization Tests
        {
            "task": "SQL Optimization - Date Filter",
            "input": "Optimize this SQL query: SELECT * FROM subscriptions WHERE DATE(event_timestamp) = '2024-06-15' AND region = 'AU'",
            "type": "sql_optimization",
            "description": "Date function optimization"
        },
        {
            "task": "SQL Optimization - Aggregation",
            "input": "Optimize this SQL query: SELECT service_name, COUNT(*) as total_subscriptions, AVG(monthly_price_usd) as avg_price FROM subscriptions GROUP BY service_name ORDER BY total_subscriptions DESC",
            "type": "sql_optimization",
            "description": "Aggregation query optimization"
        },
        {
            "task": "SQL Optimization - Join",
            "input": "Optimize this SQL query: SELECT c.customer_id, c.email, s.service_name, s.monthly_price_usd FROM customers c JOIN subscriptions s ON c.customer_id = s.customer_id WHERE s.region = 'US' AND s.is_active = true",
            "type": "sql_optimization",
            "description": "Join query optimization"
        },
        
        # 6. Edge Cases
        {
            "task": "Edge Case - Minimal Input",
            "input": "Extract data: Customer 999 used Service X.",
            "type": "extraction",
            "description": "Minimal input test"
        },
        {
            "task": "Edge Case - Complex SQL",
            "input": "Optimize: SELECT * FROM (SELECT customer_id, service_name, ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY event_timestamp DESC) as rn FROM subscriptions) WHERE rn = 1",
            "type": "sql_optimization",
            "description": "Complex window function optimization"
        }
    ]
    
    # Statistics
    total_tests = len(test_suite)
    passed_tests = 0
    failed_tests = 0
    
    # Run all tests
    for i, test_case in enumerate(test_suite, 1):
        print(f"\n{'='*80}")
        print(f"  TEST {i}/{total_tests}: {test_case['task']}")
        print(f"{'='*80}")
        print(f" Description: {test_case['description']}")
        print(f" Task Type: {test_case['type']}")
        print(f" Input: {test_case['input'][:120]}..." if len(test_case['input']) > 120 else f"ðŸ“¥ Input: {test_case['input']}")
        
        try:
            # Run test
            result = test_data_engineering_model(
                model=model,
                tokenizer=tokenizer,
                test_input=test_case['input'],
                task_type=test_case['type'],
                max_seq_length=max_seq_length
            )
            
            # Display response
            print(f"\n Response ({len(result['response'])} chars):")
            print(f"{'-'*60}")
            
            # Truncate long responses
            response_display = result['response']
            if len(response_display) > 500:
                response_display = response_display[:250] + "\n... [TRUNCATED] ...\n" + response_display[-250:]
            print(response_display)
            print(f"{'-'*60}")
            
            # Display validation results
            validation = result['validation']
            print(f"\n Validation Results:")
            print(f"   âœ“ Valid JSON: {validation['is_valid_json']}")
            print(f"   âœ“ Expected Structure: {validation['has_expected_structure']}")
            print(f"   âœ“ Response Length: {validation['response_length']} characters")
            
            if validation['is_valid_json'] and validation['parsed_json']:
                if isinstance(validation['parsed_json'], dict):
                    print(f"   âœ“ Keys in response: {list(validation['parsed_json'].keys())[:10]}{'...' if len(validation['parsed_json']) > 10 else ''}")
            
            if validation.get('found_keys'):
                print(f"   âœ“ Found expected keys: {validation['found_keys']}")
            
            if validation.get('missing_keys'):
                print(f"     Missing keys: {validation['missing_keys']}")
            
            if validation.get('is_sql_response'):
                print(f"    SQL response detected (valid for optimization task)")
            
            if validation['error_message']:
                print(f"    Error: {validation['error_message']}")
                if validation.get('json_error_line'):
                    print(f"    JSON error at line {validation['json_error_line']}: {validation.get('json_error_content', 'N/A')}")
                failed_tests += 1
            else:
                if validation['is_valid_json'] or validation.get('is_sql_response') or validation['has_expected_structure']:
                    print(f"    TEST PASSED")
                    passed_tests += 1
                else:
                    print(f"     TEST PARTIALLY PASSED - Check structure")
                    passed_tests += 0.5  # Half credit
            
        except Exception as e:
            print(f"\n TEST FAILED WITH EXCEPTION:")
            print(f"   Error: {str(e)}")
            print(f"   Type: {type(e).__name__}")
            failed_tests += 1
        
        # Small delay between tests
        import time
        time.sleep(0.5)
    
    # Print summary
    print(f"\n{'='*100}")
    print("TEST SUITE SUMMARY")
    print(f"{'='*100}")
    print(f" Total Tests: {total_tests}")
    print(f" Passed: {passed_tests}")
    print(f" Failed: {failed_tests}")
    print(f" Success Rate: {(passed_tests/total_tests)*100:.1f}%")
    
    # Task-specific breakdown
    print(f"\n Task Type Breakdown:")
    task_results = {}
    for test_case in test_suite:
        task_type = test_case['type']
        if task_type not in task_results:
            task_results[task_type] = {"total": 0, "passed": 0}
        task_results[task_type]["total"] += 1
    
    # Note: In a real implementation, you'd track actual pass/fail per task type
    for task_type, counts in task_results.items():
        print(f"   {task_type:20} {counts['total']:2} tests")
    
    print(f"\n Recommendations:")
    if passed_tests / total_tests > 0.7:
        print("   âœ“ Model is performing well across most tasks")
    else:
        print("     Model needs improvement. Consider:")
        print("      - More training epochs")
        print("      - Better training data balance")
        print("      - Adjusting learning rate")
    
    print(f"{'='*100}\n")
    
    return {
        "total_tests": total_tests,
        "passed_tests": passed_tests,
        "failed_tests": failed_tests,
        "success_rate": (passed_tests/total_tests)*100
    }

# How to use in your training script:
"""
# After training, add this to your script:
print("\n" + "="*100)
print("RUNNING COMPREHENSIVE TEST SUITE")
print("="*100)

# Prepare model for inference
model = prepare_model_for_inference(model)

# Run comprehensive tests
test_results = run_comprehensive_test_suite(model, tokenizer, max_seq_length)

print("Testing complete! Model is ready for deployment." if test_results["success_rate"] > 70 else 
      "Model needs improvement before deployment.")
"""

'\n# After training, add this to your script:\nprint("\n" + "="*100)\nprint("RUNNING COMPREHENSIVE TEST SUITE")\nprint("="*100)\n\n# Prepare model for inference\nmodel = prepare_model_for_inference(model)\n\n# Run comprehensive tests\ntest_results = run_comprehensive_test_suite(model, tokenizer, max_seq_length)\n\nprint("Testing complete! Model is ready for deployment." if test_results["success_rate"] > 70 else \n      "Model needs improvement before deployment.")\n'

In [None]:
## Data Engineering Model Testing Suite
# OVERVIEW: Comprehensive testing for data engineering tasks

FastLanguageModel.for_inference(model)

def test_data_engineering_model(model, tokenizer, test_input, task_type="extraction"):
    """Test the fine-tuned model on data engineering tasks"""
    
    # System prompts for different tasks
    system_prompts = {
        "extraction": "Extract structured JSON data from the given text.",
        "schema_inference": "You are a data engineer. Infer database schema from the given data description.",
        "data_quality": "You are a data quality engineer. Generate data quality rules for the given table description.",
        "etl_pipeline": "You are an ETL pipeline architect. Design a data pipeline from the given requirements.",
        "sql_optimization": "You are a database performance expert. Optimize the given SQL query."
    }
    
    system_prompt = system_prompts.get(task_type, system_prompts["extraction"])
    
    # Gemma chat format
    formatted_prompt = f"""<start_of_turn>user
{system_prompt}

{test_input}<end_of_turn>
<start_of_turn>model
"""
    
    inputs = tokenizer(
        [formatted_prompt],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_seq_length,
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Generation parameters optimized for structured output
    generation_config = {
        "max_new_tokens": 512,
        "temperature": 0.1,
        "top_p": 0.9,
        "top_k": 40,
        "do_sample": False,  # Deterministic for structured data
        "repetition_penalty": 1.1,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
    }
    
    with torch.no_grad():
        outputs = model.generate(**inputs, **generation_config)
    
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract model response
    if "<start_of_turn>model" in full_response:
        response = full_response.split("<start_of_turn>model")[-1].strip()
        response = response.replace("<end_of_turn>", "").strip()
    else:
        response = full_response
    
    return response

print("\n Testing Data Engineering Capabilities...")

# Test cases covering different data engineering tasks
test_suite = [
    {
        "task": "Data Extraction",
        "input": "Extract subscription usage details:\nCustomer 10001 used Apple Music on Game Console under the Basic plan costing $16.23 in region AU.",
        "type": "extraction"
    },
    {
        "task": "Schema Inference",
        "input": "Infer schema from CSV with columns: 'user_id,email,signup_date,last_login,plan_type' with sample: '123,user@example.com,2024-01-15,2024-06-20,premium'",
        "type": "schema_inference"
    },
    {
        "task": "Data Quality Rules",
        "input": "Generate data quality rules for a customer table with columns: customer_id (int), email (string), age (int), signup_date (date)",
        "type": "data_quality"
    },
    {
        "task": "4. ETL Pipeline Design",
        "input": "Design an ETL pipeline to process daily Netflix subscription data from CSV files to a data warehouse for analytics.",
        "type": "etl_pipeline"
    },
    {
        "task": "SQL Optimization",
        "input": "Optimize this SQL query: SELECT * FROM orders WHERE DATE(created_at) = '2024-06-01' ORDER BY amount DESC",
        "type": "sql_optimization"
    }  
]

print("=" * 80)
for i, test_case in enumerate(test_suite, 1):
    print(f"\nðŸ”¬ Test {i}: {test_case['task']}")
    print(f" Input: {test_case['input'][:80]}...")
    
    response = test_data_engineering_model(
        model, 
        tokenizer, 
        test_case['input'],
        test_case['type']
    )
    
    print(f"\n Response:")
    print(f"{'-'*40}")
    print(response[:500])
    if len(response) > 500:
        print("... (truncated)")
    print(f"{'-'*40}")
    
    # Validate JSON if applicable
    if test_case['type'] in ['extraction', 'schema_inference', 'data_quality']:
        try:
            # Try to parse as JSON
            if response.strip().startswith("{") or response.strip().startswith("["):
                parsed = json.loads(response)
                print(f" Valid JSON with {len(str(parsed))} characters")
            else:
                print("  Response is not JSON format")
        except json.JSONDecodeError:
            print(" Invalid JSON format")
    
    print(f"{'='*80}")


 Testing Data Engineering Capabilities...

ðŸ”¬ Test 1: Schema Inference
 Input: Infer schema from CSV with columns: 'user_id,email,signup_date,last_login,plan_t...

 Response:
----------------------------------------
user
You are a data engineer. Infer database schema from the given data description.

Infer schema from CSV with columns: 'user_id,email,signup_date,last_login,plan_type' with sample: '123,user@example.com,2024-01-15,2024-06-20,premium'
model
{
  "tables": [
    {
      "table_name": "users",
      "columns": [
        {
          "name": "user_id",
          "type": "VARCHAR(20)",
          "primary_key": true,
          "nullable": false
        },
        {
          "name": "email",
        
... (truncated)
----------------------------------------
  Response is not JSON format

ðŸ”¬ Test 2: Data Quality Rules
 Input: Generate data quality rules for a customer table with columns: customer_id (int)...

 Response:
----------------------------------------
user
You are a 

In [13]:
## Enhanced Data Engineering Model Testing Suite
# OVERVIEW: Test all 5 data engineering task types sequentially

FastLanguageModel.for_inference(model)

def test_data_engineering_model(model, tokenizer, test_input, task_type="extraction"):
    """Test the fine-tuned model on data engineering tasks"""
    
    # System prompts for different tasks
    system_prompts = {
        "extraction": "Extract structured JSON data from the given text.",
        "schema_inference": "You are a data engineer. Infer database schema from the given data description.",
        "data_quality": "You are a data quality engineer. Generate data quality rules for the given table description.",
        "etl_pipeline": "You are an ETL pipeline architect. Design a data pipeline from the given requirements.",
        "sql_optimization": "You are a database performance expert. Optimize the given SQL query."
    }
    
    system_prompt = system_prompts.get(task_type, system_prompts["extraction"])
    
    # Gemma chat format
    formatted_prompt = f"""<start_of_turn>user
{system_prompt}

{test_input}<end_of_turn>
<start_of_turn>model
"""
    
    inputs = tokenizer(
        [formatted_prompt],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_seq_length,
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Generation parameters optimized for structured output
    generation_config = {
        "max_new_tokens": 512,
        "temperature": 0.1,
        "top_p": 0.9,
        "top_k": 40,
        "do_sample": False,  # Deterministic for structured data
        "repetition_penalty": 1.1,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
    }
    
    with torch.no_grad():
        outputs = model.generate(**inputs, **generation_config)
    
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract model response
    if "<start_of_turn>model" in full_response:
        response = full_response.split("<start_of_turn>model")[-1].strip()
        response = response.replace("<end_of_turn>", "").strip()
    else:
        response = full_response
    
    return response

print("\n" + "=" * 80)
print("COMPREHENSIVE DATA ENGINEERING LLM TESTING")
print("Testing all 5 task types sequentially")
print("=" * 80)

# Test cases covering ALL 5 data engineering tasks
test_suite = [
    {
        "task": "1. Data Extraction",
        "input": "Extract subscription usage details:\nCustomer 10001 used Apple Music on Game Console under the Basic plan costing $16.23 in region AU.",
        "type": "extraction"
    },
    {
        "task": "2. Schema Inference", 
        "input": "Design a database schema for tracking Netflix subscription usage. Include customer details, subscription info, and usage metrics.",
        "type": "schema_inference"
    },
    {
        "task": "3. Data Quality Rules",
        "input": "Generate data quality rules for subscription data with fields: customer_id, service_name, subscription_plan, monthly_price_usd, device_type, region, usage_hours, event_timestamp",
        "type": "data_quality"
    },
    {
        "task": "4. ETL Pipeline Design",
        "input": "Design an ETL pipeline to process daily Netflix subscription data from CSV files to a data warehouse for analytics.",
        "type": "etl_pipeline"
    },
    {
        "task": "5. SQL Optimization",
        "input": "Optimize this SQL query: SELECT * FROM subscriptions WHERE DATE(event_timestamp) = '2024-06-15' AND region = 'AU'",
        "type": "sql_optimization"
    }
]

# Test results storage
all_results = []

# Test each task type sequentially
print("\n STARTING TESTS...")
print("=" * 80)

for i, test_case in enumerate(test_suite, 1):
    print(f"\n TEST {i}: {test_case['task']}")
    print(f" Task Type: {test_case['type']}")
    print(f" Input: {test_case['input'][:100]}..." if len(test_case['input']) > 100 else f"ðŸ“¥ Input: {test_case['input']}")
    print("-" * 60)
    
    # Run the test
    response = test_data_engineering_model(
        model, 
        tokenizer, 
        test_case['input'],
        test_case['type']
    )
    
    # Store result
    result = {
        "test_number": i,
        "task": test_case['task'],
        "task_type": test_case['type'],
        "input": test_case['input'],
        "response": response,
        "response_length": len(response)
    }
    all_results.append(result)
    
    # Display response
    print(f" Response ({len(response)} characters):")
    print(f"{'-'*40}")
    
    # Show full response (truncate if too long)
    if len(response) > 800:
        print(response[:400])
        print("... [TRUNCATED - MIDDLE REMOVED] ...")
        print(response[-400:])
    else:
        print(response)
    
    print(f"{'-'*40}")
    
    # Validate based on task type
    print(f"\n Validation:")
    
    # Check if response is JSON for appropriate tasks
    if test_case['type'] in ['extraction', 'schema_inference', 'data_quality', 'etl_pipeline']:
        try:
            # Clean response (remove markdown code blocks if present)
            clean_response = response.strip()
            if clean_response.startswith("```json"):
                clean_response = clean_response[7:]
            if clean_response.endswith("```"):
                clean_response = clean_response[:-3]
            clean_response = clean_response.strip()
            
            if clean_response.startswith("{") or clean_response.startswith("["):
                parsed = json.loads(clean_response)
                print(f"  âœ“ Valid JSON format")
                print(f"  âœ“ Contains {len(str(parsed))} characters")
                
                # Show keys if it's a dictionary
                if isinstance(parsed, dict):
                    keys = list(parsed.keys())
                    print(f"  âœ“ Keys in response: {keys[:8]}{'...' if len(keys) > 8 else ''}")
                else:
                    print(f"  âœ“ Response is a JSON array/list")
            else:
                print(f"    Response is not JSON format (doesn't start with {{ or [)")
        except json.JSONDecodeError as e:
            print(f"   Invalid JSON format: {str(e)}")
    
    # Special handling for SQL optimization
    elif test_case['type'] == 'sql_optimization':
        if 'SELECT' in response.upper() or 'WHERE' in response.upper():
            print(f"  âœ“ SQL response detected (expected for optimization)")
            # Check if it looks like optimized SQL
            if 'DATE(' not in response or '>=' in response:
                print(f"  âœ“ Appears to contain optimization suggestions")
        else:
            print(f"    May not be SQL format")
    
    print(f"   Response length: {len(response)} characters")
    
    print(f"\n{'='*80}")

# Display summary of all 5 results
print("\n" + "=" * 80)
print(" TEST RESULTS SUMMARY (All 5 Tasks)")
print("=" * 80)

for result in all_results:
    print(f"\n{result['task']}:")
    print(f"  Type: {result['task_type']}")
    print(f"  Response Length: {result['response_length']} chars")
    
    # Quick validation indicator
    if result['task_type'] == 'sql_optimization':
        if 'SELECT' in result['response'].upper():
            print(f"  Status:  SQL detected")
        else:
            print(f"  Status:   May not be SQL")
    else:
        try:
            clean_resp = result['response'].strip()
            if clean_resp.startswith("```json"):
                clean_resp = clean_resp[7:]
            if clean_resp.endswith("```"):
                clean_resp = clean_resp[:-3]
            clean_resp = clean_resp.strip()
            
            if clean_resp.startswith("{") or clean_resp.startswith("["):
                json.loads(clean_resp)
                print(f"  Status:  Valid JSON")
            else:
                print(f"  Status:   Not JSON format")
        except:
            print(f"  Status:  Invalid JSON")

print(f"\n{'='*80}")
print(" TESTING COMPLETE: All 5 task types have been tested")
print(f"   Total tests: {len(all_results)}")
print(f"   Tasks tested: {', '.join([r['task_type'] for r in all_results])}")
print("=" * 80)

# Optional: Save results to file
try:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_filename = f"test_results_{timestamp}.json"
    
    # Prepare results for saving
    save_data = []
    for result in all_results:
        save_data.append({
            "task": result['task'],
            "task_type": result['task_type'],
            "input": result['input'],
            "response": result['response'],
            "response_length": result['response_length'],
            "timestamp": datetime.now().isoformat()
        })
    print(result)
    
    with open(results_filename, 'w') as f:
        json.dump(save_data, f, indent=2)
    
    print(f"\n Results saved to: {results_filename}")
except Exception as e:
    print(f"\n  Could not save results: {e}")

print("\n")


COMPREHENSIVE DATA ENGINEERING LLM TESTING
Testing all 5 task types sequentially

 STARTING TESTS...

 TEST 1: 1. Data Extraction
 Task Type: extraction
 Input: Extract subscription usage details:
Customer 10001 used Apple Music on Game Console under the Basic ...
------------------------------------------------------------
 Response (2423 characters):
----------------------------------------
user
Extract structured JSON data from the given text.

Extract subscription usage details:
Customer 10001 used Apple Music on Game Console under the Basic plan costing $16.23 in region AU.
model
{
  "customer_id": "CUST-10001",
  "service_name": "Apple Music",
  "subscription_plan": "Basic",
  "monthly_price_usd": 16.23,
  "device_type": "Game Console",
  "region": "AU",
  "usage_hours": 5.98,
  
... [TRUNCATED - MIDDLE REMOVED] ...
">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...">...