In [1]:
import numpy as np
from gensim.models import KeyedVectors
from sklearn.cluster import KMeans
from k_means_constrained import KMeansConstrained

model_file = "archive/GoogleNews-vectors-negative300.bin"  # Update with the correct path to the Word2Vec file
model = KeyedVectors.load_word2vec_format(model_file, binary=True)

In [2]:
# Load pre-trained Word2Vec embeddings
def load_word2vec_model(model_file):
    model = KeyedVectors.load_word2vec_format(model_file, binary=True)  # Load Word2Vec in binary format
    return model

# Function to get word embedding
def get_embedding(word, model):
    if word in model:
        return model[word]
    else:
        return np.zeros(model.vector_size)  # Return a zero vector if word is not found

# Function to get embeddings for a list of words
def get_word_embeddings(word_list, model):
    return np.array([get_embedding(word, model) for word in word_list])

# Apply k-means clustering
def apply_kmeans_clustering(word_embeddings, n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=0)
    kmeans.fit(word_embeddings)
    return kmeans.labels_

# Group words based on clustering labels
def group_words_by_clusters(word_list, labels):
    clusters = {}
    for word, label in zip(word_list, labels):
        if label not in clusters:
            clusters[label] = []
        clusters[label].append(word)
    return clusters

In [13]:
# Define the set of words from the Connections game
word_list = ["v", "m", "l", "hen", "mare", "yew", "we", "i", "cow", "they", "doe", "it", "d", "you", "u", "ewe"]# Get embeddings for the word list
word_embeddings = get_word_embeddings(word_list, model)

In [14]:
clf = KMeansConstrained(
     n_clusters=4,
     size_min=4,
     size_max=4,
     random_state=0
 )
clf.fit_predict(word_embeddings)
clf.cluster_centers_
clf.labels_

array([1, 1, 1, 2, 0, 2, 3, 2, 0, 3, 0, 3, 1, 3, 2, 0], dtype=int32)

In [15]:
grouped_words = group_words_by_clusters(word_list, clf.labels_)

# Output the groups
for cluster, words in grouped_words.items():
    print(f"{', '.join(words)}")

v, m, l, d
hen, yew, i, u
mare, cow, doe, ewe
we, they, it, you


In [6]:
import random
import json

shuffle=False
with open("connections_prompts_copy.jsonl", "r", encoding='unicode-escape') as f:
    data = list(f)

In [7]:
import ast
words = []
solutions = []
for d in data:
    try: 
        res = ast.literal_eval(d)
        words.append(res['words'])
        solutions.append(res['solution'])
    except:
        pass

In [8]:
correct_sets_list = []
for solution in solutions:
    correct_sets = {}
    for correct_set in solution['groups']:
        correct_words = set(correct_set['words'])
        correct_sets[correct_set['reason']] = correct_words
    correct_sets_list.append(correct_sets)

In [9]:
len(words)

0

In [135]:
words

