In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import numpy as np
from tqdm.auto import tqdm

from santa.operator import *
from santa.metrics import PerplexityCalculator

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
scorer = PerplexityCalculator(model_path="google/gemma-2-9b")
tokenizer = scorer.tokenizer
id2word = {v: k for k, v in tokenizer.vocab.items()}

Loading checkpoint shards: 100% 8/8 [00:09<00:00,  1.13s/it]


In [3]:
text = "advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle"
tokens = text.split()
assert len(tokens) == 100
scorer.get_perplexity(text)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


355.0732086869962

In [4]:
def check_validation(new_text):
    text = "advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh chimney workshop stocking ornament holly jingle beard naughty nice sing cheer and of the is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy workshop wonder believe dream hope peace joy merry season greeting card wrapping paper bow fireplace night cookie milk star wish wreath angel the to of and in that have it not with as you from we kaggle"
    tokens = text.split()
    new_tokens = new_text.split()
    assert len(tokens) == len(new_tokens), f"{len(new_tokens)}"
    assert len(text) == len(new_text), f"{len(text)}, {len(new_text)}"
    assert len(set(tokens) - set(new_tokens)) == 0
    assert len(set(new_tokens) - set(tokens)) == 0

## tokenizerのid順に並び替え

In [5]:
token2id = []
for token in tokens:
    token_id = tokenizer(" " + token)["input_ids"][1:]
    token2id.append((token, token_id))

In [6]:
init_solution = " ".join([token for token, _ in sorted(token2id, key=lambda x: (len(x[1]), x[1][0]))])
check_validation(init_solution)
init_solution

'the the the in of of to and and and is it that with as you from not we have star game family give sing night night visit season card hope paper believe nice wonder walk drive wish sleep dream peace bow doll eat milk jump laugh relax joy advent holiday toy chocolate polar gifts cookie workshop workshop angel cheer cheer puzzle candy candle ornament ornament greeting beard decorations fireplace fireplace bake merry elf wrapping wreath chimney chimney stocking naughty reindeer holly gingerbread carol peppermint sleigh magi jingle unwrap hohoho grinch mistletoe snowglobe fruitcake nutcracker kaggle yuletide scrooge eggnog poinsettia'

In [7]:
scorer.get_perplexity(init_solution)

329.15156408889504

## 出現頻度 → アルファベット順

In [8]:
counter = {}
for token in sorted(tokens):
    if token not in counter:
        counter[token] = 0
    counter[token] += 1
counter = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
new_tokens = []
for token, count in counter:
    new_tokens += [token] * count
init_solution = " ".join(new_tokens)
check_validation(init_solution)
init_solution

'and and and the the the cheer cheer chimney chimney fireplace fireplace night night of of ornament ornament workshop workshop advent angel as bake beard believe bow candle candy card carol chocolate cookie decorations doll dream drive eat eggnog elf family from fruitcake game gifts gingerbread give greeting grinch have hohoho holiday holly hope in is it jingle joy jump kaggle laugh magi merry milk mistletoe naughty nice not nutcracker paper peace peppermint poinsettia polar puzzle reindeer relax scrooge season sing sleep sleigh snowglobe star stocking that to toy unwrap visit walk we wish with wonder wrapping wreath you yuletide'

In [9]:
scorer.get_perplexity(init_solution)

76.21396633545747

## stopwords + seq1 + seq2
- seq1は各アルファベットの単語の長さ順で前半半分
- seq2は各アルファベットの単語の長さ順で後半半分

In [10]:
STOP_WORDS = [
    'we', 'that', 'as', 'it', 'with',
    'of', 'in', 'is', 'not', 'you',
    'from', 'and','to', 'the',
]

d = {"stop_words": []}
for token in sorted(tokens):
    if token in STOP_WORDS:
        d["stop_words"].append(token)
    else:
        if token[0] not in d:
            d[token[0]] = []
        d[token[0]].append(token)
