In [17]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import wikipedia
import torch
from typing import List, Tuple, Union, Optional
import re

In [20]:
class WikiImageMatcher:
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        """
        Initialize the WikiImageMatcher with a CLIP model from Hugging Face.
        
        Args:
            model_name (str): Name of the CLIP model from Hugging Face
        """
        self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)

    def clean_wiki_text(self, text: str) -> str:
        """
        Clean Wikipedia text by removing references, links, and extra whitespace.
        
        Args:
            text (str): Raw Wikipedia text
            
        Returns:
            str: Cleaned text
        """
        # Remove references
        text = re.sub(r'\[\d+\]', '', text)
        # Remove multiple newlines
        text = re.sub(r'\n+', ' ', text)
        # Remove extra whitespace
        text = ' '.join(text.split())
        return text
    
    def find_wiki_article(self, query: str) -> Optional[wikipedia.WikipediaPage]:
        """
        Find a Wikipedia article using multiple search strategies.
        
        Args:
            query (str): Article title or search query
            
        Returns:
            Optional[wikipedia.WikipediaPage]: Wikipedia page if found, None otherwise
        """
        try:
            # Strategy 1: Try direct page lookup
            try:
                return wikipedia.page(query, auto_suggest=False)
            except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError):
                pass

            # Strategy 2: Try with auto_suggest
            try:
                return wikipedia.page(query, auto_suggest=True)
            except wikipedia.exceptions.DisambiguationError as e:
                # If disambiguation page, try the first suggestion
                try:
                    return wikipedia.page(e.options[0], auto_suggest=False)
                except:
                    pass
            except wikipedia.exceptions.PageError:
                pass

            # Strategy 3: Search and use the first result
            search_results = wikipedia.search(query, results=5)
            if search_results:
                try:
                    return wikipedia.page(search_results[0], auto_suggest=False)
                except:
                    pass

            # If all strategies fail
            return None

        except Exception as e:
            print(f"Error searching for '{query}': {str(e)}")
            return None

    def get_article_embedding(self, article_name: str) -> Tuple[Optional[torch.Tensor], Optional[str]]:
        """
        Get CLIP embedding for a Wikipedia article.
        
        Args:
            article_name (str): Name of the Wikipedia article
            
        Returns:
            Tuple[Optional[torch.Tensor], Optional[str]]: 
                (embedding vector, actual article title) if successful,
                (None, None) if failed
        """
        try:
            # Find the article
            page = self.find_wiki_article(article_name)
            if page is None:
                print(f"Could not find article: {article_name}")
                return None, None

            # Get and clean content
            content = self.clean_wiki_text(page.content)
            
            # Process text through CLIP
            with torch.no_grad():
                inputs = self.processor(
                    text=[content],
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=77
                )
                text_features = self.model.get_text_features(**{k: v.to(self.device) for k, v in inputs.items()})
                normalized_features = text_features / text_features.norm(dim=-1, keepdim=True)
                
            return normalized_features, page.title
            
        except Exception as e:
            print(f"Error processing article '{article_name}': {str(e)}")
            return None, None

    
    def get_image_embedding(self, image: Union[str, Image.Image]) -> torch.Tensor:
        """
        Get CLIP embedding for an image.
        
        Args:
            image (Union[str, Image.Image]): Either a path to an image or a PIL Image
            
        Returns:
            torch.Tensor: Embedding vector for the image
        """
        try:
            # Load image if path is provided
            if isinstance(image, str):
                image = Image.open(image)
                
            # Process image through CLIP
            with torch.no_grad():
                inputs = self.processor(
                    images=image,
                    return_tensors="pt"
                )
                image_features = self.model.get_image_features(**{k: v.to(self.device) for k, v in inputs.items()})
                
            return image_features / image_features.norm(dim=-1, keepdim=True)
            
        except Exception as e:
            raise Exception(f"Error processing image: {str(e)}")
    
    def find_matches(self, 
                    query_embedding: torch.Tensor,
                    article_embeddings: List[Tuple[str, torch.Tensor]],
                    num_matches: int = 1) -> List[Tuple[str, float]]:
        """
        Find the closest matching articles for a query embedding.
        
        Args:
            query_embedding (torch.Tensor): Embedding vector to match against
            article_embeddings (List[Tuple[str, torch.Tensor]]): List of (article_name, embedding) pairs
            num_matches (int): Number of matches to return
            
        Returns:
            List[Tuple[str, float]]: List of (article_name, similarity_score) pairs
        """
        similarities = []
        
        for article_name, article_embedding in article_embeddings:
            similarity = torch.nn.functional.cosine_similarity(
                query_embedding, article_embedding
            ).item()
            similarities.append((article_name, similarity))
        
        # Sort by similarity score in descending order
        similarities.sort(key=lambda x: x[1], reverse=True)
        return similarities[:num_matches]

# Example usage function
def demo_wiki_image_matcher(image_path: str, article_names: List[str], num_matches: int = 1):
    """
    Demonstrate the WikiImageMatcher with example usage.
    """
    matcher = WikiImageMatcher()
    
    # Get embeddings for all articles
    article_embeddings = []
    for article_name in article_names:
        embedding, actual_title = matcher.get_article_embedding(article_name)
        if embedding is not None and actual_title is not None:
            article_embeddings.append((actual_title, embedding))
        else:
            print(f"Skipping '{article_name}' due to retrieval error")
    
    if not article_embeddings:
        print("No valid articles found to match against")
        return []

    # Get embedding for the query image
    image_embedding = matcher.get_image_embedding(image_path)
    
    # Find matches
    matches = matcher.find_matches(image_embedding, article_embeddings, num_matches)
    
    return matches

In [23]:
image_path = '/Users/clkruse/Downloads/astro_test.png'
demo_wiki_image_matcher(image_path, ["Astronaut", "Ancient Greek", "Roses", "Mount Everest", "Midjourney", "Pink (Color)"], num_matches=6)


[('Astronaut', 0.27501875162124634),
 ('Shades of pink', 0.2281285524368286),
 ('Rose', 0.22468915581703186),
 ('Ancient Greek', 0.2177101969718933),
 ('Mount Everest', 0.17938131093978882),
 ('Midjourney', 0.1769527792930603)]

In [25]:
from datasets import load_dataset
dataset = load_dataset("facebook/contrastive_search_index", streaming=True)

DatasetNotFoundError: Dataset 'facebook/contrastive_search_index' doesn't exist on the Hub or cannot be accessed.