In [1]:
import torch
import sqlite3
import numpy as np
from PIL import Image
import requests
from io import BytesIO
from typing import Dict, List, Optional, Union
from transformers import CLIPModel, CLIPProcessor
from abc import ABC, abstractmethod

In [2]:
class BaseEmbeddingsReader(ABC):
    """Abstract base class for embedding readers."""
    
    def __init__(self, 
                 db_path: str = "wiki_embeddings.db",
                 model_name: str = "openai/clip-vit-base-patch32"):
        """Initialize the base reader with common functionality."""
        self.db_path = db_path
        self.device = self._get_device()
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)

    @staticmethod
    def _get_device() -> str:
        """Determine the appropriate device for computation."""
        if torch.cuda.is_available():
            return "cuda"
        elif torch.backends.mps.is_available():
            return "mps"
        return "cpu"

    def get_article_by_title(self, title: str) -> Optional[Dict]:
        """Get a specific article by its title."""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute("""
                SELECT article_id, title, url, embedding, processed_date
                FROM embeddings 
                WHERE title = ?
            """, (title,))
            
            row = cursor.fetchone()
            if row:
                return {
                    'article_id': row[0],
                    'title': row[1],
                    'url': row[2],
                    'embedding': np.frombuffer(row[3], dtype=np.float32),
                    'processed_date': row[4]
                }
            return None

    def get_database_stats(self) -> Dict:
        """Get statistics about the database."""
        with sqlite3.connect(self.db_path) as conn:
            stats = {}
            
            cursor = conn.execute("SELECT COUNT(*) FROM embeddings")
            stats['total_articles'] = cursor.fetchone()[0]
            
            cursor = conn.execute("""
                SELECT title, processed_date 
                FROM embeddings 
                ORDER BY processed_date DESC 
                LIMIT 1
            """)
            row = cursor.fetchone()
            if row:
                stats['most_recent_article'] = {
                    'title': row[0],
                    'date': row[1]
                }
            
            cursor = conn.execute("SELECT COUNT(*) FROM failed_articles")
            stats['failed_articles'] = cursor.fetchone()[0]
            
            return stats

    def _find_similar_articles(self, query_embedding: np.ndarray, limit: int) -> List[Dict]:
        """Find similar articles based on embedding."""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute("""
                SELECT article_id, title, url, embedding
                FROM embeddings
            """)
            
            results = [
                {
                    'article_id': row[0],
                    'title': row[1],
                    'url': row[2],
                    'similarity': float(np.dot(query_embedding[0], np.frombuffer(row[3], dtype=np.float32)))
                }
                for row in cursor
            ]
        
        results.sort(key=lambda x: x['similarity'], reverse=True)
        return results[:limit]

    @abstractmethod
    def get_embedding(self, input_data: Union[str, str]) -> torch.Tensor:
        """Generate embedding for the input data."""
        pass

    @abstractmethod
    def get_similar_articles(self, input_data: Union[str, str], limit: int = 5) -> List[Dict]:
        """Find similar articles based on input data."""
        pass

class TextMatcher(BaseEmbeddingsReader):
    """Reader for text-based embeddings."""
    
    def get_embedding(self, input_text: str) -> torch.Tensor:
        """Generate embedding for input text."""
        with torch.no_grad():
            inputs = self.processor(
                text=[input_text],
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=77
            )
            text_features = self.model.get_text_features(
                **{k: v.to(self.device) for k, v in inputs.items()}
            )
            return text_features / text_features.norm(dim=-1, keepdim=True)

    def get_similar_articles(self, query: str, limit: int = 5) -> List[Dict]:
        """Find articles similar to the query text."""
        query_embedding = self.get_embedding(query).cpu().numpy()
        return self._find_similar_articles(query_embedding, limit)

class ImageMatcher(BaseEmbeddingsReader):
    """Reader for image-based embeddings."""

    def _load_image(self, image_path: str) -> Image.Image:
        """Load an image from either a local path or URL."""
        try:
            if image_path.startswith(('http://', 'https://')):
                response = requests.get(image_path, timeout=10)
                response.raise_for_status()
                return Image.open(BytesIO(response.content))
            return Image.open(image_path)
        except (requests.exceptions.RequestException, Exception) as e:
            raise Exception(f"Error loading image: {str(e)}")

    def get_embedding(self, image_path: str) -> torch.Tensor:
        """Generate embedding for input image."""
        try:
            image = self._load_image(image_path)
            with torch.no_grad():
                inputs = self.processor(
                    images=image,
                    return_tensors="pt"
                )
                image_features = self.model.get_image_features(
                    **{k: v.to(self.device) for k, v in inputs.items()}
                )
                return image_features / image_features.norm(dim=-1, keepdim=True)
        except Exception as e:
            raise Exception(f"Error processing image: {str(e)}")

    def get_similar_articles(self, image_path: str, limit: int = 5) -> List[Dict]:
        """Find articles similar to the input image."""
        query_embedding = self.get_embedding(image_path).cpu().numpy()
        return self._find_similar_articles(query_embedding, limit)