stop_words = d.pop("stop_words")
seq1, seq2 = [], []
for alphabet, words in d.items():
    words = sorted(words, key=lambda x: len(x))
    w1, w2 = words[:len(words)//2], words[len(words)//2:]
    seq1 += w1
    seq2 += w2
init_solution = " ".join(stop_words + seq1 + seq2)
check_validation(init_solution)

In [11]:
init_solution

'and and and as from in is it not of of that the the the to we with you angel bow bake card candy carol cheer cheer doll dream eat family fireplace game give gifts have hope joy magi milk nice night ornament paper peace polar relax sing star sleep season walk wish wonder advent beard believe candle cookie chimney chimney chocolate drive decorations elf eggnog fireplace fruitcake grinch greeting gingerbread holly hohoho holiday jump jingle kaggle laugh merry mistletoe night naughty nutcracker ornament puzzle peppermint poinsettia reindeer sleigh scrooge stocking snowglobe toy unwrap visit wreath workshop workshop wrapping yuletide'

In [12]:
scorer.get_perplexity(init_solution)

101.00528710107709

## 局所解に人手を加えたもの
- ストップワードを最後にする
- 複数回出てくる単語は同じ場所にまとめる

In [15]:
init_solution = """
bake cheer cheer drive dream eat family game give grinch holiday hope jump laugh naughty nice night night peace puzzle relax scrooge season sing sleep toy unwrap visit walk wish wonder workshop workshop yuletide
and and and the the the of of from to is as in that it we with not you have
advent angel beard believe bow candy candle carol chimney chimney chocolate cookie decorations doll eggnog elf fireplace fireplace fruitcake gingerbread gifts greeting card holly hohoho jingle joy kaggle magi merry milk mistletoe nutcracker ornament ornament peppermint polar poinsettia reindeer sleigh snowglobe star stocking wreath wrapping paper
""".strip().replace("\n", " ")
check_validation(init_solution)
init_solution

'bake cheer cheer drive dream eat family game give grinch holiday hope jump laugh naughty nice night night peace puzzle relax scrooge season sing sleep toy unwrap visit walk wish wonder workshop workshop yuletide and and and the the the of of from to is as in that it we with not you have advent angel beard believe bow candy candle carol chimney chimney chocolate cookie decorations doll eggnog elf fireplace fireplace fruitcake gingerbread gifts greeting card holly hohoho jingle joy kaggle magi merry milk mistletoe nutcracker ornament ornament peppermint polar poinsettia reindeer sleigh snowglobe star stocking wreath wrapping paper'

In [16]:
scorer.get_perplexity(init_solution)

45.83804799995008

## 品詞でグルーピング＆アルファベット順にソート

In [17]:
import spacy

nlp = spacy.load("en_core_web_sm")

In [18]:
pos = []
for token in tokens:
    for t in nlp(token):
        if t.pos_ == "VERB":
            index = 1
        elif t.pos_ in ("NOUN", "PROPN"):
            index = 2
        else:
            index = 0
        pos += [(token, index, t.pos_)]
pos = sorted(pos, key=lambda x: (x[1], x[0]))

In [19]:
pos

[('and', 0, 'CCONJ'),
 ('and', 0, 'CCONJ'),
 ('and', 0, 'CCONJ'),
 ('as', 0, 'ADP'),
 ('elf', 0, 'PRON'),
 ('from', 0, 'ADP'),
 ('holly', 0, 'ADV'),
 ('in', 0, 'ADP'),
 ('is', 0, 'AUX'),
 ('it', 0, 'PRON'),
 ('nice', 0, 'ADJ'),
 ('not', 0, 'PART'),
 ('of', 0, 'ADP'),
 ('of', 0, 'ADP'),
 ('ornament', 0, 'ADJ'),
 ('ornament', 0, 'ADJ'),
 ('poinsettia', 0, 'ADV'),
 ('polar', 0, 'ADJ'),
 ('that', 0, 'SCONJ'),
 ('the', 0, 'PRON'),
 ('the', 0, 'PRON'),
 ('the', 0, 'PRON'),
 ('to', 0, 'PART'),
 ('we', 0, 'PRON'),
 ('with', 0, 'ADP'),
 ('you', 0, 'PRON'),
 ('bake', 1, 'VERB'),
 ('believe', 1, 'VERB'),
 ('bow', 1, 'VERB'),
 ('drive', 1, 'VERB'),
 ('eat', 1, 'VERB'),
 ('fruitcake', 1, 'VERB'),
 ('give', 1, 'VERB'),
 ('greeting', 1, 'VERB'),
 ('have', 1, 'VERB'),
 ('kaggle', 1, 'VERB'),
 ('laugh', 1, 'VERB'),
 ('merry', 1, 'VERB'),
 ('naughty', 1, 'VERB'),
 ('peppermint', 1, 'VERB'),
 ('relax', 1, 'VERB'),
 ('sing', 1, 'VERB'),
 ('visit', 1, 'VERB'),
 ('walk', 1, 'VERB'),
 ('wish', 1, 'VERB'),
 (

In [21]:
init_solution = " ".join([x[0] for x in pos])
check_validation(init_solution)
init_solution

'and and and as elf from holly in is it nice not of of ornament ornament poinsettia polar that the the the to we with you bake believe bow drive eat fruitcake give greeting have kaggle laugh merry naughty peppermint relax sing visit walk wish wrapping yuletide advent angel beard candle candy card carol cheer cheer chimney chimney chocolate cookie decorations doll dream eggnog family fireplace fireplace game gifts gingerbread grinch hohoho holiday hope jingle joy jump magi milk mistletoe night night nutcracker paper peace puzzle reindeer scrooge season sleep sleigh snowglobe star stocking toy unwrap wonder workshop workshop wreath'

In [22]:
scorer.get_perplexity(init_solution)

184.02509883139066

## 局所解に人手を加えたもの
- ストップワードを真ん中に置く
- 複数回出てくる単語は同じ場所にまとめる

In [23]:
init_solution = """
bake cheer cheer drive dream eat family game give grinch holiday hope jump laugh naughty nice night night peace puzzle relax scrooge season sing sleep toy unwrap visit walk wish wonder workshop workshop yuletide
and and and the the the of of from to is as in that it we with not you have
advent angel beard believe bow candy candle carol chimney chimney chocolate cookie decorations doll eggnog elf fireplace fireplace fruitcake gingerbread gifts greeting card holly hohoho jingle joy kaggle magi merry milk mistletoe nutcracker ornament ornament peppermint polar poinsettia reindeer sleigh snowglobe star stocking wreath wrapping paper
""".strip().replace("\n", " ")
check_validation(init_solution)
init_solution

'bake cheer cheer drive dream eat family game give grinch holiday hope jump laugh naughty nice night night peace puzzle relax scrooge season sing sleep toy unwrap visit walk wish wonder workshop workshop yuletide and and and the the the of of from to is as in that it we with not you have advent angel beard believe bow candy candle carol chimney chimney chocolate cookie decorations doll eggnog elf fireplace fireplace fruitcake gingerbread gifts greeting card holly hohoho jingle joy kaggle magi merry milk mistletoe nutcracker ornament ornament peppermint polar poinsettia reindeer sleigh snowglobe star stocking wreath wrapping paper'

In [24]:
scorer.get_perplexity(init_solution)

45.83804799995008