In [1]:
import re
from pathlib import Path

from scipy.stats import spearmanr
from sympy.combinatorics import Permutation

In [2]:
target_id = 4
dir = Path("./output/")
files = sorted(dir.glob(f"id{target_id}_*.txt"))
len(files)

41

In [21]:
best_order = None
best_score = None
token2id = {}
texts = []
transpositions = []
for i, filename in enumerate(files[0:]):
    score = float(str(filename).split("_")[1].split(".txt")[0])
    with open(filename) as f:
        text = f.readline().strip()
    texts.append(text)
    tokens = text.split(" ")
    if len(token2id) == 0:
        for token in tokens:
            token2id[token] = len(token2id)
    order = [token2id[token] for token in tokens]
    if best_order is None:
        best_order = order
        best_score = score
    corr, pvalue = spearmanr(order, best_order)
    p = Permutation(order)
    transpositions.append(p.transpositions())
    num_swap = len(transpositions[-1])
    print(f"[id {i:>02}] {corr:.3f}, {pvalue:.3f}, {num_swap:>3}, {score:.5f}, {score - best_score:.5f}")

[id 00] 1.000, 0.000,   0, 70.20242, 0.00000
[id 01] 0.975, 0.000,  10, 71.22486, 1.02244
[id 02] 0.915, 0.000,  35, 71.96499, 1.76257
[id 03] 0.915, 0.000,  35, 72.00041, 1.79799
[id 04] 0.915, 0.000,  37, 72.26725, 2.06483
[id 05] 0.915, 0.000,  34, 72.34945, 2.14703
[id 06] 0.914, 0.000,  34, 72.37530, 2.17288
[id 07] 0.915, 0.000,  35, 72.41000, 2.20758
[id 08] 0.914, 0.000,  36, 72.53948, 2.33706
[id 09] 0.914, 0.000,  36, 72.64883, 2.44641
[id 10] 0.914, 0.000,  35, 73.94579, 3.74337
[id 11] 0.878, 0.000,  38, 74.50227, 4.29985
[id 12] 0.879, 0.000,  36, 74.82051, 4.61809
[id 13] 0.898, 0.000,  38, 76.52955, 6.32713
[id 14] 0.933, 0.000,  35, 79.45131, 9.24889
[id 15] 0.928, 0.000,  38, 79.75075, 9.54833
[id 16] 0.927, 0.000,  38, 80.06289, 9.86047
[id 17] 0.932, 0.000,  35, 80.37625, 10.17383
[id 18] 0.927, 0.000,  38, 80.62890, 10.42648
[id 19] 0.921, 0.000,  36, 81.64199, 11.43957
[id 20] 0.920, 0.000,  36, 82.28231, 12.07989
[id 21] 0.900, 0.000,  36, 82.92766, 12.72524
[id 2

In [22]:
texts[0]

'from and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season greeting card wrapping paper bow toy doll game puzzle cookie snowglobe fireplace candle wreath poinsettia angel star wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'

In [23]:
texts[1]

'from and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season greeting card wrapping paper wreath bow poinsettia star angel snowglobe candle fireplace toy doll game puzzle cookie wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'

### 動作確認

In [24]:
def swap(x, i, j):
    x[i], x[j] = x[j], x[i]
    return x


tmp = best_order[:10]
print(tmp)
tmp = swap(tmp, 0, 1)
tmp = swap(tmp, 3, 1)
tmp = swap(tmp, 0, 6)
tmp = swap(tmp, 0, 7)
tmp = swap(tmp, 0, 1)
tmp = swap(tmp, 8, 9)
tmp

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


[3, 7, 2, 0, 4, 5, 1, 6, 9, 8]

In [25]:
p = Permutation(tmp)
p.transpositions()

[(0, 3), (1, 6), (1, 7), (8, 9)]

In [26]:
for k in p.transpositions()[::-1]:
    tmp = swap(tmp, k[0], k[1])
tmp

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [9]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
sys.path.append("..")
from santa.metrics import PerplexityCalculator

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
scorer = PerplexityCalculator(model_path="google/gemma-2-9b")

Loading checkpoint shards: 100% 8/8 [00:08<00:00,  1.01s/it]


In [33]:
for index in range(len(texts)):
    # index = 1
    t = transpositions[index]
    s = texts[index].split(" ")
    score = scorer.get_perplexity(" ".join(s))
    score_transition = [score]
    ## print(score)
    for i, j in t[::-1]:
        s = swap(s, i, j)
        score = scorer.get_perplexity(" ".join(s))
        score_transition.append(score)
        # print(score)
    # " ".join(s)
    # print(min(score_transition))
    print(score_transition[-3:])

[70.20242311605584]
[86.62284563520333, 77.95840558065296, 70.20242311605584]
[124.68123553920599, 100.9943546749491, 70.20242311605584]
[124.68123553920599, 100.9943546749491, 70.20242311605584]
[124.68123553920599, 100.9943546749491, 70.20242311605584]
[124.68123553920599, 100.9943546749491, 70.20242311605584]
[124.68123553920599, 100.9943546749491, 70.20242311605584]
[124.68123553920599, 100.9943546749491, 70.20242311605584]
[124.68123553920599, 100.9943546749491, 70.20242311605584]
[124.68123553920599, 100.9943546749491, 70.20242311605584]
[107.22946850897647, 122.95484439026072, 70.20242311605584]
[107.22946850897647, 122.95484439026072, 70.20242311605584]
[107.22946850897647, 122.95484439026072, 70.20242311605584]
[86.68048526687859, 106.30407660829488, 70.20242311605584]
[107.83847422924161, 122.95484439026072, 70.20242311605584]
[107.83847422924161, 122.95484439026072, 70.20242311605584]
[117.00085963090497, 122.95484439026072, 70.20242311605584]
[117.00085963090497, 122.954844

In [36]:
texts[0]

'from and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season greeting card wrapping paper bow toy doll game puzzle cookie snowglobe fireplace candle wreath poinsettia angel star wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'

In [35]:
texts[-1]

'peppermint candy chocolate eggnog fruitcake poinsettia wreath star angel snowglobe we kaggle the of and to in that it with as from have not you workshop greeting card wrapping paper bow toy doll puzzle game night candle fireplace cookie milk wish dream hope believe wonder joy peace season merry hohoho'

In [38]:
index = 40
t = transpositions[index]
s = texts[index].split(" ")
score = scorer.get_perplexity(" ".join(s))
score_transition = [score]
text_transition = [" ".join(s)]
print(score)
for i, j in t[::-1]:
    s = swap(s, i, j)
    score = scorer.get_perplexity(" ".join(s))
    score_transition.append(score)
    text_transition.append(" ".join(s))
    print(score)
# " ".join(s)
# print(min(score_transition))
# print(score_transition[-3:])

99.16338626091706
100.28532287020514
152.29643293929283
148.30513431132994
149.03346747099104
168.04099600611818
202.49210404218502
206.16704003647308
214.35986844727924
331.4055391392646
563.4468831796185
591.0935052922482
583.0729801382863
724.6035624055621
722.6416085269523
723.3914639620258
807.190841887026
761.1021136349358
871.0813198779329
956.5077386081596
990.6220435799005
953.9078685096837
923.4255608179376
1016.2443823212827
1039.212001173891
1088.7369401962362
1015.1696611418253
932.9216388542737
908.1351952417449
748.9047550329993
863.2682682288158
675.0068399609565
537.2360187457363
483.5043511739331
460.1834321202282
466.6301153261619
447.7325518287056
486.95042773976354
389.445736848346
350.3374147830001
220.98706951550406
268.0654979884322
219.19823465601223
185.34571390985008
184.65824882301823
70.20242311605584


In [41]:
texts[0]

'from and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season greeting card wrapping paper bow toy doll game puzzle cookie snowglobe fireplace candle wreath poinsettia angel star wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'

In [39]:
text_transition[-1]

'from and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season greeting card wrapping paper bow toy doll game puzzle cookie snowglobe fireplace candle wreath poinsettia angel star wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'

In [42]:
text_transition[-2]

'greeting and of to the as in that it we with not you have milk chocolate candy fruitcake eggnog peppermint season from card wrapping paper bow toy doll game puzzle cookie snowglobe fireplace candle wreath poinsettia angel star wish dream night wonder believe hope joy peace merry hohoho kaggle workshop'