In [None]:
import torch
from transformers import CLIPModel, CLIPTokenizer
from sklearn.cluster import AgglomerativeClustering
import numpy as np
from collections import defaultdict

def solve_connections(words: list[str], model, tokenizer, device: str):
    """
    Solves the NYT Connections game by clustering word embeddings.

    Args:
        words (list[str]): A list of 16 words from the puzzle.
        model: The pre-loaded transformer model (e.g., CLIP).
        tokenizer: The pre-loaded tokenizer for the model.
        device (str): The device to run the model on ('cuda' or 'cpu').

    Returns:
        dict: A dictionary where keys are group numbers and values are lists of
              words belonging to that group.
        dict: A dictionary where keys are group numbers and values are the
              algorithm's best guess for the group's theme.
    """
    if len(words) != 16:
        raise ValueError("The 'words' list must contain exactly 16 words.")

    # 1. Encode all 16 words into vectors (embeddings)
    print("Encoding words...")
    with torch.no_grad():
        inputs = tokenizer(words, padding=True, return_tensors="pt").to(device)
        word_embeddings = model.get_text_features(**inputs)
    
    # Normalize embeddings for cosine similarity
    word_embeddings /= word_embeddings.norm(dim=-1, keepdim=True)
    
    # Convert to numpy for scikit-learn
    embeddings_np = word_embeddings.cpu().numpy()

    # 2. Perform clustering to find 4 groups of 4
    # AgglomerativeClustering is great for this. It builds clusters by successively
    # merging the closest data points.
    # - n_clusters=4: We know we want exactly four groups.
    # - affinity='cosine': The distance metric to use. Cosine similarity is perfect for embeddings.
    # - linkage='average': How to calculate the distance between clusters. 'average' is a robust choice.
    print("Clustering embeddings...")
    clustering = AgglomerativeClustering(
        n_clusters=4,
        metric='cosine', # Note: scikit-learn uses cosine distance (1-sim), which is equivalent for clustering
        linkage='average'
    ).fit(embeddings_np)

    # 3. Group the words based on cluster labels
    groups = defaultdict(list)
    for i, word in enumerate(words):
        groups[clustering.labels_[i]].append(word)

    # 4. (Bonus) Find the theme for each group
    # We do this by calculating the average embedding for each cluster (its "centroid")
    # and then finding which single word in the vocabulary is closest to that centroid.
    print("Finding group themes...")
    group_themes = {}
    
    # Get the model's vocabulary to search for theme words
    vocab = list(tokenizer.get_vocab().keys())
    # Clean up tokens for better readability
    clean_vocab = [word.replace("</w>", "").lower() for word in vocab if '</w>' in word and len(word) > 3]
    clean_vocab = list(set(clean_vocab)) # Remove duplicates

    with torch.no_grad():
        # Pre-encode the entire vocabulary for faster search. This might take a moment.
        # For very large vocabs, you might sample, but CLIP's is manageable.
        print(f"Encoding {len(clean_vocab)} vocabulary words for theme search...")
        vocab_inputs = tokenizer(clean_vocab, padding=True, return_tensors="pt").to(device)
        vocab_embeddings = model.get_text_features(**vocab_inputs)
        vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True)

    for group_id, member_words in groups.items():
        # Find the embeddings for the words in the current group
        member_indices = [words.index(w) for w in member_words]
        member_embeddings = torch.tensor(embeddings_np[member_indices]).to(device)
        
        # Calculate the average embedding (centroid) for the group
        centroid = member_embeddings.mean(dim=0, keepdim=True)
        
        # Find the most similar word in the vocabulary to this centroid
        similarities = torch.nn.functional.cosine_similarity(centroid, vocab_embeddings)
        best_match_index = similarities.argmax()
        theme_guess = clean_vocab[best_match_index]
        
        group_themes[group_id] = theme_guess.upper()

    return dict(groups), dict(group_themes)


if __name__ == '__main__':
    # --- Setup ---
    device = "cpu"
    print(f"Using device: {device}")

    # Load the model and tokenizer once
    print("Loading CLIP model and tokenizer...")
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    # --- Puzzle Definition ---
    # Example puzzle from a past NYT Connections game (April 5, 2024)
    puzzle_words = [
        "DOG", "LITTLE SPOON", "IRIS", "WATER BOTTLE",
        "TOO LOSE", "LENS", "BENT", "ROD",
        "CONE", "ONE WEEK", "POTATO", "POINT OF VIEW",
        "ANGLE", "CUP", "CLOSING TIME", "SCOOP"
    ]

    # --- Solve the Puzzle ---
    try:
        solved_groups, themes = solve_connections(puzzle_words, model, tokenizer, device)

        # --- Display Results ---
        print("\n--- Connections Puzzle Solved! ---\n")
        for group_id, words_in_group in solved_groups.items():
            theme = themes.get(group_id, "UNKNOWN")
            print(f"🔵 Group (Predicted Theme: {theme}): {', '.join(words_in_group)}")

    except Exception as e:
        print(f"An error occurred: {e}")

Using device: cpu
Loading CLIP model and tokenizer...
Encoding words...
Clustering embeddings...
Finding group themes...
Encoding 34252 vocabulary words for theme search...
