<a href="https://colab.research.google.com/github/iansawicki/fireworks/blob/main/walkthroughs/VLM_tuning_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Supervised VLM Fine-Tuning on Fireworks

This notebook demonstrates the **Fireworks CLI commands** for fine-tuning with minimal Python wrapper code.

## 📋 Table of Contents

| Section | Description | Key CLI Commands |
|---------|-------------|------------------|
| **0. Setup** | Install dependencies and environment | - |
| **1. Authentication** | Setup CLI and verify credentials | `firectl signin`, `firectl whoami` |
| **2. Dataset Preparation** | Load and convert Spider dataset | - |
| **3. Dataset Upload** | Upload to Fireworks (Local + GCS) | `firectl create dataset` |
| **4. Fine-tuning** | Create and monitor training job | `firectl create supervised-fine-tuning-job` |
| **5. Deployment** | Deploy trained model | `firectl create deployment` |
| **6. Testing** | Query your deployed model | `curl` + Python requests |

## 🎯 What You'll Learn

- **CLI Commands**: The exact `firectl` commands to run
- **End-to-End Workflow**: From dataset to deployed model  
- **Two Approaches**: Local files vs. Google Cloud Storage
- **Best Practices**: Dataset formats, fine-tuning parameters, testing

## 📋 Prerequisites

