In [2]:
import re
from pathlib import Path

from scipy.stats import spearmanr
from sympy.combinatorics import Permutation

In [3]:
def get_token2id(text):
    token2id = {}
    tokens = text.split()
    for i in range(len(tokens)):
        token = tokens[i]
        if token not in token2id:
            token2id[token] = len(token2id)
        else:
            j = 2
            new_token = f"{token}_{j}"
            while new_token in token2id:
                j += 1
                new_token = f"{token}_{j}"
            token2id[new_token] = len(token2id)
    return token2id


def tokens2order(tokens, token2id):
    token2id = token2id.copy()
    order = []
    for token in tokens:
        if token in token2id:
            order.append(token2id.pop(token))
        else:
            j = 2
            new_token = f"{token}_{j}"
            while new_token not in token2id:
                j += 1
                new_token = f"{token}_{j}"
            order.append(token2id.pop(new_token))
    return order


def order2token(order, id2token):
    tokens = []
    for i in order:
        tokens.append(id2token[i].split("_")[0])
    return tokens


text = "sleigh of the magi yuletide cheer is unwrap gifts relax and eat cheer decorations carol sing chimney visit workshop grinch holiday holly jingle naughty nice nutcracker polar beard ornament stocking"
tokens = text.split()
token2id = get_token2id(text)
id2token = {v: k for k, v in token2id.items()}
order = tokens2order(tokens, token2id)
len(token2id)

30

In [4]:
def load_file(filename):
    score = float(str(filename).split("_")[1].split(".txt")[0])
    with open(filename) as f:
        text = f.readline().strip()
    return text, score


target_id = 3
dir = Path("./output/")
files = sorted(dir.glob(f"id{target_id}_0*.txt"))

best_text, best_score = load_file(files[0])
best_tokens = best_text.split()
token2id = get_token2id(best_text)
id2token = {v: k for k, v in token2id.items()}
best_order = tokens2order(best_tokens, token2id)

In [5]:
texts = []
transpositions = []
for i, filename in enumerate(files):
    text, score = load_file(filename)
    texts.append(text)
    tokens = text.split(" ")
    order = tokens2order(tokens, token2id)
    corr, pvalue = spearmanr(order, best_order)
    p = Permutation(order)
    transpositions.append(p.transpositions())
    num_swap = len(transpositions[-1])
    print(f"[id {i:>02}] corr={corr:.2f}, p={pvalue:.2f}, n_swaps={num_swap:>3}, score={score:.5f}, diff={score - best_score:.5f}")

[id 00] corr=1.00, p=0.00, n_swaps=  0, score=197.25009, diff=0.00000
[id 01] corr=1.00, p=0.00, n_swaps=  0, score=197.63808, diff=0.38799
[id 02] corr=0.99, p=0.00, n_swaps=  5, score=199.65241, diff=2.40232
[id 03] corr=0.99, p=0.00, n_swaps=  5, score=199.99238, diff=2.74229
[id 04] corr=0.95, p=0.00, n_swaps= 15, score=200.68898, diff=3.43889
[id 05] corr=0.48, p=0.01, n_swaps= 21, score=201.52058, diff=4.27049
[id 06] corr=0.92, p=0.00, n_swaps= 17, score=201.70189, diff=4.45180
[id 07] corr=0.92, p=0.00, n_swaps= 17, score=202.65870, diff=5.40861
[id 08] corr=0.52, p=0.00, n_swaps= 22, score=203.34542, diff=6.09533
[id 09] corr=0.56, p=0.00, n_swaps= 15, score=203.68400, diff=6.43391
[id 10] corr=0.56, p=0.00, n_swaps= 16, score=203.73315, diff=6.48306
[id 11] corr=0.66, p=0.00, n_swaps= 19, score=204.16485, diff=6.91476
[id 12] corr=0.93, p=0.00, n_swaps= 20, score=204.19542, diff=6.94533
[id 13] corr=0.49, p=0.01, n_swaps= 18, score=204.53513, diff=7.28504
[id 14] corr=0.49, p

In [30]:
texts[0]

