# FinScribe LLaMA-Factory Micro LoRA Experiment

This notebook runs a tiny LoRA SFT with LLaMA-Factory on 10 synthetic invoice pairs for development/testing.

**Requirements:**
- Colab with GPU runtime (recommended)
- Hugging Face token (if using gated models)
- ~20GB disk space


## Cell 1: Setup & Install


In [None]:
# Colab cell 1: install deps & clone
# If you run into space issues, consider mounting Google Drive
import subprocess
import sys
import os

def run_command(cmd, check=True, shell=True):
    """Run a shell command with error handling."""
    try:
        result = subprocess.run(cmd, shell=shell, check=check, 
                              capture_output=True, text=True)
        if result.stdout:
            print(result.stdout)
        return result.returncode == 0
    except subprocess.CalledProcessError as e:
        print(f"‚ùå Error running command: {cmd}")
        print(f"Error output: {e.stderr}")
        if check:
            raise
        return False
    except Exception as e:
        print(f"‚ùå Unexpected error: {e}")
        if check:
            raise
        return False

# Check GPU availability
print("Checking GPU availability...")
run_command("nvidia-smi", check=False)

# Clone LLaMA-Factory if not already present
if os.path.exists("LLaMA-Factory"):
    print("‚ö†Ô∏è  LLaMA-Factory directory already exists. Skipping clone.")
    print("   If you want a fresh clone, delete the directory first.")
else:
    print("Cloning LLaMA-Factory repository...")
    if not run_command("git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git"):
        raise RuntimeError("Failed to clone LLaMA-Factory repository")

# Change to LLaMA-Factory directory
if not os.path.exists("LLaMA-Factory"):
    raise RuntimeError("LLaMA-Factory directory not found after clone")
    
os.chdir("LLaMA-Factory")
print(f"‚úÖ Changed to directory: {os.getcwd()}")

# Install dependencies
print("Installing LLaMA-Factory dependencies (this may take several minutes)...")
if not run_command('pip install -e ".[torch,metrics]"'):
    print("‚ö†Ô∏è  Installation failed. You may need to:")
    print("   - Check your CUDA/PyTorch compatibility")
    print("   - Ensure you have sufficient disk space")
    print("   - Try: pip install -e '.[torch,metrics]' manually")
    raise RuntimeError("Failed to install LLaMA-Factory dependencies")

print("‚úÖ Setup complete!")


## Cell 2: Create Tiny Dataset


In [None]:
# Colab cell 2: Create synthetic invoice dataset
import json
import random
import os
from pathlib import Path

try:
    # Create data directory with error handling
    data_dir = Path("data")
    data_dir.mkdir(exist_ok=True)
    print(f"‚úÖ Data directory ready: {data_dir.absolute()}")
    
    train = []
    
    # Generate training examples with validation
    print("Generating synthetic invoice examples...")
    for i in range(10):
        try:
            vendor = random.choice(["TechCorp Inc.", "Acme LLC", "Globex"])
            inv = f"INV-{1000+i}"
            date = f"2024-0{random.randint(1,9)}-{random.randint(10,28)}"
            prompt = f"Validate and correct: OCR_TEXT: Vendor: {vendor} Invoice: {inv} Date: {date} Items: Widget 2x50 Total 100"
            
            # Validate JSON structure before adding
            completion_data = {
                "document_type": "invoice",
                "vendor": {"name": vendor},
                "client": {},
                "line_items": [{"desc": "Widget", "qty": 2, "unit_price": 50.0, "line_total": 100.0}],
                "financial_summary": {"subtotal": 100.0, "tax_rate": 0.0, "tax_amount": 0.0, "grand_total": 100.0}
            }
            
            # Validate JSON serialization
            completion_json = json.dumps(completion_data)
            json.loads(completion_json)  # Verify it's valid JSON
            
            train.append({
                "instruction": "Validate and return JSON only",
                "input": prompt,
                "output": completion_json
            })
        except (ValueError, KeyError, TypeError) as e:
            print(f"‚ö†Ô∏è  Error creating example {i}: {e}")
            continue
    
    if len(train) == 0:
        raise RuntimeError("Failed to create any training examples")
    
    # Write dataset file with error handling
    output_file = data_dir / "finscribe_lf_train.jsonl"
    try:
        with open(output_file, "w", encoding="utf-8") as f:
            for item in train:
                # Validate each item before writing
                if not all(key in item for key in ["instruction", "input", "output"]):
                    raise ValueError(f"Invalid item structure: {item}")
                f.write(json.dumps(item, ensure_ascii=False) + "\n")
        
        # Verify file was created and has content
        if not output_file.exists():
            raise FileNotFoundError(f"Output file was not created: {output_file}")
        
        file_size = output_file.stat().st_size
        if file_size == 0:
            raise ValueError(f"Output file is empty: {output_file}")
        
        print(f"‚úÖ Successfully wrote {len(train)} examples to {output_file}")
        print(f"   File size: {file_size} bytes")
        
        # Validate file by reading it back
        with open(output_file, "r", encoding="utf-8") as f:
            lines = f.readlines()
            if len(lines) != len(train):
                raise ValueError(f"Line count mismatch: expected {len(train)}, got {len(lines)}")
            # Validate JSON on first line
            json.loads(lines[0])
        
        print("‚úÖ Dataset file validated successfully")
        
    except OSError as e:
        raise RuntimeError(f"Failed to write dataset file: {e}")
    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid JSON in dataset: {e}")
    except Exception as e:
        raise RuntimeError(f"Unexpected error creating dataset: {e}")

