In [24]:
# !pip install transformers --upgrade
# !pip install sentence_transformers
# !pip install einops

In [20]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import itertools

import polars as pl
import numpy as np
import torch
import transformers
from transformers import AutoTokenizer, AutoModel
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer

from santa.metrics import PerplexityCalculator

In [3]:
scorer = PerplexityCalculator(model_path="google/gemma-2-9b")

Loading checkpoint shards: 100% 8/8 [00:08<00:00,  1.07s/it]


In [25]:
# model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModel.from_pretrained(model_name, device_map="auto")
model_name = "nomic-ai/nomic-embed-text-v1"
model = SentenceTransformer(model_name, trust_remote_code=True)

<All keys matched successfully>


In [64]:
# id=4
text = "the of and to in that have not you with we it from as peppermint candy fruitcake chocolate milk eggnog cookie snowglobe toy doll game puzzle greeting card wrapping paper bow candle fireplace wreath poinsettia angel star night wish dream believe wonder hope joy peace season merry hohoho kaggle workshop"
# id=5
# text = "from and and the and of as in is it of that to the we the with you have advent card angel bake beard believe bow candy carol candle cheer cheer chocolate chimney chimney cookie decorations doll dream drive eat eggnog elf family fireplace fireplace fruitcake game gifts give gingerbread greeting grinch holiday holly hohoho hope jingle jump joy kaggle laugh magi merry milk mistletoe naughty nice night night not nutcracker ornament ornament paper peace peppermint polar poinsettia puzzle reindeer relax scrooge season sing sleigh sleep snowglobe star stocking toy unwrap visit walk wish wonder workshop workshop wrapping wreath yuletide"
tokens = np.array(text.split())
np.random.shuffle(tokens)
len(tokens)

50

In [65]:
# inputs = tokenizer(tokens.tolist(), padding=True, truncation=True, return_tensors="pt").to("cuda")
with torch.no_grad():
    embedding = model.encode(tokens)
#     outputs = model(**inputs)
# embedding = outputs.pooler_output.cpu()

embedding.shape

(50, 768)

In [89]:
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters)
kmeans.fit(embedding)

In [90]:
for i in range(n_clusters):
    print(tokens[kmeans.labels_ == i])

['angel' 'eggnog' 'star']
['that' 'as' 'have' 'not' 'the' 'in' 'paper' 'from' 'and' 'of' 'hohoho'
 'candle' 'it' 'you' 'with' 'to' 'chocolate' 'we']
['season' 'fruitcake' 'fireplace' 'believe' 'snowglobe' 'peace' 'workshop'
 'night' 'dream']
['wreath' 'hope' 'joy' 'kaggle' 'greeting' 'wonder' 'card' 'wrapping'
 'wish' 'merry']
['cookie' 'game' 'poinsettia' 'peppermint' 'candy' 'bow' 'toy' 'puzzle'
 'doll' 'milk']


In [91]:
new_tokens = np.concatenate([tokens[kmeans.labels_ == i] for i in range(n_clusters)])
new_tokens

array(['angel', 'eggnog', 'star', 'that', 'as', 'have', 'not', 'the',
       'in', 'paper', 'from', 'and', 'of', 'hohoho', 'candle', 'it',
       'you', 'with', 'to', 'chocolate', 'we', 'season', 'fruitcake',
       'fireplace', 'believe', 'snowglobe', 'peace', 'workshop', 'night',
       'dream', 'wreath', 'hope', 'joy', 'kaggle', 'greeting', 'wonder',
       'card', 'wrapping', 'wish', 'merry', 'cookie', 'game',
       'poinsettia', 'peppermint', 'candy', 'bow', 'toy', 'puzzle',
       'doll', 'milk'], dtype='<U10')

In [92]:
best_score = np.inf
for perm in itertools.permutations(list(range(n_clusters))):
    new_tokens = np.concatenate([tokens[kmeans.labels_ == i] for i in perm])
    new_text = " ".join(new_tokens)
    score = scorer.get_perplexity(new_text)
    if score < best_score:
        best_score = score
        print(best_score, "\t", new_text)

1419.1487514060357 	 angel eggnog star that as have not the in paper from and of hohoho candle it you with to chocolate we season fruitcake fireplace believe snowglobe peace workshop night dream wreath hope joy kaggle greeting wonder card wrapping wish merry cookie game poinsettia peppermint candy bow toy puzzle doll milk
1391.6999298224052 	 angel eggnog star that as have not the in paper from and of hohoho candle it you with to chocolate we wreath hope joy kaggle greeting wonder card wrapping wish merry season fruitcake fireplace believe snowglobe peace workshop night dream cookie game poinsettia peppermint candy bow toy puzzle doll milk
1312.498038515131 	 angel eggnog star that as have not the in paper from and of hohoho candle it you with to chocolate we cookie game poinsettia peppermint candy bow toy puzzle doll milk wreath hope joy kaggle greeting wonder card wrapping wish merry season fruitcake fireplace believe snowglobe peace workshop night dream
1302.284097746097 	 angel egg