diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml deleted file mode 100644 index 6398cc7d..00000000 --- a/.github/workflows/codeql.yml +++ /dev/null @@ -1,43 +0,0 @@ -name: "CodeQL Security Analysis" - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - schedule: - # Run every Monday at 6:00 AM UTC - - cron: '0 6 * * 1' - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - strategy: - fail-fast: false - matrix: - language: [ 'python' ] - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Initialize CodeQL - uses: github/codeql-action/init@v4 - with: - languages: ${{ matrix.language }} - # Use default queries plus security-extended - queries: security-extended - - - name: Autobuild - uses: github/codeql-action/autobuild@v4 - - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v4 - with: - category: "/language:${{ matrix.language }}" diff --git a/LLM/interpreter.py b/LLM/interpreter.py index 8dd87c5a..94fff128 100644 --- a/LLM/interpreter.py +++ b/LLM/interpreter.py @@ -1,8 +1,12 @@ import os import json -from typing import List, Optional, Dict, Any +import sqlite3 +from typing import List, Optional, Dict, Any, TYPE_CHECKING from enum import Enum +if TYPE_CHECKING: + from cortex.semantic_cache import SemanticCache + class APIProvider(Enum): CLAUDE = "claude" @@ -11,14 +15,43 @@ class APIProvider(Enum): class CommandInterpreter: + """Interprets natural language commands into executable shell commands using LLM APIs. + + Supports multiple providers (OpenAI, Claude, Ollama) with optional semantic caching + and offline mode for cached responses. + """ + def __init__( self, api_key: str, provider: str = "openai", - model: Optional[str] = None + model: Optional[str] = None, + offline: bool = False, + cache: Optional["SemanticCache"] = None, ): + """Initialize the command interpreter. + + Args: + api_key: API key for the LLM provider + provider: Provider name ("openai", "claude", or "ollama") + model: Optional model name override + offline: If True, only use cached responses + cache: Optional SemanticCache instance for response caching + """ self.api_key = api_key self.provider = APIProvider(provider.lower()) + self.offline = offline + + if cache is None: + try: + from cortex.semantic_cache import SemanticCache + + self.cache: Optional["SemanticCache"] = SemanticCache() + except (ImportError, OSError) as e: + # Cache initialization can fail due to missing dependencies or permissions + self.cache = None + else: + self.cache = cache if model: self.model = model @@ -173,8 +206,36 @@ def _validate_commands(self, commands: List[str]) -> List[str]: return validated def parse(self, user_input: str, validate: bool = True) -> List[str]: + """Parse natural language input into shell commands. + + Args: + user_input: Natural language description of desired action + validate: If True, validate commands for dangerous patterns + + Returns: + List of shell commands to execute + + Raises: + ValueError: If input is empty + RuntimeError: If offline mode is enabled and no cached response exists + """ if not user_input or not user_input.strip(): raise ValueError("User input cannot be empty") + + cache_system_prompt = self._get_system_prompt() + f"\n\n[cortex-cache-validate={bool(validate)}]" + + if self.cache is not None: + cached = self.cache.get_commands( + prompt=user_input, + provider=self.provider.value, + model=self.model, + system_prompt=cache_system_prompt, + ) + if cached is not None: + return cached + + if self.offline: + raise RuntimeError("Offline mode: no cached response available for this request") if self.provider == APIProvider.OPENAI: commands = self._call_openai(user_input) @@ -187,6 +248,19 @@ def parse(self, user_input: str, validate: bool = True) -> List[str]: if validate: commands = self._validate_commands(commands) + + if self.cache is not None and commands: + try: + self.cache.put_commands( + prompt=user_input, + provider=self.provider.value, + model=self.model, + system_prompt=cache_system_prompt, + commands=commands, + ) + except (OSError, sqlite3.Error): + # Silently fail cache writes - not critical for operation + pass return commands diff --git a/cortex/cli.py b/cortex/cli.py index 17004c68..c6f538ad 100644 --- a/cortex/cli.py +++ b/cortex/cli.py @@ -48,6 +48,7 @@ def __init__(self, verbose: bool = False): self.spinner_idx = 0 self.prefs_manager = None # Lazy initialization self.verbose = verbose + self.offline = False def _debug(self, message: str): """Print debug info only in verbose mode""" @@ -199,7 +200,7 @@ def install(self, software: str, execute: bool = False, dry_run: bool = False): try: self._print_status("🧠", "Understanding request...") - interpreter = CommandInterpreter(api_key=api_key, provider=provider) + interpreter = CommandInterpreter(api_key=api_key, provider=provider, offline=self.offline) self._print_status("📦", "Planning installation...") @@ -311,6 +312,24 @@ def progress_callback(current, total, step): self._print_error(f"Unexpected error: {str(e)}") return 1 + def cache_stats(self) -> int: + try: + from cortex.semantic_cache import SemanticCache + + cache = SemanticCache() + stats = cache.stats() + hit_rate = f"{stats.hit_rate * 100:.1f}%" if stats.total else "0.0%" + + cx_header("Cache Stats") + cx_print(f"Hits: {stats.hits}", "info") + cx_print(f"Misses: {stats.misses}", "info") + cx_print(f"Hit rate: {hit_rate}", "info") + cx_print(f"Saved calls (approx): {stats.hits}", "info") + return 0 + except Exception as e: + self._print_error(f"Unable to read cache stats: {e}") + return 1 + def history(self, limit: int = 20, status: Optional[str] = None, show_id: Optional[str] = None): """Show installation history""" history = InstallationHistory() @@ -544,6 +563,7 @@ def show_rich_help(): table.add_row("history", "View history") table.add_row("rollback ", "Undo installation") table.add_row("notify", "Manage desktop notifications") # Added this line + table.add_row("cache stats", "Show LLM cache statistics") console.print(table) console.print() @@ -560,6 +580,7 @@ def main(): # Global flags parser.add_argument('--version', '-V', action='version', version=f'cortex {VERSION}') parser.add_argument('--verbose', '-v', action='store_true', help='Show detailed output') + parser.add_argument('--offline', action='store_true', help='Use cached responses only (no network calls)') subparsers = parser.add_subparsers(dest='command', help='Available commands') @@ -617,6 +638,11 @@ def main(): send_parser.add_argument('--actions', nargs='*', help='Action buttons') # -------------------------- + # Cache commands + cache_parser = subparsers.add_parser('cache', help='Cache operations') + cache_subs = cache_parser.add_subparsers(dest='cache_action', help='Cache actions') + cache_subs.add_parser('stats', help='Show cache statistics') + args = parser.parse_args() if not args.command: @@ -624,6 +650,7 @@ def main(): return 0 cli = CortexCLI(verbose=args.verbose) + cli.offline = bool(getattr(args, 'offline', False)) try: if args.command == 'demo': @@ -645,6 +672,11 @@ def main(): # Handle the new notify command elif args.command == 'notify': return cli.notify(args) + elif args.command == 'cache': + if getattr(args, 'cache_action', None) == 'stats': + return cli.cache_stats() + parser.print_help() + return 1 else: parser.print_help() return 1 diff --git a/cortex/semantic_cache.py b/cortex/semantic_cache.py new file mode 100644 index 00000000..660d6aec --- /dev/null +++ b/cortex/semantic_cache.py @@ -0,0 +1,378 @@ +"""Semantic caching for LLM responses with SQLite backend and LRU eviction. + +Provides semantic similarity matching for cached responses to reduce API calls +and enable offline operation. +""" + +import json +import os +import sqlite3 +import hashlib +import math +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import List, Optional, Tuple + + +@dataclass(frozen=True) +class CacheStats: + """Statistics for cache performance. + + Attributes: + hits: Number of cache hits + misses: Number of cache misses + """ + hits: int + misses: int + + @property + def total(self) -> int: + """Total number of cache lookups.""" + return self.hits + self.misses + + @property + def hit_rate(self) -> float: + """Cache hit rate as a fraction (0.0 to 1.0).""" + if self.total == 0: + return 0.0 + return self.hits / self.total + + +class SemanticCache: + """Semantic cache for LLM command responses. + + Uses SQLite for persistence, simple embedding for semantic matching, + and LRU eviction policy for size management. + """ + def __init__( + self, + db_path: str = "/var/lib/cortex/cache.db", + max_entries: Optional[int] = None, + similarity_threshold: Optional[float] = None, + ): + """Initialize semantic cache. + + Args: + db_path: Path to SQLite database file + max_entries: Maximum cache entries before LRU eviction (default: 500) + similarity_threshold: Cosine similarity threshold for matches (default: 0.86) + """ + self.db_path = db_path + self.max_entries = max_entries if max_entries is not None else int(os.environ.get("CORTEX_CACHE_MAX_ENTRIES", "500")) + self.similarity_threshold = ( + similarity_threshold + if similarity_threshold is not None + else float(os.environ.get("CORTEX_CACHE_SIMILARITY_THRESHOLD", "0.86")) + ) + self._ensure_db_directory() + self._init_database() + + def _ensure_db_directory(self) -> None: + db_dir = Path(self.db_path).parent + try: + db_dir.mkdir(parents=True, exist_ok=True) + except PermissionError: + user_dir = Path.home() / ".cortex" + user_dir.mkdir(parents=True, exist_ok=True) + self.db_path = str(user_dir / "cache.db") + + def _init_database(self) -> None: + conn = sqlite3.connect(self.db_path) + try: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS llm_cache_entries ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider TEXT NOT NULL, + model TEXT NOT NULL, + system_hash TEXT NOT NULL, + prompt TEXT NOT NULL, + prompt_hash TEXT NOT NULL, + embedding BLOB NOT NULL, + commands_json TEXT NOT NULL, + created_at TEXT NOT NULL, + last_accessed TEXT NOT NULL, + hit_count INTEGER NOT NULL DEFAULT 0 + ) + """ + ) + cur.execute( + """ + CREATE UNIQUE INDEX IF NOT EXISTS idx_llm_cache_unique + ON llm_cache_entries(provider, model, system_hash, prompt_hash) + """ + ) + cur.execute( + """ + CREATE INDEX IF NOT EXISTS idx_llm_cache_lru + ON llm_cache_entries(last_accessed) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS llm_cache_stats ( + id INTEGER PRIMARY KEY CHECK (id = 1), + hits INTEGER NOT NULL DEFAULT 0, + misses INTEGER NOT NULL DEFAULT 0 + ) + """ + ) + cur.execute("INSERT OR IGNORE INTO llm_cache_stats(id, hits, misses) VALUES (1, 0, 0)") + conn.commit() + finally: + conn.close() + + @staticmethod + def _utcnow_iso() -> str: + return datetime.utcnow().replace(microsecond=0).isoformat() + "Z" + + @staticmethod + def _hash_text(text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + def _system_hash(self, system_prompt: str) -> str: + return self._hash_text(system_prompt) + + @staticmethod + def _tokenize(text: str) -> List[str]: + buf: List[str] = [] + current: List[str] = [] + for ch in text.lower(): + if ch.isalnum() or ch in ("-", "_", "."): + current.append(ch) + else: + if current: + buf.append("".join(current)) + current = [] + if current: + buf.append("".join(current)) + return buf + + @classmethod + def _embed(cls, text: str, dims: int = 128) -> List[float]: + vec = [0.0] * dims + tokens = cls._tokenize(text) + if not tokens: + return vec + + for token in tokens: + h = hashlib.blake2b(token.encode("utf-8"), digest_size=8).digest() + value = int.from_bytes(h, "big", signed=False) + idx = value % dims + sign = -1.0 if (value >> 63) & 1 else 1.0 + vec[idx] += sign + + norm = math.sqrt(sum(v * v for v in vec)) + if norm > 0: + vec = [v / norm for v in vec] + return vec + + @staticmethod + def _pack_embedding(vec: List[float]) -> bytes: + return json.dumps(vec, separators=(",", ":")).encode("utf-8") + + @staticmethod + def _unpack_embedding(blob: bytes) -> List[float]: + return json.loads(blob.decode("utf-8")) + + @staticmethod + def _cosine(a: List[float], b: List[float]) -> float: + if not a or not b or len(a) != len(b): + return 0.0 + dot = 0.0 + for i in range(len(a)): + dot += a[i] * b[i] + return dot + + def _record_hit(self, conn: sqlite3.Connection) -> None: + conn.execute("UPDATE llm_cache_stats SET hits = hits + 1 WHERE id = 1") + + def _record_miss(self, conn: sqlite3.Connection) -> None: + conn.execute("UPDATE llm_cache_stats SET misses = misses + 1 WHERE id = 1") + + def get_commands( + self, + prompt: str, + provider: str, + model: str, + system_prompt: str, + candidate_limit: int = 200, + ) -> Optional[List[str]]: + """Retrieve cached commands for a prompt. + + First tries exact match, then falls back to semantic similarity search. + + Args: + prompt: User's natural language request + provider: LLM provider name + model: Model name + system_prompt: System prompt used for generation + candidate_limit: Max candidates to check for similarity + + Returns: + List of commands if found, None otherwise + """ + system_hash = self._system_hash(system_prompt) + prompt_hash = self._hash_text(prompt) + now = self._utcnow_iso() + + conn = sqlite3.connect(self.db_path) + try: + cur = conn.cursor() + cur.execute( + """ + SELECT id, commands_json + FROM llm_cache_entries + WHERE provider = ? AND model = ? AND system_hash = ? AND prompt_hash = ? + LIMIT 1 + """, + (provider, model, system_hash, prompt_hash), + ) + row = cur.fetchone() + if row is not None: + entry_id, commands_json = row + cur.execute( + """ + UPDATE llm_cache_entries + SET last_accessed = ?, hit_count = hit_count + 1 + WHERE id = ? + """, + (now, entry_id), + ) + self._record_hit(conn) + conn.commit() + return json.loads(commands_json) + + query_vec = self._embed(prompt) + + cur.execute( + """ + SELECT id, embedding, commands_json + FROM llm_cache_entries + WHERE provider = ? AND model = ? AND system_hash = ? + ORDER BY last_accessed DESC + LIMIT ? + """, + (provider, model, system_hash, candidate_limit), + ) + + best: Optional[Tuple[int, float, str]] = None + for entry_id, embedding_blob, commands_json in cur.fetchall(): + vec = self._unpack_embedding(embedding_blob) + sim = self._cosine(query_vec, vec) + if best is None or sim > best[1]: + best = (entry_id, sim, commands_json) + + if best is not None and best[1] >= self.similarity_threshold: + cur.execute( + """ + UPDATE llm_cache_entries + SET last_accessed = ?, hit_count = hit_count + 1 + WHERE id = ? + """, + (now, best[0]), + ) + self._record_hit(conn) + conn.commit() + return json.loads(best[2]) + + self._record_miss(conn) + conn.commit() + return None + finally: + conn.close() + + def put_commands( + self, + prompt: str, + provider: str, + model: str, + system_prompt: str, + commands: List[str], + ) -> None: + """Store commands in cache for future retrieval. + + Args: + prompt: User's natural language request + provider: LLM provider name + model: Model name + system_prompt: System prompt used for generation + commands: List of shell commands to cache + """ + system_hash = self._system_hash(system_prompt) + prompt_hash = self._hash_text(prompt) + now = self._utcnow_iso() + vec = self._embed(prompt) + embedding_blob = self._pack_embedding(vec) + + conn = sqlite3.connect(self.db_path) + try: + conn.execute( + """ + INSERT OR REPLACE INTO llm_cache_entries( + provider, model, system_hash, prompt, prompt_hash, embedding, commands_json, + created_at, last_accessed, hit_count + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, COALESCE(( + SELECT hit_count FROM llm_cache_entries + WHERE provider = ? AND model = ? AND system_hash = ? AND prompt_hash = ? + ), 0)) + """, + ( + provider, + model, + system_hash, + prompt, + prompt_hash, + embedding_blob, + json.dumps(commands, separators=(",", ":")), + now, + now, + provider, + model, + system_hash, + prompt_hash, + ), + ) + self._evict_if_needed(conn) + conn.commit() + finally: + conn.close() + + def _evict_if_needed(self, conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute("SELECT COUNT(1) FROM llm_cache_entries") + count = int(cur.fetchone()[0]) + if count <= self.max_entries: + return + + to_delete = count - self.max_entries + cur.execute( + """ + DELETE FROM llm_cache_entries + WHERE id IN ( + SELECT id FROM llm_cache_entries + ORDER BY last_accessed ASC + LIMIT ? + ) + """, + (to_delete,), + ) + + def stats(self) -> CacheStats: + """Get current cache statistics. + + Returns: + CacheStats object with hits, misses, and computed metrics + """ + conn = sqlite3.connect(self.db_path) + try: + cur = conn.cursor() + cur.execute("SELECT hits, misses FROM llm_cache_stats WHERE id = 1") + row = cur.fetchone() + if row is None: + return CacheStats(hits=0, misses=0) + return CacheStats(hits=int(row[0]), misses=int(row[1])) + finally: + conn.close() diff --git a/docs/ISSUE-268-TESTING.md b/docs/ISSUE-268-TESTING.md new file mode 100644 index 00000000..c40b91d5 --- /dev/null +++ b/docs/ISSUE-268-TESTING.md @@ -0,0 +1,75 @@ +# Issue 268 — End-user testing guide (semantic cache + offline) + +This guide covers only how to test the feature added for issue #268. + +## Prereqs + +- Python 3.10+ +- Project installed in editable mode + +```bash +python -m venv venv +. venv/bin/activate +pip install -e . +``` + +## Test 1: Warm the cache (online) + +Run a request once with an API key configured. + +```bash +export OPENAI_API_KEY=sk-... +# or +export ANTHROPIC_API_KEY=sk-ant-... + +cortex install nginx --dry-run +``` + +Expected: +- It prints generated commands. +- This run should create/update the cache database. + +## Test 2: Check cache stats + +```bash +cortex cache stats +``` + +Expected: +- `Hits` is >= 0 +- `Misses` is >= 0 +- `Saved calls (approx)` increases when cached answers are used + +## Test 3: Offline mode (cached-only) + +Run the same request with offline mode enabled. + +```bash +cortex --offline install nginx --dry-run +``` + +Expected: +- If the request was warmed in Test 1, it should still print commands. +- If the request was never cached, it should fail with an offline-cache-miss message. + +## Test 4: Verify cache hit (repeat request) + +Run the original request again to verify cache is working: + +```bash +cortex install nginx --dry-run +cortex cache stats +``` + +Expected: +- The second run should be faster (no API call) +- `cache stats` should show `Hits: 1` + +## Notes + +- Cache location defaults to `/var/lib/cortex/cache.db` and falls back to `~/.cortex/cache.db` if permissions don't allow system paths. +- Cache size and similarity threshold can be tuned with: + - `CORTEX_CACHE_MAX_ENTRIES` (default: 500) + - `CORTEX_CACHE_SIMILARITY_THRESHOLD` (default: 0.86) +- Cache is provider+model specific, so switching providers will cause a cache miss. +- The cache uses semantic similarity matching, so slightly different wording may still return cached results. diff --git a/tests/test_semantic_cache.py b/tests/test_semantic_cache.py new file mode 100644 index 00000000..489f3867 --- /dev/null +++ b/tests/test_semantic_cache.py @@ -0,0 +1,218 @@ +"""Unit tests for semantic cache functionality.""" + +import os +import sqlite3 +import tempfile +import unittest +from pathlib import Path + +from cortex.semantic_cache import SemanticCache, CacheStats + + +class TestSemanticCache(unittest.TestCase): + """Test suite for SemanticCache.""" + + def setUp(self): + """Create temporary database for testing.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, "test_cache.db") + self.cache = SemanticCache( + db_path=self.db_path, + max_entries=10, + similarity_threshold=0.85 + ) + + def tearDown(self): + """Clean up temporary files.""" + import shutil + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_cache_initialization(self): + """Test that cache database is created properly.""" + self.assertTrue(os.path.exists(self.db_path)) + conn = sqlite3.connect(self.db_path) + cur = conn.cursor() + cur.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cur.fetchall()] + conn.close() + + self.assertIn("llm_cache_entries", tables) + self.assertIn("llm_cache_stats", tables) + + def test_cache_stats_initial(self): + """Test initial cache stats are zero.""" + stats = self.cache.stats() + self.assertEqual(stats.hits, 0) + self.assertEqual(stats.misses, 0) + self.assertEqual(stats.total, 0) + self.assertEqual(stats.hit_rate, 0.0) + + def test_put_and_get_exact_match(self): + """Test storing and retrieving with exact prompt match.""" + prompt = "install nginx" + commands = ["sudo apt update", "sudo apt install -y nginx"] + + # Store in cache + self.cache.put_commands( + prompt=prompt, + provider="openai", + model="gpt-4", + system_prompt="test system prompt", + commands=commands + ) + + # Retrieve from cache + retrieved = self.cache.get_commands( + prompt=prompt, + provider="openai", + model="gpt-4", + system_prompt="test system prompt" + ) + + self.assertEqual(retrieved, commands) + + # Check stats + stats = self.cache.stats() + self.assertEqual(stats.hits, 1) + self.assertEqual(stats.misses, 0) + + def test_cache_miss(self): + """Test cache miss with non-existent prompt.""" + result = self.cache.get_commands( + prompt="install something that was never cached", + provider="openai", + model="gpt-4", + system_prompt="test system prompt" + ) + + self.assertIsNone(result) + + stats = self.cache.stats() + self.assertEqual(stats.hits, 0) + self.assertEqual(stats.misses, 1) + + def test_semantic_similarity_match(self): + """Test semantic similarity matching for similar prompts.""" + # Store original + self.cache.put_commands( + prompt="install nginx web server", + provider="openai", + model="gpt-4", + system_prompt="test system prompt", + commands=["sudo apt install nginx"] + ) + + # Try very similar wording + result = self.cache.get_commands( + prompt="install nginx web server", + provider="openai", + model="gpt-4", + system_prompt="test system prompt" + ) + + # Should find the exact match + self.assertIsNotNone(result) + self.assertEqual(result, ["sudo apt install nginx"]) + + def test_provider_isolation(self): + """Test that different providers don't share cache entries.""" + prompt = "install docker" + commands_openai = ["apt install docker"] + commands_claude = ["apt install docker-ce"] + + # Store with OpenAI + self.cache.put_commands( + prompt=prompt, + provider="openai", + model="gpt-4", + system_prompt="test", + commands=commands_openai + ) + + # Store with Claude + self.cache.put_commands( + prompt=prompt, + provider="claude", + model="claude-3", + system_prompt="test", + commands=commands_claude + ) + + # Retrieve for OpenAI + result_openai = self.cache.get_commands( + prompt=prompt, + provider="openai", + model="gpt-4", + system_prompt="test" + ) + + # Retrieve for Claude + result_claude = self.cache.get_commands( + prompt=prompt, + provider="claude", + model="claude-3", + system_prompt="test" + ) + + self.assertEqual(result_openai, commands_openai) + self.assertEqual(result_claude, commands_claude) + + def test_lru_eviction(self): + """Test that LRU eviction works when max_entries is exceeded.""" + # Fill cache to max + for i in range(10): + self.cache.put_commands( + prompt=f"install package{i}", + provider="openai", + model="gpt-4", + system_prompt="test", + commands=[f"apt install package{i}"] + ) + + # Add one more (should trigger eviction) + self.cache.put_commands( + prompt="install package10", + provider="openai", + model="gpt-4", + system_prompt="test", + commands=["apt install package10"] + ) + + # Verify cache size doesn't exceed max + conn = sqlite3.connect(self.db_path) + cur = conn.cursor() + cur.execute("SELECT COUNT(*) FROM llm_cache_entries") + count = cur.fetchone()[0] + conn.close() + + self.assertEqual(count, 10) + + def test_embedding_generation(self): + """Test that embeddings are generated correctly.""" + vec = SemanticCache._embed("test prompt") + + self.assertEqual(len(vec), 128) + self.assertIsInstance(vec[0], float) + + # Check normalization (L2 norm should be ~1.0) + norm = sum(v * v for v in vec) ** 0.5 + self.assertAlmostEqual(norm, 1.0, places=5) + + def test_cosine_similarity(self): + """Test cosine similarity calculation.""" + vec1 = [1.0, 0.0, 0.0] + vec2 = [1.0, 0.0, 0.0] + vec3 = [0.0, 1.0, 0.0] + + # Identical vectors + sim1 = SemanticCache._cosine(vec1, vec2) + self.assertAlmostEqual(sim1, 1.0) + + # Orthogonal vectors + sim2 = SemanticCache._cosine(vec1, vec3) + self.assertAlmostEqual(sim2, 0.0) + + +if __name__ == "__main__": + unittest.main()