1. **Fireworks AI Account**: Sign up at [fireworks.ai](https://fireworks.ai)
2. **Fireworks CLI (mac)**: `brew tap fw-ai/firectl && brew install firectl`
3. **API Key**: `export FIREWORKS_API_KEY='your-key'`
4. **Optional**: HuggingFace token for private datasets

## 🚀 Quick Start

If you just want the commands, here's the complete workflow:

```bash
# 1. Setup
export FIREWORKS_API_KEY='your-key'
firectl signin

# 2. Upload dataset (choose one)
firectl create dataset my-dataset ./data.jsonl                    # Local file
firectl create dataset my-dataset --external-url gs://bucket/data.jsonl  # GCS

# 3. Start fine-tuning
firectl create supervised-fine-tuning-job \
  --dataset my-dataset --output-model my-model \
  --base-model accounts/fireworks/models/qwen2p5-coder-32b-instruct \
  --epochs 3 --turbo --early-stop --eval-auto-carveout

# 4. Monitor and deploy
firectl get supervised-fine-tuning-job JOB_ID
firectl create deployment accounts/ACCOUNT/models/my-model --enable-addons
```

In [None]:
# Install required packages
!pip install jsonlines
!wget -O firectl.gz https://storage.googleapis.com/fireworks-public/firectl/stable/linux-amd64.gz
!gunzip firectl.gz
!sudo install -o root -g root -m 0755 firectl /usr/local/bin/firectl


# 1. Authentication & Environment Setup

## 📋 Section Overview
- Install Fireworks CLI
- Set API key and authenticate
- Verify setup with Python
- Optional: Setup HuggingFace access

## 🔧 Required CLI Commands

Run these commands in your terminal **before** running this notebook:

```bash
# 1. Install Fireworks CLI (on mac)
brew tap fw-ai/firectl && brew install firectl

# 1.2 Or Linux (If you're using Collab, do this in the terminal.)
wget -O firectl.gz https://storage.googleapis.com/fireworks-public/firectl/stable/linux-amd64.gz
gunzip firectl.gz
sudo install -o root -g root -m 0755 firectl /usr/local/bin/firectl

# 2. Set your API key
export FIREWORKS_API_KEY='your-api-key-here'

# 3. Sign into Fireworks
firectl signin

# 4. Verify authentication
firectl whoami

# 5. Set your HF API key
```

**Important**: The Python code below will check if these steps were completed successfully.

In [None]:
# Securely input API key and sign in
import getpass
import os

# Get API key securely
api_key = getpass.getpass("Enter your Fireworks API key: ")

# Set and use in one go
os.environ['FIREWORKS_API_KEY'] = api_key
!firectl signin

# Check Prerequisites
import subprocess
import os

print("🔧 Checking Prerequisites...")

# Check API key
if api_key:
    print(f"✅ API Key found: {api_key[:10]}...")
else:
    print("❌ No FIREWORKS_API_KEY found")
    print("💡 Set it with: export FIREWORKS_API_KEY='your-key'")

# Check firectl
try:
    result = subprocess.run(["firectl", "version"], capture_output=True, text=True)
    if result.returncode == 0:
        print(f"✅ firectl installed: {result.stdout.strip()}")
    else:
        print("❌ firectl not working")
except FileNotFoundError:
    print("❌ firectl not found")
    print("💡 Install with: brew tap fw-ai/firectl && brew install firectl")

# Check if signed in
try:
    result = subprocess.run(["firectl", "whoami"], capture_output=True, text=True)
    if result.returncode == 0:
        print("✅ Signed into Fireworks CLI")
        print(result.stdout.strip())
    else:
        print("⚠️  Not signed into Fireworks CLI")
        print("💡 Sign in with: firectl signin")
except:
    print("⚠️  Could not check Fireworks CLI status")

print("🚀 Prerequisites check complete!")


## Load the Dataset from Hugging Face

We'll be loading the PubTabNet dataset from Hugging Face, which provides a convenient way to access the data without dealing with large downloads.

In [None]:
# importing prerequisites
import sys
import requests
import tarfile
import jsonlines
import numpy as np
from os import path
from PIL import Image
from PIL import ImageFont, ImageDraw
from glob import glob
from matplotlib import pyplot as plt
from datasets import load_dataset
%matplotlib inline

# Load only a subset of the PubTabNet dataset from Hugging Face
print("Loading a subset of PubTabNet dataset from Hugging Face...")

# # Option 1: Load first 1000 examples from train split (most efficient)
# ds = load_dataset("apoidea/pubtabnet-html", split="train[:1000]")
# print(f"Loaded {len(ds)} examples from the train split")

# Option 2: If you want a specific range (e.g., examples 1000-2000)
# ds = load_dataset("apoidea/pubtabnet-html", split="train[1000:2000]")

# Option 3: For a percentage of the data (e.g., first 10%)
# ds = load_dataset("apoidea/pubtabnet-html", split="train[:10%]")

# Option 4: For multiple splits with subsets
ds = load_dataset("apoidea/pubtabnet-html", split={
    "train": "train[:800]",     # first 800 for training
    "test": "train[800:1000]"   # next 200 for testing
})

print(f"Dataset loaded successfully!")
print(f"Dataset type: {type(ds)}")
print(f"Number of examples: {len(ds)}")
print(f"Columns: {ds.column_names}")

In [None]:
# For easier access, extract just the train split
ds_train = ds['train']
ds_test = ds['test']

print(f"Train dataset: {len(ds_train)} examples")
print(f"Test dataset: {len(ds_test)} examples")

# Look at the first 5 examples
print(f"\nFirst 5 examples:")
for i in range(5):
    example = ds_train[i]
    print(f"\n--- Example {i+1} ---")
    for key, value in example.items():
        if isinstance(value, str) and len(value) > 100:
            print(f"  {key}: {value[:100]}...")
        elif key == 'image':
            print(f"  {key}: <PIL Image object>")
        else:
            print(f"  {key}: {value}")

# Create sample for detailed exploration
sample_data = ds_train.select(range(5))
print(f"\nSelected {len(sample_data)} examples for detailed exploration")

# Display the first 5 images in a clean 2x3 grid layout
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()  # Flatten to make indexing easier

for i in range(5):
    img_example = ds_train[i]
    axes[i].imshow(img_example['image'])
    axes[i].axis('off')
    axes[i].set_title(f"Image {i+1}: {img_example['imgid']}", fontsize=12, pad=10)

# Hide the 6th subplot since we only have 5 images
axes[5].axis('off')

plt.tight_layout()
plt.show()

# 3. Dataset Upload to Fireworks

## 📋 Section Overview
- Convert data to Chat Completions format (required by Fireworks)
- Create JSONL files for upload
- Upload using `firectl create dataset` command
- Show both local file and GCS upload methods

## 🔄 Convert to Chat Completions Format

Fireworks requires data in Chat Completions format with `system`, `user`, and `assistant` messages.

### Dataset Requirements:
- Format: .jsonl file
- Minimum examples: 3
- Maximum examples: 3 million per dataset
- Images: Must be base64 encoded with proper MIME type prefixes
- Supported image formats: PNG, JPG, JPEG

- Message Schema: Each training example must include a messages array where each message has:
- role: one of system, user, or assistant
- content: an array containing text and image objects or just text

## Convert image to base64 format, and generate jsonl file
In this case we will ask 2 questions per image

In [None]:
# Combined: Convert images to base64, create JSONL file, and display first example
import base64
from io import BytesIO
import json
from bs4 import BeautifulSoup
import random
from PIL import Image
import matplotlib.pyplot as plt

def pil_to_base64(pil_image):
    """Convert PIL Image to base64 string"""
    buffered = BytesIO()
    pil_image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/png;base64,{img_str}"

def create_table_qa_jsonl_single(base64_data, output_file="pubtabnet_table_qa.jsonl"):
    """Create JSONL file with exactly ONE example per image"""

    def extract_table_info(html_table):
        """Extract meaningful information from HTML table"""
        try:
            soup = BeautifulSoup(html_table, 'html.parser')
            table = soup.find('table')
            if not table:
                return "This appears to be a table with structured data."

            rows = table.find_all('tr')
            if not rows:
                return "This appears to be a table with structured data."

            # Extract headers if available
            headers = []
            first_row = rows[0]
            header_cells = first_row.find_all(['th', 'td'])
            headers = [cell.get_text(strip=True) for cell in header_cells if cell.get_text(strip=True)]

            # Count rows and columns
            num_rows = len(rows)
            num_cols = len(header_cells) if header_cells else 0

            # Extract some sample data
            sample_data = []
            for row in rows[1:3]:  # Get first 2 data rows
                cells = row.find_all(['td', 'th'])
                row_data = [cell.get_text(strip=True) for cell in cells]
                if any(row_data):  # Only add non-empty rows
                    sample_data.append(row_data)

            description = f"This table has {num_rows} rows and {num_cols} columns."
            if headers:
                description += f" The columns are: {', '.join(headers)}."

            return description, headers, sample_data

        except Exception as e:
            return "This appears to be a table with structured data.", [], []

    # Different question templates - pick ONE per image
    question_templates = [
        "What information is shown in this table?",
        "Describe the content and structure of this table.",
        "What data does this table contain?",
        "Can you analyze this table and tell me what it shows?",
        "What are the main elements of this table?",
        "Summarize the information presented in this table.",
        "What can you tell me about the data in this table?",
        "Describe what this table is displaying.",
    ]

    with open(output_file, 'w') as f:
        for i, example in enumerate(base64_data):
            # Extract information from the HTML table
            description, headers, sample_data = extract_table_info(example['html_table'])

            # Create a comprehensive response about the table content
            response = description
            if headers:
                response += f" The table includes columns for {', '.join(headers)}."
            if sample_data:
                response += " Based on the visible data, this table contains structured information that can be queried and analyzed."

            # Pick just ONE question per image
            question = random.choice(question_templates)

            training_example = {
                "messages": [
                    {
                        "role": "system",
                        "content": "You are a helpful assistant that can analyze tables and documents to answer questions about their content. Focus on understanding what information is presented and be ready to answer specific questions about the data."
                    },
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": example['image_base64']
                                }
                            },
                            {
                                "type": "text",
                                "text": question
                            }
                        ]
                    },
                    {
                        "role": "assistant",
                        "content": response
                    }
                ]
            }

            f.write(json.dumps(training_example) + '\n')

            if (i + 1) % 25 == 0:
                print(f"Processed {i + 1}/{len(base64_data)} examples...")

    print(f"✓ Created {output_file} with exactly {len(base64_data)} training examples")
    return output_file