except Exception as e:
    print(f"‚ùå Error in dataset creation: {e}")
    raise


## Cell 3: Register Dataset & Create Training Config


In [None]:
# Colab cell 3: Register dataset and create YAML config
import json
import os
from pathlib import Path

try:
    # Validate that dataset file exists before registering
    dataset_file = Path("data/finscribe_lf_train.jsonl")
    if not dataset_file.exists():
        raise FileNotFoundError(
            f"Dataset file not found: {dataset_file}\n"
            "Please run the previous cell to create the dataset first."
        )
    
    print(f"‚úÖ Found dataset file: {dataset_file}")
    
    # Register dataset in dataset_info.json
    dataset_info_path = Path("data/dataset_info.json")
    
    # Load existing dataset_info if it exists
    existing_info = {}
    if dataset_info_path.exists():
        try:
            with open(dataset_info_path, "r", encoding="utf-8") as f:
                existing_info = json.load(f)
            print(f"‚úÖ Loaded existing dataset_info.json with {len(existing_info)} entries")
        except json.JSONDecodeError as e:
            print(f"‚ö†Ô∏è  Existing dataset_info.json is invalid JSON: {e}")
            print("   Creating new file...")
            existing_info = {}
    
    # Add or update our dataset entry
    dataset_info = {
        "finscribe_lf_train": {
            "file_name": "finscribe_lf_train.jsonl",
            "format": "jsonl",
            "description": "FinScribe micro experiment dataset"
        }
    }
    existing_info.update(dataset_info)
    
    # Write dataset_info.json with error handling
    try:
        with open(dataset_info_path, "w", encoding="utf-8") as f:
            json.dump(existing_info, f, indent=2, ensure_ascii=False)
        print(f"‚úÖ Registered dataset in {dataset_info_path}")
        
        # Validate the written JSON
        with open(dataset_info_path, "r", encoding="utf-8") as f:
            json.load(f)
        print("‚úÖ Dataset info file validated")
        
    except OSError as e:
        raise RuntimeError(f"Failed to write dataset_info.json: {e}")
    except json.JSONEncodeError as e:
        raise ValueError(f"Failed to encode dataset_info as JSON: {e}")
    
    # Create training YAML
    yaml_config = """model_name_or_path: <SMALL_MODEL_NAME>  # Replace with small model like 'facebook/opt-125m' or 'microsoft/phi-2'
stage: sft
finetuning_type: lora
dataset: finscribe_lf_train
cutoff_len: 512
output_dir: saves/finscribe_test
per_device_train_batch_size: 1
num_train_epochs: 1
learning_rate: 2e-5
bf16: false
logging_steps: 5
save_steps: 10
"""
    
    # Create output directory
    yaml_dir = Path("examples/train_lora")
    try:
        yaml_dir.mkdir(parents=True, exist_ok=True)
    except OSError as e:
        raise RuntimeError(f"Failed to create directory {yaml_dir}: {e}")
    
    # Write YAML config file
    yaml_file = yaml_dir / "finscribe_colab.yaml"
    try:
        with open(yaml_file, "w", encoding="utf-8") as f:
            f.write(yaml_config.strip())
        
        # Verify file was created
        if not yaml_file.exists():
            raise FileNotFoundError(f"YAML file was not created: {yaml_file}")
        
        print(f"‚úÖ Created training config: {yaml_file}")
        
        # Check if model name needs to be replaced
        with open(yaml_file, "r", encoding="utf-8") as f:
            content = f.read()
            if "<SMALL_MODEL_NAME>" in content:
                print("\n‚ö†Ô∏è  IMPORTANT: Edit the YAML file to replace <SMALL_MODEL_NAME> with your chosen model!")
                print(f"   File location: {yaml_file.absolute()}")
                print("   Suggested models: 'facebook/opt-125m' or 'microsoft/phi-2'")
            else:
                print("‚úÖ Model name appears to be configured")
                
    except OSError as e:
        raise RuntimeError(f"Failed to write YAML config file: {e}")