'sleigh of the magi yuletide cheer is unwrap gifts relax and eat cheer decorations carol sing chimney visit workshop grinch holiday holly jingle naughty nice nutcracker polar beard ornament stocking'

In [32]:
texts[2]

'sleigh of the magi yuletide cheer is unwrap gifts and eat cheer decorations sing carol relax chimney visit workshop grinch holiday holly jingle naughty nice nutcracker polar beard ornament stocking'

### 動作確認

In [7]:
def swap(x, i, j):
    x[i], x[j] = x[j], x[i]
    return x


tmp = best_order[:10]
print(tmp)
tmp = swap(tmp, 0, 1)
tmp = swap(tmp, 3, 1)
tmp = swap(tmp, 0, 6)
tmp = swap(tmp, 0, 7)
tmp = swap(tmp, 0, 1)
tmp = swap(tmp, 8, 9)
tmp

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


[3, 7, 2, 0, 4, 5, 1, 6, 9, 8]

In [8]:
p = Permutation(tmp)
p.transpositions()

[(0, 3), (1, 6), (1, 7), (8, 9)]

In [9]:
for k in p.transpositions()[::-1]:
    tmp = swap(tmp, k[0], k[1])
tmp

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [10]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
sys.path.append("..")
from santa.metrics import PerplexityCalculator

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
scorer = PerplexityCalculator(model_path="google/gemma-2-9b")

Loading checkpoint shards: 100% 8/8 [00:08<00:00,  1.07s/it]


In [28]:
def get_trajectory(text, trans):
    t = trans
    s = text.split()
    score = scorer.get_perplexity(" ".join(s))
    score_transition = [score]
    for i, j in t[::-1]:
        s = swap(s, i, j)
        score = scorer.get_perplexity(" ".join(s))
        score_transition.append(score)
    return score_transition


i = 10
score_trj = get_trajectory(texts[i], transpositions[i])

In [29]:
score_trj

[203.7331572188011,
 339.0267089547947,
 334.93140280232285,
 252.10140552708296,
 342.99475274376255,
 346.8848305954417,
 388.39830192813577,
 249.6388415856473,
 307.7621369079633,
 365.65996147162645,
 379.3507746057428,
 386.6730920759111,
 586.1533293042395,
 299.6744744120968,
 458.33667289434305,
 389.2296394660958,
 197.25009976693568]

In [20]:
texts[0]

'from the the the of of and and not and to in is you that it with as advent card angel bake beard believe bow candy carol candle cheer cheer chocolate chimney chimney cookie decorations doll dream drive eat eggnog elf family fireplace fireplace fruitcake game gifts give gingerbread greeting grinch have holiday holly hohoho hope jingle jump joy kaggle laugh magi merry milk mistletoe naughty nice night night nutcracker ornament ornament paper peace peppermint polar poinsettia puzzle reindeer relax scrooge season sing sleigh sleep snowglobe star stocking toy unwrap visit walk we wish wonder workshop workshop wrapping wreath yuletide'

In [21]:
texts[-1]

'and and and as from have in is it not of of that the the the to we with you advent angel bake beard believe bow candy card carol candle cheer cheer chocolate chimney chimney cookie decorations doll dream drive eat eggnog elf family fireplace fireplace fruitcake game gifts give gingerbread greeting grinch holiday holly hohoho hope jingle jump joy kaggle laugh magi merry milk mistletoe naughty nice night night nutcracker ornament ornament paper peace peppermint polar poinsettia puzzle reindeer relax scrooge season sing sleigh sleep snowglobe star stocking toy unwrap visit walk wish wonder workshop workshop wrapping wreath yuletide'

In [41]:
texts[0]

'from and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season greeting card wrapping paper bow toy doll game puzzle cookie snowglobe fireplace candle wreath poinsettia angel star wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'

In [39]:
text_transition[-1]

'from and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season greeting card wrapping paper bow toy doll game puzzle cookie snowglobe fireplace candle wreath poinsettia angel star wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'

In [42]:
text_transition[-2]

'greeting and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season from card wrapping paper bow toy doll game puzzle cookie snowglobe fireplace candle wreath poinsettia angel star wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'