In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
sys.path.append("..")
import itertools
import time
import math

import numpy as np

from santa.metrics import PerplexityCalculator
from santa.utils import save_text

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
scorer = PerplexityCalculator("google/gemma-2-9b")

Loading checkpoint shards: 100% 8/8 [00:08<00:00,  1.09s/it]


In [3]:
text = "the of and to in that have not you with we it from as peppermint candy fruitcake cookie chocolate milk eggnog greeting card wrapping paper bow toy doll game puzzle snowglobe candle fireplace wreath poinsettia angel star night wish dream believe wonder hope joy peace season merry hohoho kaggle workshop"
sub_texts = [
    "candle fireplace wreath poinsettia angel star night wish dream", "from as peppermint candy fruitcake", "merry hohoho kaggle workshop",
    "the of and to in that", "card wrapping paper bow", "chocolate milk eggnog", "toy doll game puzzle", "hope joy peace season", 
]
fix_ids = [0]

In [4]:
scorer.get_perplexity(text)

80.69083499570341

In [5]:
for st in sub_texts:
    text = text.replace(st, "-".join(st.split()))
text, len(text.split())

('the-of-and-to-in-that have not you with we it from-as-peppermint-candy-fruitcake cookie chocolate-milk-eggnog greeting card-wrapping-paper-bow toy-doll-game-puzzle snowglobe candle-fireplace-wreath-poinsettia-angel-star-night-wish-dream believe wonder hope-joy-peace-season merry-hohoho-kaggle-workshop',
 19)

In [6]:
tokens = text.split()
np.array(tokens)[fix_ids]

array(['the-of-and-to-in-that'], dtype='<U62')

In [7]:
len(tokens) - len(fix_ids)

18

In [8]:
def beam_search(init_tokens, fix_ids, scorer, k=30, precomputed={}, n_iters=100):
    n = len(init_tokens)
    candidates = [init_tokens]
    best_score = np.inf
    best_text = None
    for _ in range(n_iters):
        start = time.time()
        scores = {}
        for candidate in candidates:
            for i in range(n):
                if i in fix_ids:
                    continue
                for j in range(i+1, n):
                    if j in fix_ids:
                        continue
                    s = list(candidate).copy()
                    s[i], s[j] = s[j], s[i]
                    s = tuple(s)
                    text = " ".join(s).replace("-", " ")
                    if text in precomputed:
                        score = precomputed[s]
                    else:
                        score = scorer.get_perplexity(text)
                        precomputed[s] = score
                    scores[s] = score
        scores = sorted(scores.items(), key=lambda x: x[1])[:k]
        candidates = [k for k, v in scores]
        if scores[0][1] < best_score:
            best_score = scores[0][1]
            best_text = scores[0][0]
        print(f"best score: {best_score}, top10 mean score: {np.mean([s for _, s in scores[:10]])}")
        print(" ".join(best_text).replace("-", " "))
        print(f"{time.time()-start} [s]")
    return best_score, best_text

In [9]:
%%time
best_score, best_text = beam_search(tokens, fix_ids, scorer, k=100, n_iters=20)

best score: 81.00664999449546, top10 mean score: 84.75049767662156
the of and to in that have not you with we it from as peppermint candy fruitcake cookie chocolate milk eggnog greeting card wrapping paper bow toy doll game puzzle snowglobe candle fireplace wreath poinsettia angel star night wish dream wonder believe hope joy peace season merry hohoho kaggle workshop
11.89622449874878 [s]
best score: 80.69083499570341, top10 mean score: 83.719461517798
the of and to in that have not you with we it from as peppermint candy fruitcake cookie chocolate milk eggnog greeting card wrapping paper bow toy doll game puzzle snowglobe candle fireplace wreath poinsettia angel star night wish dream believe wonder hope joy peace season merry hohoho kaggle workshop
1177.6079578399658 [s]
best score: 80.69083499570341, top10 mean score: 83.06876436643032
the of and to in that have not you with we it from as peppermint candy fruitcake cookie chocolate milk eggnog greeting card wrapping paper bow toy dol

In [11]:
best_score

80.37625124290746

In [10]:
" ".join(best_text).replace("-", " ")

'the of and to in that have not you with we it from as peppermint candy fruitcake chocolate milk eggnog cookie snowglobe toy doll game puzzle greeting card wrapping paper bow candle fireplace wreath poinsettia angel star night wish dream believe wonder hope joy peace season merry hohoho kaggle workshop'

In [31]:
output_dir = "./output"
target_id = 4
save_text(" ".join(best_text).replace("-", " "), best_score, target_id, output_dir=output_dir)