except FileNotFoundError as e:
    print(f"‚ùå File not found: {e}")
    raise
except (ValueError, RuntimeError) as e:
    print(f"‚ùå Error: {e}")
    raise
except Exception as e:
    print(f"‚ùå Unexpected error: {e}")
    raise


## Cell 4: Run Training


In [None]:
# Colab cell 4: Run training
# Make sure you've edited the YAML to set model_name_or_path
import subprocess
import sys
import os
from pathlib import Path

def validate_training_config():
    """Validate that the training configuration is ready."""
    yaml_file = Path("examples/train_lora/finscribe_colab.yaml")
    
    if not yaml_file.exists():
        raise FileNotFoundError(
            f"Training config not found: {yaml_file}\n"
            "Please run the previous cell to create the config first."
        )
    
    # Check if model name is still a placeholder
    with open(yaml_file, "r", encoding="utf-8") as f:
        content = f.read()
        if "<SMALL_MODEL_NAME>" in content:
            raise ValueError(
                "Model name not configured!\n"
                f"Please edit {yaml_file} and replace <SMALL_MODEL_NAME> with a valid model name.\n"
                "Suggested: 'facebook/opt-125m' or 'microsoft/phi-2'"
            )
    
    # Verify dataset exists
    dataset_file = Path("data/finscribe_lf_train.jsonl")
    if not dataset_file.exists():
        raise FileNotFoundError(
            f"Dataset file not found: {dataset_file}\n"
            "Please run the dataset creation cell first."
        )
    
    print("‚úÖ Training configuration validated")

try:
    # Validate configuration before training
    print("Validating training configuration...")
    validate_training_config()
    
    # Check if LLaMA-Factory is installed
    try:
        import llamafactory
        print(f"‚úÖ LLaMA-Factory found: {llamafactory.__file__}")
    except ImportError:
        raise ImportError(
            "LLaMA-Factory not installed. Please run the setup cell first."
        )
    
    yaml_file = "examples/train_lora/finscribe_colab.yaml"
    print(f"\nüöÄ Starting training with config: {yaml_file}")
    print("   This may take several minutes...")
    
    # Try Python module first (more reliable)
    try:
        result = subprocess.run(
            [sys.executable, "-m", "llamafactory.entrypoints", "train", yaml_file],
            check=False,  # Don't raise on non-zero exit
            capture_output=False  # Show output in real-time
        )
        
        if result.returncode != 0:
            print(f"\n‚ùå Training failed with exit code {result.returncode}")
            print("   Check the output above for error details.")
            raise subprocess.CalledProcessError(result.returncode, result.args)
        
        print("\n‚úÖ Training completed successfully!")
        
    except subprocess.CalledProcessError as e:
        print(f"\n‚ùå Training command failed: {e}")
        print("\nTroubleshooting tips:")
        print("  - Check that the model name in the YAML is valid and accessible")
        print("  - Ensure you have sufficient GPU memory")
        print("  - Verify the dataset file is valid JSONL")
        print("  - Check disk space for output directory")
        raise
    
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è  Training interrupted by user")
        raise
    
    except Exception as e:
        print(f"\n‚ùå Unexpected error during training: {e}")
        raise

except (FileNotFoundError, ValueError, ImportError) as e:
    print(f"‚ùå Configuration error: {e}")
    raise
except Exception as e:
    print(f"‚ùå Error: {e}")
    raise


## Cell 5: Inference Test (if serving locally)


In [None]:
# Colab cell 5: Example inference stub
# Use your running LLaMA-Factory API or load model directly
import requests
import json
from typing import Optional, Dict, Any
import time

def test_api_connection(api_url: str, timeout: int = 5) -> bool:
    """Test if API server is reachable."""
    try:
        response = requests.get(api_url.replace("/v1/chat/completions", "/health"), 
                              timeout=timeout)
        return response.status_code == 200
    except (requests.exceptions.RequestException, AttributeError):
        return False