[['nets',
  'return',
  'heat',
  'jazz',
  'mom',
  'shift',
  'kayak',
  'option',
  'rain',
  'sleet',
  'level',
  'racecar',
  'bucks',
  'tab',
  'hail',
  'snow'],
 ['league',
  'loafer',
  'queue',
  'are',
  'pump',
  'foot',
  'why',
  'time',
  'mile',
  'us',
  'sneaker',
  'sea',
  'boot',
  'people',
  'essence',
  'yard'],
 ['amigo',
  'wolf',
  'cheek',
  'pom',
  'stooge',
  'lab',
  'king',
  'peke',
  'scarf',
  'tenor',
  'eye',
  'gobble',
  'nose',
  'mouth',
  'chow',
  'pit'],
 ['puma',
  'sweep',
  'chicago',
  'bat',
  'iron',
  'adidas',
  'vacuum',
  'super',
  'carousel',
  'mop',
  'cabaret',
  'dust',
  'nike',
  'cats',
  'spider',
  'reebok'],
 ['glum',
  'prime',
  'low',
  'hulu',
  'green',
  'plum',
  'blue',
  'down',
  'peacock',
  'mayo',
  'scarlet',
  'ketchup',
  'relish',
  'netflix',
  'mustard',
  'tartar'],
 ['future',
  'sister',
  'q-tip',
  'sin',
  'chance',
  'boardwalk',
  'sea',
  'midnight',
  'jail',
  'common',
  'ice cube',
  'b

In [128]:
def evaluate_accuracy(correct_sets, predicted_sets):

    exact_matches = 0
    total_correct_predictions = 0
    
    # Calculate Exact Match Accuracy
    for correct_set in correct_sets:
        if correct_set in predicted_sets:
            exact_matches += 1
            total_correct_predictions += 4
            continue
        
        max_intersection = 1
        for predicted_set in predicted_sets:
            intersection = len(correct_set.intersection(predicted_set))
            if intersection > max_intersection:
                max_intersection = intersection
        
        if max_intersection > 1:
            total_correct_predictions += max_intersection     

    exact_match_accuracy = exact_matches / 4
    cluster_purity = total_correct_predictions / 16

    return exact_match_accuracy, cluster_purity

In [122]:
words_n_solutions = zip(words, correct_sets_list)

In [132]:
import itertools
words_n_solutions = zip(words, correct_sets_list)
accuracy = 0
purity = 0
lowest_purity = 1
lowest_purity_set = []
for word, correct_set in words_n_solutions:
    word_embeddings = get_word_embeddings(word, model)
    clf = KMeansConstrained(
     n_clusters=4,
     size_min=4,
     size_max=4,
     random_state=1
    )
    clf.fit_predict(word_embeddings)
    predicted_groups = group_words_by_clusters(word, clf.labels_)
    predicted_sets = []

    # Output the groups
    for _, predicted_words in predicted_groups.items():
        predicted_sets.append(set(predicted_words))
    
    
    exact_match_accuracy, cluster_purity = evaluate_accuracy(correct_set.values(), predicted_sets)

    if cluster_purity == lowest_purity:
        lowest_purity_set.append([predicted_sets, correct_set])
    
    elif cluster_purity < lowest_purity:
        lowest_purity = cluster_purity
        lowest_purity_set = [predicted_sets, correct_set]
    
    accuracy += exact_match_accuracy
    purity += cluster_purity

avg_accuracy = accuracy/len(words)
avg_purity = purity/len(words)

print(f'Average Exact Match Accuracy: {avg_accuracy}')
print(f'Average Cluster Purity: {avg_purity}')

Average Exact Match Accuracy: 0.18536931818181818
Average Cluster Purity: 0.6818181818181818


In [133]:
lowest_purity

0.375

In [134]:
lowest_purity_set

[[{'framed', 'met', 'road', 'rocky'},
  {'fury', 'horror', 'mad', 'rabbit'},
  {'harry', 'max', 'roger', 'sally'},
  {'picture', 'show', 'when', 'who'}],
 {'who framed roger rabbit': {'framed', 'rabbit', 'roger', 'who'},
  'rocky horror picture show': {'horror', 'picture', 'rocky', 'show'},
  'when harry met sally': {'harry', 'met', 'sally', 'when'},
  'mad max fury road': {'fury', 'mad', 'max', 'road'}},
 [[{'chopped', 'knock', 'ram', 'slam'},
   {'maxi', 'mini', 'pan', 'roast'},
   {'bachelor', 'catfish', 'jaguar', 'mouse'},
   {'alone', 'fiat', 'lily', 'survivor'}],
  {'reality shows': {'alone', 'catfish', 'chopped', 'survivor'},
   'criticize': {'knock', 'pan', 'roast', 'slam'},
   'car brands': {'fiat', 'jaguar', 'mini', 'ram'},
   '___ pad': {'bachelor', 'lily', 'maxi', 'mouse'}}],
 [[{'outside', 'remote', 'room', 'vehicle'},
   {'grande', 'mars', 'means', 'proof'},
   {'channel', 'legend', 'medium', 'styles'},
   {'large', 'slim', 'small', 'swift'}],
  {'method': {'channel', 'me