# Security

> Encryption and security utilities for API key management

In [None]:
#| default_exp core.security

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
import base64
import hashlib
import warnings
from typing import Optional, Union
from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC

from cjm_fasthtml_byok.core.types import EncryptionError, SecurityWarning

## Key Generation

In [None]:
#| export
def generate_encryption_key(
    password: Optional[str] = None,  # Optional password to derive key from
    salt: Optional[bytes] = None  # Optional salt for key derivation (required if password provided)
) -> bytes:  # 32-byte encryption key suitable for Fernet
    """Generate or derive an encryption key."""
    if password:
        if not salt:
            salt = os.urandom(16)
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=100000,
        )
        key = base64.urlsafe_b64encode(kdf.derive(password.encode()))
        return key
    else:
        # Generate a random key
        return Fernet.generate_key()

In [None]:
#| export
def get_or_create_app_key(
    secret_key: str  # The application's secret key (from FastHTML app config)
) -> bytes:  # Encryption key for the app
    """Get or create an app-specific encryption key derived from the app's secret key."""
    # Use a fixed salt specific to BYOK to ensure consistency
    salt = b'cjm-fasthtml-byok-v1'
    return generate_encryption_key(password=secret_key, salt=salt)

## Encryption/Decryption

In [None]:
#| export
class KeyEncryptor:
    """
    Handles encryption and decryption of API keys.
    """
    
    def __init__(
        self,
        encryption_key: Optional[bytes] = None  # Encryption key to use. If None, generates a new one
    ):
        """
        Initialize the encryptor.
        
        Args:
            encryption_key: Encryption key to use. If None, generates a new one.
        """
        self.encryption_key = encryption_key or Fernet.generate_key()
        self._fernet = Fernet(self.encryption_key)
    
    def encrypt(
        self,
        value: str  # Plain text API key to encrypt
    ) -> bytes:  # Encrypted bytes
        """
        Encrypt an API key value.
        
        Args:
            value: Plain text API key
        
        Returns:
            Encrypted bytes
        
        Raises:
            EncryptionError: If encryption fails
        """
        try:
            return self._fernet.encrypt(value.encode())
        except Exception as e:
            raise EncryptionError(f"Failed to encrypt value: {e}")
    
    def decrypt(
        self,
        encrypted_value: bytes  # Encrypted bytes to decrypt
    ) -> str:  # Decrypted plain text API key
        """
        Decrypt an API key value.
        
        Args:
            encrypted_value: Encrypted bytes
        
        Returns:
            Decrypted API key
        
        Raises:
            EncryptionError: If decryption fails
        """
        try:
            return self._fernet.decrypt(encrypted_value).decode()
        except InvalidToken:
            raise EncryptionError("Invalid encryption key or corrupted data")
        except Exception as e:
            raise EncryptionError(f"Failed to decrypt value: {e}")
    
    def rotate_key(
        self,
        new_key: bytes,  # New encryption key to use
        encrypted_value: bytes  # Value encrypted with current key
    ) -> bytes:  # Value re-encrypted with new key
        """
        Re-encrypt a value with a new key.
        
        Args:
            new_key: New encryption key
            encrypted_value: Value encrypted with current key
        
        Returns:
            Value encrypted with new key
        """
        decrypted = self.decrypt(encrypted_value)
        new_encryptor = KeyEncryptor(new_key)
        return new_encryptor.encrypt(decrypted)

## Security Checks

In [None]:
#| export
def check_https(
    request  # FastHTML/Starlette request object
) -> bool:  # True if using HTTPS, False otherwise
    """Check if the request is using HTTPS."""
    return request.url.scheme == 'https'

In [None]:
#| export
def validate_environment(
    request,  # FastHTML/Starlette request object
    require_https: bool = True,  # Whether to require HTTPS
    is_production: bool = None  # Whether running in production (auto-detected if None)
) -> None:
    """Validate the security environment."""
    if is_production is None:
        # Simple production detection
        is_production = not any([
            'localhost' in str(request.url),
            '127.0.0.1' in str(request.url),
            '0.0.0.0' in str(request.url),
            os.environ.get('DEBUG', '').lower() in ('true', '1', 'yes')
        ])
    
    if is_production and require_https and not check_https(request):
        warnings.warn(
            "API keys are being transmitted over HTTP in production. "
            "This is insecure. Please use HTTPS.",
            SecurityWarning
        )

In [None]:
#| export
def mask_key(
    key: str,  # The API key to mask
    visible_chars: int = 4  # Number of characters to show at start and end
) -> str:  # Masked key like 'sk-a...xyz'
    """Mask an API key for display purposes."""
    if len(key) <= visible_chars * 2:
        return '*' * len(key)
    
    return f"{key[:visible_chars]}...{key[-visible_chars:]}"

## Key Fingerprinting

In [None]:
#| export
def get_key_fingerprint(
    key: str  # The API key
) -> str:  # SHA256 fingerprint of the key (first 16 chars)
    """Generate a fingerprint for an API key (for logging/tracking without exposing the key)."""
    hash_obj = hashlib.sha256(key.encode())
    return hash_obj.hexdigest()[:16]

## Tests

In [None]:
# Test encryption/decryption
encryptor = KeyEncryptor()
test_key = "sk-1234567890abcdef"
encrypted = encryptor.encrypt(test_key)
decrypted = encryptor.decrypt(encrypted)
assert decrypted == test_key
print(f"✓ Encryption/decryption working")

# Test key masking
masked = mask_key("sk-1234567890abcdef")
assert masked == "sk-1...cdef"
print(f"✓ Key masking: {masked}")

# Test fingerprinting
fp1 = get_key_fingerprint("test-key-1")
fp2 = get_key_fingerprint("test-key-2")
assert fp1 != fp2
assert len(fp1) == 16
print(f"✓ Fingerprinting: {fp1}")

✓ Encryption/decryption working
✓ Key masking: sk-1...cdef
✓ Fingerprinting: 1255558df586ae27


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()