<a href="https://colab.research.google.com/github/laraschwarz/NYTConnectionsAI/blob/main/WordConnections.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [None]:
from transformers import AdamW, get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(num_epochs):
    model.train()
    for batch in dataloader:
        inputs, labels = batch
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
# Load pre-trained BERT model and tokenizer
bert = "bert-large-uncased"  # Specify the BERT model to use
tokenizer = AutoTokenizer.from_pretrained(bert)  # Load tokenizer for text processing
model = AutoModel.from_pretrained(bert, output_hidden_states=True)  # Load BERT model with hidden states


def group_words(words):
    """
    Groups words based on semantic similarity using BERT representations and K-Means clustering.

    Args:
        words (list): A list of words to group. The length must be divisible by 4.

    Returns:
        list: A list of lists, where each inner list represents a group of semantically similar words.
    """
    # Check if the length of words is divisible by 4
    if len(words) % 4 != 0:
        raise ValueError("The length of input words must be divisible by 4.")

    # Tokenize the words using the BERT tokenizer
    sequences = tokenizer(words, padding=True, truncation=True, return_tensors="pt")  # Tokenize, pad, and convert to PyTorch tensors

    # Get BERT embeddings without gradient calculations
    with torch.no_grad():
        outputs = model(**sequences)  # Pass sequences to BERT model
        hidden_states = outputs.hidden_states[-1]  # Extract the last layer hidden states

    # Reshape hidden states for similarity calculation
    hidden_states = hidden_states.view(hidden_states.size(0), -1)  # Flatten for cosine similarity

    # Calculate pairwise cosine similarities between word embeddings
    similarities = cosine_similarity(hidden_states, hidden_states)

    # Perform K-Means clustering to group similar words
    clustering = KMeans(n_clusters=len(words)//4, n_init=100)  # Create clusters based on input length
    labels = clustering.fit_predict(similarities)  # Assign each word to a cluster

    # Organize words into their respective clusters
    groups = [[] for _ in range(len(words)//4)]  # Create empty lists for clusters
    for i, word in enumerate(words):
        group_index = labels[i]
        if len(groups[group_index]) < 4:
            groups[group_index].append(word)  # Add words to their assigned clusters
        else:
            # If a group already has 4 words, find the next available group
            for j in range(len(groups)):
                if len(groups[j]) < 4:
                    groups[j].append(word)
                    break

    return groups  # Return the list of grouped words

In [None]:
# Define a list of words
words = ["red", "green", "purple", "blue", "apple", "banana", "grape", "watermelon", "basketball", "football", "soccer", "baseball", "computer", "keyboard", "mouse", "monitor"]
words2 = ["tesla", "UCSD", "toyota", "ferrari", "SDSU", "honda", "USC", "USD", "burger", "pizza", "pasta", "wings"]
words3 = ["happy", "sad", "angry", "bored", "dancing", "running", "jumping", "lifting"]
words4 = ["forest", "parade", "erotic", "skeleton", "train", "hedgehog", "olive", "book", "mint", "democratic", "cactus", "fleet", "noble", "caravan", "sad", "lime"]

# Call the group_words function to group the words based on semantic similarity
groups = group_words(words)
groups2 = group_words(words2)
groups3 = group_words(words3)
groups4 = group_words(words4)

print("Test #1:\n")

# Print each group of words
for i, group in enumerate(groups):
    # Add 1 to the index for more natural group numbering
    print(f"Group {i+1}: {', '.join(group)}")

print("\nTest #2:\n")

for i, group in enumerate(groups2):
    # Add 1 to the index for more natural group numbering
    print(f"Group {i+1}: {', '.join(group)}")

print("\nTest #3:\n")

for i, group in enumerate(groups3):
    # Add 1 to the index for more natural group numbering
    print(f"Group {i+1}: {', '.join(group)}")

print("\nTest #4:\n")

for i, group in enumerate(groups4):
    # Add 1 to the index for more natural group numbering
    print(f"Group {i+1}: {', '.join(group)}")



Test #1:

Group 1: apple, banana, grape, computer
Group 2: basketball, football, soccer, baseball
Group 3: watermelon, keyboard, mouse, monitor
Group 4: red, green, purple, blue

Test #2:

Group 1: tesla, toyota, ferrari, honda
Group 2: UCSD, SDSU, USC, USD
Group 3: burger, pizza, pasta, wings

Test #3:

Group 1: sad, angry, bored, jumping
Group 2: happy, dancing, running, lifting

Test #4:

Group 1: olive, democratic, fleet, noble
Group 2: forest, parade, skeleton, book
Group 3: hedgehog, caravan, sad, lime
Group 4: erotic, train, mint, cactus