def show_jsonl_with_image_fixed(jsonl_file):
    """Display first line with truncated base64 AND show the actual image"""

    with open(jsonl_file, 'r') as f:
        first_line = f.readline().strip()

    example = json.loads(first_line)

    # Create a copy with truncated base64
    example_copy = json.loads(json.dumps(example))  # Deep copy

    # Find the image content item
    image_content_item = None
    image_index = None
    for i, content in enumerate(example['messages'][1]['content']):
        if content.get('type') == 'image_url':
            image_content_item = content
            image_index = i
            break

    if image_content_item is None:
        print("ERROR: No image_url content found!")
        return

    # Get the actual image URL
    image_content = image_content_item['image_url']['url']

    if ',' in image_content:
        prefix, base64_data = image_content.split(',', 1)
        truncated = f"{prefix},{base64_data[:50]}...[TRUNCATED {len(base64_data)} total chars]"
        example_copy['messages'][1]['content'][image_index]['image_url']['url'] = truncated

    # Show the JSON structure
    pretty_json = json.dumps(example_copy, indent=2, ensure_ascii=False)

    print("=" * 60)
    print("First JSONL example (with truncated base64):")
    print("=" * 60)
    print(pretty_json)

    print("\n" + "=" * 60)
    print("CORRESPONDING IMAGE:")
    print("=" * 60)

    # Convert base64 back to image and display
    try:
        # Extract and decode the base64 image
        if ',' in image_content:
            image_data = image_content.split(',')[1]
        else:
            image_data = image_content

        # Decode base64 to image
        image_bytes = base64.b64decode(image_data)
        image = Image.open(BytesIO(image_bytes))

        # Display using matplotlib
        plt.figure(figsize=(12, 8))
        plt.imshow(image)
        plt.axis('off')
        plt.title("Table Image from JSONL Training Data", fontsize=14, pad=20)
        plt.show()

        # Show details
        print(f"\nImage details:")
        print(f"Size: {image.size}")
        print(f"Mode: {image.mode}")
        print(f"Base64 length: {len(image_data)} characters")

        # Show the question and answer
        text_content = None
        for content in example['messages'][1]['content']:
            if content.get('type') == 'text':
                text_content = content.get('text')
                break

        print(f"\nTraining pair:")
        print(f"Question: {text_content}")
        print(f"Answer: {example['messages'][2]['content'][:300]}...")

    except Exception as e:
        print(f"Error displaying image: {e}")
        import traceback
        traceback.print_exc()