def call_inference_api(
    api_url: str,
    model_name: str,
    prompt: str,
    temperature: float = 0.0,
    timeout: int = 30,
    max_retries: int = 3
) -> Optional[Dict[str, Any]]:
    """Call the inference API with error handling and retries."""
    payload = {
        "model": model_name,
        "messages": [
            {"role": "user", "content": prompt}
        ],
        "temperature": temperature
    }
    
    for attempt in range(max_retries):
        try:
            print(f"Attempting API call (attempt {attempt + 1}/{max_retries})...")
            response = requests.post(
                api_url,
                json=payload,
                timeout=timeout,
                headers={"Content-Type": "application/json"}
            )
            
            # Check HTTP status
            response.raise_for_status()
            
            # Validate response is JSON
            try:
                result = response.json()
                return result
            except json.JSONDecodeError as e:
                raise ValueError(f"API returned invalid JSON: {e}\nResponse text: {response.text[:200]}")
                
        except requests.exceptions.Timeout:
            if attempt < max_retries - 1:
                wait_time = (attempt + 1) * 2
                print(f"‚ö†Ô∏è  Request timed out. Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
                continue
            else:
                raise RuntimeError(f"API request timed out after {max_retries} attempts")
                
        except requests.exceptions.ConnectionError:
            raise ConnectionError(
                f"Cannot connect to API at {api_url}\n"
                "Make sure the LLaMA-Factory API server is running.\n"
                "Start it with: llamafactory-cli api"
            )
            
        except requests.exceptions.HTTPError as e:
            if e.response.status_code == 404:
                raise FileNotFoundError(
                    f"API endpoint not found: {api_url}\n"
                    "Check that the API server is running and the endpoint is correct."
                )
            elif e.response.status_code == 503:
                if attempt < max_retries - 1:
                    wait_time = (attempt + 1) * 2
                    print(f"‚ö†Ô∏è  Service unavailable. Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
                    continue
                else:
                    raise RuntimeError("API service unavailable after retries")
            else:
                raise RuntimeError(
                    f"API returned error {e.response.status_code}: {e.response.text[:200]}"
                )
                
        except requests.exceptions.RequestException as e:
            raise RuntimeError(f"Request failed: {e}")
    
    return None

# Configuration
API_BASE = "http://localhost:8000"
API_URL = f"{API_BASE}/v1/chat/completions"
MODEL_NAME = "finscribe-llama"  # Update this to match your trained model name

# Example prompt
test_prompt = (
    "Validate JSON: {\"document_type\":\"invoice\","
    "\"vendor\":{\"name\":\"TechCorp Inc.\"},"
    "\"line_items\":[{\"desc\":\"Widget\",\"qty\":2,\"unit_price\":50.0}]}"
)

print("=" * 60)
print("Inference API Test")
print("=" * 60)

try:
    # Test API connection
    print(f"Testing connection to {API_BASE}...")
    if test_api_connection(API_URL):
        print("‚úÖ API server is reachable")
    else:
        print("‚ö†Ô∏è  Health check endpoint not available, but continuing...")
    
    # Make inference call
    print(f"\nCalling inference API...")
    print(f"  URL: {API_URL}")
    print(f"  Model: {MODEL_NAME}")
    print(f"  Prompt: {test_prompt[:50]}...")
    
    result = call_inference_api(
        api_url=API_URL,
        model_name=MODEL_NAME,
        prompt=test_prompt,
        temperature=0.0,
        timeout=30
    )
    
    if result:
        print("\n‚úÖ API call successful!")
        print("\nResponse:")
        print(json.dumps(result, indent=2, ensure_ascii=False))
        
        # Extract and validate response content
        if "choices" in result and len(result["choices"]) > 0:
            content = result["choices"][0].get("message", {}).get("content", "")
            if content:
                print(f"\nGenerated content:\n{content}")
                # Try to parse as JSON if it looks like JSON
                if content.strip().startswith("{"):
                    try:
                        parsed = json.loads(content)
                        print("‚úÖ Response is valid JSON")
                    except json.JSONDecodeError:
                        print("‚ö†Ô∏è  Response is not valid JSON")
    else:
        print("‚ùå API call returned no result")
        
except ConnectionError as e:
    print(f"\n‚ùå Connection error: {e}")
    print("\nTo start the API server, run:")
    print("  llamafactory-cli api")
    print("  # or")
    print("  python -m llamafactory.entrypoints api")
    
except (FileNotFoundError, ValueError, RuntimeError) as e:
    print(f"\n‚ùå Error: {e}")
    
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Interrupted by user")
    
except Exception as e:
    print(f"\n‚ùå Unexpected error: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "=" * 60)