In [None]:
# print db stats
text_reader = TextMatcher()
image_reader = ImageMatcher()
print(text_reader.get_database_stats())


In [None]:
# For text-based search
#similar_articles = text_reader.get_similar_articles("ficus audrey")
#[print(article['title']) for article in similar_articles];

# For image-based search
similar_articles = image_reader.get_similar_articles("/Users/clkruse/Downloads/IMG_6984.jpeg",20)
[print(f"{article['title']}, {article['url']}, {article['similarity']:.3f}") for article in similar_articles];

In [None]:
class BaseEmbeddingsReader(ABC):
    """Abstract base class for embedding readers."""
    
    def __init__(self, 
                 db_path: str = "wiki_embeddings.db",
                 model_name: str = "openai/clip-vit-base-patch32"):
        """Initialize the base reader with common functionality."""
        self.db_path = db_path
        self.device = self._get_device()
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)

    @staticmethod
    def _get_device() -> str:
        """Determine the appropriate device for computation."""
        if torch.cuda.is_available():
            return "cuda"
        elif torch.backends.mps.is_available():
            return "mps"
        return "cpu"

    def get_article_by_title(self, title: str) -> Optional[Dict]:
        """Get a specific article by its title."""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute("""
                SELECT article_id, title, url, embedding, processed_date
                FROM embeddings 
                WHERE title = ?
            """, (title,))
            
            row = cursor.fetchone()
            if row:
                return {
                    'article_id': row[0],
                    'title': row[1],
                    'url': row[2],
                    'embedding': np.frombuffer(row[3], dtype=np.float32),
                    'processed_date': row[4]
                }
            return None

    def get_database_stats(self) -> Dict:
        """Get statistics about the database."""
        with sqlite3.connect(self.db_path) as conn:
            stats = {}
            
            cursor = conn.execute("SELECT COUNT(*) FROM embeddings")
            stats['total_articles'] = cursor.fetchone()[0]
            
            cursor = conn.execute("""
                SELECT title, processed_date 
                FROM embeddings 
                ORDER BY processed_date DESC 
                LIMIT 1
            """)
            row = cursor.fetchone()
            if row:
                stats['most_recent_article'] = {
                    'title': row[0],
                    'date': row[1]
                }
            
            cursor = conn.execute("SELECT COUNT(*) FROM failed_articles")
            stats['failed_articles'] = cursor.fetchone()[0]
            
            return stats

    def _find_similar_articles(self, query_embedding: np.ndarray, limit: int) -> List[Dict]:
        """Find similar articles based on embedding."""
        with sqlite3.connect(self.db_path) as conn:
            cursor = conn.execute("""
                SELECT article_id, title, url, embedding
                FROM embeddings
            """)
            
            results = [
                {
                    'article_id': row[0],
                    'title': row[1],
                    'url': row[2],
                    'similarity': float(np.dot(query_embedding[0], np.frombuffer(row[3], dtype=np.float32)))
                }
                for row in cursor
            ]
        
        results.sort(key=lambda x: x['similarity'], reverse=True)
        return results[:limit]

    @abstractmethod
    def get_embedding(self, input_data: Union[str, str]) -> torch.Tensor:
        """Generate embedding for the input data."""
        pass

    @abstractmethod
    def get_similar_articles(self, input_data: Union[str, str], limit: int = 5) -> List[Dict]:
        """Find similar articles based on input data."""
        pass

class ImageMatcher(BaseEmbeddingsReader):
    """Reader for image-based embeddings."""

    def _load_image(self, image_path: str) -> Image.Image:
        """Load an image from either a local path or URL."""
        try:
            if image_path.startswith(('http://', 'https://')):
                response = requests.get(image_path, timeout=10)
                response.raise_for_status()
                return Image.open(BytesIO(response.content))
            return Image.open(image_path)
        except (requests.exceptions.RequestException, Exception) as e:
            raise Exception(f"Error loading image: {str(e)}")

    def get_embedding(self, image_path: str) -> torch.Tensor:
        """Generate embedding for input image."""
        try:
            image = self._load_image(image_path)
            with torch.no_grad():
                inputs = self.processor(
                    images=image,
                    return_tensors="pt"
                )
                image_features = self.model.get_image_features(
                    **{k: v.to(self.device) for k, v in inputs.items()}
                )
                return image_features / image_features.norm(dim=-1, keepdim=True)
        except Exception as e:
            raise Exception(f"Error processing image: {str(e)}")

    def get_similar_articles(self, image_path: str, limit: int = 5) -> List[Dict]:
        """Find articles similar to the input image."""
        query_embedding = self.get_embedding(image_path).cpu().numpy()
        return self._find_similar_articles(query_embedding, limit)