# ========== MAIN EXECUTION ==========

print("🔄 Step 1: Converting first 100 images to base64...")

# Get first 100 examples and convert images to base64
ds_train = ds['train']
first_100 = ds_train.select(range(100))

# Create a list to store the converted data
base64_data = []

for i in range(100):
    example = first_100[i]

    # Convert PIL Image to base64
    base64_image = pil_to_base64(example['image'])

    # Create new example with base64 image
    base64_example = {
        'image_base64': base64_image,
        'imgid': example['imgid'],
        'split': example['split'],
        'html': example['html'],
        'html_table': example['html_table']
    }

    base64_data.append(base64_example)

    if (i + 1) % 20 == 0:  # Progress indicator
        print(f"Converted {i + 1}/100 images...")

print(f"✅ Successfully converted {len(base64_data)} images to base64 format")

# Check the first example
first_base64 = base64_data[0]
print(f"\nFirst base64 example:")
print(f"Image ID: {first_base64['imgid']}")
print(f"Base64 length: {len(first_base64['image_base64'])}")
print(f"Base64 preview: {first_base64['image_base64'][:100]}...")

print(f"\n🔄 Step 2: Creating JSONL file for training...")

# Create the JSONL file with exactly 100 examples
jsonl_file = create_table_qa_jsonl_single(base64_data)

# Verify the count
with open(jsonl_file, 'r') as f:
    line_count = sum(1 for line in f)
print(f"Final verification: {line_count} examples in JSONL file")

print(f"\n🔄 Step 3: Displaying first example...")

# Display the first example
show_jsonl_with_image_fixed(jsonl_file)

print(f"\n✅ Complete! JSONL file '{jsonl_file}' is ready for upload to Fireworks AI")

## 📤 Upload Dataset - Local File Method

Upload your JSONL file to Fireworks:

```bash
firectl create dataset DATASET_NAME ./your-file.jsonl
```

In [None]:
import subprocess
import sys
import os
from datetime import datetime

# Create timestamp
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

# File paths (adjust these to match your actual files)
sample_file = "pubtabnet_table_qa.jsonl"


# Create unique dataset name
dataset_name = f"pubtabnet-table-qa-{timestamp}"

print(f"📤 Uploading dataset: {dataset_name}")
print(f"📄 File: {sample_file}")

# Check if firectl exists and make it executable
firectl_path = "./firectl"
if os.path.exists(firectl_path):
    # Make firectl executable
    os.chmod(firectl_path, 0o755)
    print(f"✅ Found firectl at {firectl_path}")
else:
    print(f"❌ firectl not found at {firectl_path}")
    sys.exit(1)

# Check if the data file exists
if not os.path.exists(sample_file):
    print(f"❌ Data file not found: {sample_file}")
    print("Available files:")
    for f in os.listdir("."):
        if f.endswith(".jsonl"):
            print(f"  - {f}")
    sys.exit(1)

# Build and run the command with full path to firectl
cmd = [firectl_path, "-a", "pyroworks", "create", "dataset", dataset_name, sample_file]
print(f"🔧 Running: {' '.join(cmd)}")

try:
    result = subprocess.run(cmd, capture_output=True, text=True, check=True)
    print("✅ Dataset upload successful!")
    print(result.stdout)
    print(f"📊 Dataset name: {dataset_name}")
except subprocess.CalledProcessError as e:
    print(f"❌ Upload failed: {e}")
    print(f"Error output: {e.stderr}")
    print(f"Return code: {e.returncode}")
except FileNotFoundError:
    print("❌ firectl binary not found or not executable.")

In [None]:
! firectl get dataset {dataset_name}

# 4. Create Fine-tuning Job

## 📋 Section Overview  
- Create fine-tuning job with `firectl create supervised-fine-tuning-job`
- Monitor training progress
- Best practices for parameters

## 🚀 Start Fine-tuning

Here's the key CLI command to start training:

```bash
firectl create supervised-fine-tuning-job \
  --base-model accounts/fireworks/models/qwen2p5-vl-32b-instruct \
  --dataset DATASET_NAME \
  --output-model MODEL_NAME \
  --display-name "pubtabnet-table-qa Fine-tune" \
  --epochs 3 \
  --learning-rate 0.0001 \
  --turbo \
  --early-stop \
  --eval-auto-carveout
```

## 📊 Parameter Guide

| Parameter | Description | Recommended Value |
|-----------|-------------|-------------------|
| `--base-model` | Base model to fine-tune | `qwen2p5-vl-32b-instruct` |
| `--dataset` | Your uploaded dataset ID | From step 3 |
| `--output-model` | Name for your fine-tuned model | `my-pubtabnet-table-qa-model` |
| `--epochs` | Training iterations | `3` (start small) |
| `--learning-rate` | Learning rate | `0.0001` |
| `--turbo` | Faster training | Always include |
| `--early-stop` | Prevent overfitting | Always include |
| `--eval-auto-carveout` | Auto validation split | Always include |

In [None]:
import subprocess
import re
import os
from datetime import datetime

# Create unique model name
model_name = f"pubtabnet-table-qa-{timestamp}"
display_name = f"pubtabnet-table-qa Fine-tune {timestamp}"

print("🚀 Creating fine-tuning job...")
print(f"📋 Dataset: {dataset_name}")
print(f"📋 Output Model: {model_name}")

# Check if firectl exists and make it executable
firectl_path = "./firectl"
if os.path.exists(firectl_path):
    # Make firectl executable
    os.chmod(firectl_path, 0o755)
    print(f"✅ Found firectl at {firectl_path}")
else:
    print(f"❌ firectl not found at {firectl_path}")
    exit(1)

# Build the fine-tuning command - FIXED: separate arguments properly
cmd = [
    firectl_path,
    "create", "supervised-fine-tuning-job",
    "--base-model", "accounts/fireworks/models/qwen2p5-vl-32b-instruct",
    "--dataset", dataset_name,
    "--output-model", model_name,
    "--display-name", display_name,
    "--epochs", "3",
    "--learning-rate", "0.0001",
    "--turbo",
    "--early-stop",
    "--eval-auto-carveout"
]

print(f"🔧 Running: {' '.join(cmd)}")

try:
    result = subprocess.run(cmd, capture_output=True, text=True, check=True)
    print("✅ Fine-tuning job created successfully!")
    print(result.stdout)

    # Extract job ID from output
    pattern = r"Name: accounts/[^/]+/supervisedFineTuningJobs/([a-zA-Z0-9]+)"

    match = re.search(pattern, result.stdout)
    if match:
        job_id = match.group(1)
        print(f"🎯 Job ID: {job_id}")
        print(f"💡 Save this: export JOB_ID='{job_id}'")
        print(f"\n🔍 Monitor with: ./firectl get supervised-fine-tuning-job {job_id}")
    else:
        print("💡 Job created but couldn't extract ID. Check the output above.")

except subprocess.CalledProcessError as e:
    print(f"❌ Fine-tuning job creation failed: {e}")
    print(f"Error output: {e.stderr}")
    if "dataset" in str(e.stderr).lower():
        print("💡 Make sure your dataset was uploaded successfully first!")
    elif "not found" in str(e.stderr).lower():
        print("💡 Check if the dataset name is correct")
except FileNotFoundError:
    print("❌ firectl binary not found or not executable.")

# 5. Monitor Training Progress

## 📋 Section Overview
- Check job status with CLI commands
- Monitor training metrics
- Know when training is complete

## 📊 Monitoring Commands

Check your job status with these commands:

```bash
# List all jobs
firectl list supervised-fine-tuning-jobs

# Check specific job
firectl get supervised-fine-tuning-job JOB_ID

# Check your models (after completion)
firectl list model
```

In [None]:
import subprocess
import os

print("📊 Monitoring your fine-tuning job...")
print(f"🎯 Job ID: {job_id}")

# Check if firectl exists and make it executable
firectl_path = "./firectl"
if os.path.exists(firectl_path):
    # Make firectl executable
    os.chmod(firectl_path, 0o755)
    print(f"✅ Found firectl at {firectl_path}")
else:
    print(f"❌ firectl not found at {firectl_path}")
    exit(1)

try:
    # FIXED: Separate arguments properly and use ./firectl
    cmd = [firectl_path, "-a", "pyroworks", "get", "supervised-fine-tuning-job", job_id]
    print(f"🔧 Running: {' '.join(cmd)}")

    result = subprocess.run(cmd, capture_output=True, text=True, check=True)
    print("✅ Job details:")
    print(result.stdout)

    # Check if job is completed
    if "JOB_STATE_COMPLETED" in result.stdout:
        print("🎉 Job completed! Ready for deployment.")
    elif "JOB_STATE_RUNNING" in result.stdout:
        print("⏳ Job is still running...")
    elif "JOB_STATE_VALIDATING" in result.stdout:
        print("🔍 Job is validating...")
    elif "JOB_STATE_QUEUED" in result.stdout:
        print("📋 Job is queued...")
    else:
        print("📊 Check the status above for current state.")

except subprocess.CalledProcessError as e:
    print(f"❌ Failed to check job {job_id}: {e}")
    print(f"Error: {e.stderr}")
    if "not found" in str(e.stderr).lower():
        print("💡 Make sure the job ID is correct")
except FileNotFoundError:
    print("❌ firectl binary not found or not executable.")

# Optional: List all jobs to see what's available
print("\n" + "="*50)
print("📋 All your fine-tuning jobs:")
try:
    cmd_list = [firectl_path, "-a", "pyroworks", "list", "supervised-fine-tuning-jobs"]
    result_list = subprocess.run(cmd_list, capture_output=True, text=True, check=True)
    print(result_list.stdout)
except Exception as e:
    print(f"❌ Failed to list jobs: {e}")

## 🚀 Deploy Your Model

Once training is **JOB_STATE_COMPLETED**, deploy your model:

```bash
firectl create deployment accounts/pyroworks/models/MODEL_NAME --enable-addons
```


In [None]:
# Execute model deployment
import subprocess
import re

# Check if firectl exists and make it executable
firectl_path = "./firectl"
if os.path.exists(firectl_path):
    # Make firectl executable
    os.chmod(firectl_path, 0o755)
    print(f"✅ Found firectl at {firectl_path}")
else:
    print(f"❌ firectl not found at {firectl_path}")
    exit(1)

# Build deployment command (using the model name from earlier)
model_path = f"accounts/pyroworks/models/{model_name}"

# FIXED: Separate arguments properly and use ./firectl
cmd = [firectl_path, "-a", "pyroworks", "create", "deployment", model_path, "--enable-addons"]

print(f"🔧 Running: {' '.join(cmd)}")
print("⏳ Note: This will only work if your fine-tuning job has COMPLETED")

try:
    result = subprocess.run(cmd, capture_output=True, text=True, check=True)
    print("✅ Model deployment successful!")
    print(result.stdout)

    # Extract deployment ID
    pattern = r"Name: accounts/[^/]+/deployments/([a-zA-Z0-9]+)"
    match = re.search(pattern, result.stdout)
    if match:
        deployment_id = match.group(1)
        print(f"🎯 Deployment ID: {deployment_id}")
        print(f"💡 Save this: export DEPLOYMENT_ID='{deployment_id}'")
        print(f"\n🌐 Your model is now available for API calls!")
    else:
        print("💡 Deployment created but couldn't extract ID. Check the output above.")

except subprocess.CalledProcessError as e:
    print(f"❌ Deployment failed: {e}")
    print(f"Error output: {e.stderr}")
    if "not found" in str(e.stderr):
        print("💡 Make sure your fine-tuning job completed successfully first!")
        print("💡 Check job status with: ./firectl list supervised-fine-tuning-jobs")
    elif "model" in str(e.stderr).lower():
        print("💡 Make sure the model name is correct")
except FileNotFoundError:
    print("❌ firectl binary not found or not executable.")

# Optional: List all your models to verify the fine-tuned model exists
print("\n" + "="*50)
print("📋 Your available models:")
try:
    cmd_list = [firectl_path, "list", "models"]
    result_list = subprocess.run(cmd_list, capture_output=True, text=True, check=True)
    print(result_list.stdout)

    # Check if our model is in the list
    if model_name in result_list.stdout:
        print(f"✅ Found your fine-tuned model: {model_name}")
    else:
        print(f"⚠️  Your model '{model_name}' not found in the list above.")
        print("💡 Make sure your fine-tuning job completed successfully.")

except Exception as e:
    print(f"❌ Failed to list models: {e}")

In [None]:
! firectl get deployment {deployment_id}

# 7. Test Your Model

## 🧪 Query with curl

Test your deployed model from the command line:

```bash
curl --request POST \
  --url https://api.fireworks.ai/inference/v1/chat/completions \
  -H 'Accept: application/json' \
  -H 'Content-Type: application/json' \
  -H 'Authorization: Bearer $FIREWORKS_API_KEY' \
  --data '{
    "model": "accounts/pyroworks/deployedModels/pubtabnet-table-qa-20250804-032004-tv6p517x",
    "max_tokens": 4000,
    "top_p": 1,
    "top_k": 40,
    "presence_penalty": 0,
    "frequency_penalty": 0,
    "temperature": 0.6,
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Can you describe this image?"
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"
                    }
                }
            ]
        }
    ]
  }'
```


In [None]:
# Test your fine-tuned VLM model on test dataset
import requests
import os
import random
import base64
from io import BytesIO
import matplotlib.pyplot as plt
from PIL import Image
import json

print("🧪 Testing your fine-tuned VLM model on test dataset...")

# Get first 5 test examples
test_examples = ds_test.select(range(5))
print(f"Selected {len(test_examples)} test examples")

# Convert test images to base64
def pil_to_base64(pil_image):
    """Convert PIL Image to base64 string"""
    buffered = BytesIO()
    pil_image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/png;base64,{img_str}"

def display_image_from_base64(base64_string, title="Image"):
    """Display image from base64 string"""
    try:
        # Extract and decode the base64 image
        if ',' in base64_string:
            image_data = base64_string.split(',')[1]
        else:
            image_data = base64_string

        # Decode base64 to image
        image_bytes = base64.b64decode(image_data)
        image = Image.open(BytesIO(image_bytes))

        # Display using matplotlib
        plt.figure(figsize=(10, 6))
        plt.imshow(image)
        plt.axis('off')
        plt.title(title, fontsize=12, pad=15)
        plt.show()

        return image

    except Exception as e:
        print(f"Error displaying image: {e}")
        return None

# Convert test images
test_base64_data = []
for i in range(5):
    example = test_examples[i]
    base64_image = pil_to_base64(example['image'])
    test_base64_data.append({
        'image_base64': base64_image,
        'imgid': example['imgid'],
        'html_table': example['html_table'],
        'original_image': example['image']  # Keep original PIL image for easier display
    })

# Question templates
question_templates = [
    "What information is shown in this table?",
    "Describe the content and structure of this table.",
    "What data does this table contain?",
    "Can you analyze this table and tell me what it shows?",
    "What are the main elements of this table?",
    "Summarize the information presented in this table.",
    "What can you tell me about the data in this table?",
    "Describe what this table is displaying.",
]

# Model configuration - UPDATE THESE VALUES
model_name = "pubtabnet-table-qa-20250804-032004-tv6p517x"  # Replace with your actual model name
model_id = f"accounts/pyroworks/deployedModels/{model_name}"
url = "https://api.fireworks.ai/inference/v1/chat/completions"

# Get API key from environment
api_key = os.environ.get('FIREWORKS_API_KEY')
if not api_key:
    print("❌ FIREWORKS_API_KEY not found in environment variables")
    print("Please set it using: os.environ['FIREWORKS_API_KEY'] = 'your_key'")
else:
    print("✅ API key found")

headers = {
    "Authorization": f"Bearer {api_key}",
    "Content-Type": "application/json"
}

print(f"🎯 Testing model: {model_id}")
print("⏳ Note: Model must be deployed first!")

# Test each image
results = []

for i, test_data in enumerate(test_base64_data):
    # Pick a random question for this image
    question = random.choice(question_templates)

    print(f"\n" + "="*80)
    print(f"📋 Test {i+1}: Image ID {test_data['imgid']}")
    print(f"❓ Question: {question}")
    print("="*80)

    # Display the image
    print(f"\n🖼️ Displaying test image:")
    display_image_from_base64(
        test_data['image_base64'],
        title=f"Test {i+1}: {test_data['imgid']} - {question[:50]}..."
    )

    data = {
        "model": model_id,
        "messages": [
            {
                "role": "system",
                "content": "You are a helpful assistant that can analyze tables and documents to answer questions about their content. Focus on understanding what information is presented and be ready to answer specific questions about the data."
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": test_data['image_base64']
                        }
                    },
                    {
                        "type": "text",
                        "text": question
                    }
                ]
            }
        ],
        "temperature": 0.1,
        "max_tokens": 500
    }

    try:
        print("🔄 Sending request to model...")
        response = requests.post(url, headers=headers, json=data)
        response.raise_for_status()

        result = response.json()
        model_response = result['choices'][0]['message']['content'].strip()

        print(f"\n🤖 Model Response:")
        print("-" * 40)
        print(model_response)
        print("-" * 40)

        # Store result for comparison
        results.append({
            'image_id': test_data['imgid'],
            'question': question,
            'model_response': model_response,
            'ground_truth_html': test_data['html_table'],
            'image': test_data['original_image']
        })

        print(f"✅ Test {i+1} completed successfully!")

    except requests.exceptions.RequestException as e:
        print(f"❌ Request failed: {e}")
        if hasattr(e, 'response') and e.response is not None:
            print(f"Response: {e.response.text}")
        break
    except Exception as e:
        print(f"❌ Error: {e}")
        break

print(f"\n" + "="*80)
print(f"✅ Testing complete! Processed {len(results)} examples")
print("="*80)