In [8]:
import pandas as pd
import numpy as np
import sys

sys.path.append("../")
from evaluation import PerplexityCalculator
from util import save_text
path_model = "/home/task/.cache/kagglehub/models/google/gemma-2/transformers/gemma-2-9b/2"
path_input = "../input/santa-2024/sample_submission.csv"
df_sample = pd.read_csv(path_input)

In [2]:
# text = "sleigh of holiday cheer unwrap gifts relax eat yuletide cheer sing carol the magi visit workshop grinch is naughty and nice decorations ornament chimney stocking nutcracker polar beard holly jingle"
text = "magi yuletide cheer grinch carol holiday holly jingle naughty nice nutcracker polar beard ornament stocking chimney sleigh workshop gifts decorations"

In [3]:
scorer = PerplexityCalculator(model_path=str(path_model))

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [5]:
scorer.get_perplexity(text)

302.15420510895507

In [6]:
words = text.split()
len(words)

20

In [9]:
import copy
import random
import tqdm

In [10]:
def make_neighbor_1(words_input: list[str]) -> list[str]:
    """ランダムに単語を選んでランダムな箇所に挿入"""
    words = words_input.copy()
    idx = random.randint(0, len(words) - 1)
    word = words.pop(idx)
    idx_insert = random.randint(0, len(words))
    words.insert(idx_insert, word)
    return words


def make_neighbor_2(words_input: list[str]) -> list[str]:
    """ランダムに単語の列を選んでランダムな箇所に挿入"""
    words = words_input.copy()
    idx1 = random.randint(0, len(words))
    idx2 = random.randint(0, len(words))
    if idx1 == idx2:
        return None
    if idx1 > idx2:
        idx1, idx2 = idx2, idx1
    words_mid = words[idx1:idx2]
    words_rest = words[:idx1] + words[idx2:]
    idx_insert = random.randint(0, len(words_rest))
    words_new = words_rest[:idx_insert] + words_mid + words_rest[idx_insert:]
    return words_new


def make_neighbor_3(words_input: list[str]) -> list[str]:
    """ランダムに単語の列を選んで先頭か末尾に移動"""
    words = words_input.copy()
    idx1 = random.randint(1, len(words))
    idx2 = random.randint(1, len(words))
    if idx1 == idx2:
        return None
    if idx1 > idx2:
        idx1, idx2 = idx2, idx1
    words1 = words[:idx1]
    words2 = words[idx1:idx2]
    words3 = words[idx2:]
    coin = random.randint(1, 2)
    if coin == 1:
        words_new = words1 + words3 + words2
    else:
        words_new = words2 + words1 + words3
    return words_new


def make_neighbor_4(words_input: list[str]) -> list[str]:
    """Rotate"""
    words = words_input.copy()
    idx = random.randint(1, len(words) - 1)
    words_new = words[idx:] + words[:idx]
    return words_new


def make_neighbor_5(words_input: list[str]) -> list[str]:
    """隣接した単語を入れ替える"""
    words = words_input.copy()
    idx = random.randint(0, len(words) - 2)
    words[idx], words[idx + 1] = words[idx + 1], words[idx]
    return words


def make_neighbor_6(words_input: list[str]) -> list[str]:
    """区間を反転する"""
    words = words_input.copy()
    idx1 = random.randint(0, len(words))
    idx2 = (idx1 + random.randint(2, len(words) - 2)) % len(words)
    if idx1 > idx2:
        idx1, idx2 = idx2, idx1
    words[idx1:idx2] = words[idx1:idx2][::-1]
    return words


def make_neighbor_7(words_input: list[str]) -> list[str]:
    """ランダムな 2 単語を入れ替える"""
    words = words_input.copy()
    idx1 = random.randint(0, len(words) - 1)
    idx2 = (idx1 + random.randint(1, len(words) - 1)) % len(words)
    words[idx1], words[idx2] = words[idx2], words[idx1]
    return words

neighbor_prob = {
            1: 10.0,
            2: 5.0,
            3: 5.0,
            4: 1.0,
            5: 5.0,
            6: 1.0,
            7: 1.0,
        }

prob_total = sum(neighbor_prob.values())
for key in neighbor_prob:
    neighbor_prob[key] = neighbor_prob[key] / prob_total

def make_neighbor(
    words_input: list[str], neighbor_prob: dict[int, float] = neighbor_prob
) -> tuple[list[str], int]:
    """ランダムに操作を行う"""
    words_return = None
    while words_return is None:
        coin = int(
            np.random.choice(list(neighbor_prob.keys()), p=list(neighbor_prob.values()))
        )
        if coin == 1:
            words_return = make_neighbor_1(words_input)
        elif coin == 2:
            words_return = make_neighbor_2(words_input)
        elif coin == 3:
            words_return = make_neighbor_3(words_input)
        elif coin == 4:
            words_return = make_neighbor_4(words_input)
        elif coin == 5:
            words_return = make_neighbor_5(words_input)
        elif coin == 6:
            words_return = make_neighbor_6(words_input)
        elif coin == 7:
            words_return = make_neighbor_7(words_input)
        else:
            raise ValueError("Invalid neighbor function coin")
        if words_return == words_input:
            words_return = None
    assert sorted(words_input) == sorted(words_return)
    return words_return, coin

In [244]:
# create some initial solutions
n_words = len(words)
n_pop = 100
pop = []
for _ in range(n_pop):
    random.shuffle(words)
    pop.append(copy.deepcopy(words))

In [245]:
len(words)

20

In [246]:
pop = []
for _ in range(n_pop):
    random.shuffle(words)
    words_best = copy.deepcopy(words)
    score_best = scorer.get_perplexity(" ".join(words_best))
    pbar = tqdm.tqdm(range(100))
    for _ in pbar:
        # 16個の近傍を生成
        candidates = []
        scores = []
        for _ in range(16):
            words_tmp, _ = make_neighbor(words_best)
            candidates.append(words_tmp)
            scores.append(scorer.get_perplexity(" ".join(words_tmp)))
        
        # 元の解も含めて最良のものを選択
        candidates.append(words_best)
        scores.append(score_best)
        best_idx = min(range(len(scores)), key=lambda i: scores[i])
        words_best = candidates[best_idx]
        score_best = scores[best_idx]
        pbar.set_postfix({"score": score_best})
    pop.append(words_best)

100%|██████████| 100/100 [01:50<00:00,  1.11s/it, score=455]  
100%|██████████| 100/100 [01:50<00:00,  1.10s/it, score=423]  
100%|██████████| 100/100 [01:50<00:00,  1.10s/it, score=372]  
100%|██████████| 100/100 [01:51<00:00,  1.11s/it, score=363]  
100%|██████████| 100/100 [01:51<00:00,  1.11s/it, score=366]  
100%|██████████| 100/100 [01:50<00:00,  1.10s/it, score=337]  
100%|██████████| 100/100 [01:49<00:00,  1.10s/it, score=378]  
100%|██████████| 100/100 [01:49<00:00,  1.10s/it, score=400]  
100%|██████████| 100/100 [01:49<00:00,  1.10s/it, score=445]  
100%|██████████| 100/100 [01:50<00:00,  1.11s/it, score=351]  
100%|██████████| 100/100 [01:49<00:00,  1.10s/it, score=362]  
100%|██████████| 100/100 [01:50<00:00,  1.10s/it, score=369]  
100%|██████████| 100/100 [01:49<00:00,  1.10s/it, score=397]  
100%|██████████| 100/100 [01:49<00:00,  1.10s/it, score=362]  
100%|██████████| 100/100 [01:49<00:00,  1.10s/it, score=389]  
100%|██████████| 100/100 [01:49<00:00,  1.10s/it, score

In [247]:
pd.to_pickle(pop, "pop_0002_100samples.pkl")

In [28]:
for p in pop:
    print(" ".join(p))

jingle grinch holiday eat sing relax unwrap decorations gifts cheer yuletide ornament carol holly stocking magi nutcracker chimney naughty nice sleigh polar beard workshop visit and the of cheer is
ornament holiday cheer cheer and eat relax unwrap decorations yuletide chimney stocking magi carol sing grinch the of is nutcracker naughty nice holly jingle sleigh visit workshop polar beard gifts
grinch of the eat gifts sing yuletide carol holiday cheer jingle relax unwrap stocking decorations ornament holly naughty and nice nutcracker magi visit chimney beard sleigh workshop polar is cheer
sleigh holiday cheer cheer yuletide holly the magi of workshop polar beard unwrap chimney visit nutcracker is naughty nice gifts and grinch eat jingle relax carol sing ornament decorations stocking
ornament yuletide stocking gifts of the magi carol grinch holiday cheer holly unwrap cheer and eat relax sing chimney sleigh naughty visit polar workshop jingle beard is nice nutcracker decorations
decoration

In [44]:
from util import get_perplexity_, load_score_memo, save_score_memo
import math

In [45]:
class GAEAX:
    def __init__(self, words: list[str], words_ppls_map: dict[str, str], scorer, idx: int, population_size: int = 10, n_offspring: int = 16, generations: int = 50, initial_pop: list[list[str]] = None, mode: int = 1, decay_rate: float = 0.95):
        self.idx = idx
        self.words = words
        self.idx_words = {w: i for i, w in enumerate(words)}
        self.n_words = len(words)
        self.pheromone = np.zeros((self.n_words, self.n_words))
        self.decay_rate = decay_rate
        self.words_ppls_map = words_ppls_map
        self.population_size = population_size
        self.n_offspring = n_offspring
        self.generations = generations
        self.calculator = scorer
        self.score_memo = {}
        self.score_memo_with_error = {}
        self.pop = []
        if initial_pop is not None:
            self.set_population(copy.deepcopy(initial_pop))
        self.mode = mode

    def set_population(self, pop: list[list[str]]):
        self.pop = pop

    def _calc_perplexity(self, words: list[str]) -> float:
        idx_bos = words.index("<bos>")
        words_use = words[idx_bos+1:] + words[:idx_bos]
        words_use = [self.words_ppls_map[w] for w in words_use]
        return get_perplexity_(self.calculator, self.score_memo, self.score_memo_with_error, " ".join(words_use))

    def calculate_edge_entropy(self, population: list[list[str]]) -> float:
        """集団の枝エントロピーを計算"""
        edge_count = {}
        total_edges = 0
        
        for solution in population:
            for i in range(len(solution) - 1):
                edge = (solution[i], solution[i + 1])
                if edge not in edge_count:
                    edge_count[edge] = 0
                edge_count[edge] += 1
                total_edges += 1

            edge = (solution[-1], solution[0])
            if edge not in edge_count:
                edge_count[edge] = 0
            edge_count[edge] += 1
            total_edges += 1
        
        entropy = 0.0
        for count in edge_count.values():
            p = count / total_edges
            entropy -= p * math.log(p)
        
        return entropy

    def edge_assembly_crossover(self, parent1_words: list[str], parent2_words: list[str]) -> list[str]:
        """Edge Assembly Crossoverを実装"""
        n = len(parent1_words)
        parent1 = [self.idx_words[w] for w in parent1_words]
        parent2 = [self.idx_words[w] for w in parent2_words]
        
        if set(parent1) != set(range(n)):
            print(parent1)
            print(set(parent1))
            print(parent1_words)
        if set(parent2) != set(range(n)):
            print(parent2)
            print(set(parent2))
            print(parent2_words)
        assert set(parent1) == set(range(n))
        assert set(parent2) == set(range(n))

        # 両親からエッジ情報を抽出
        edges1 = {}
        edges1_inv = {}
        edges2 = {}
        # parent1のエッジ情報を構築
        for i in range(len(parent1)):
            curr = parent1[i]
            next_word = parent1[(i + 1) % n]
            edges1[curr] = next_word
            edges1_inv[next_word] = curr

        # parent2のエッジ情報を構築
        for i in range(len(parent2)):
            curr = parent2[i]
            next_word = parent2[(i + 1) % n]
            edges2[curr] = next_word

        # ABサイクルの構築
        esets = []
        usedA = set()
        usedB = set()
        for i in range(n):
            eset_curr = []
            curr = i
            usedA.add(curr)
            eset_curr.append(curr)

            is_a = True
            while True:
                if is_a:
                    nxt = edges1_inv[curr]
                else:
                    nxt = edges2[curr]
                if is_a and nxt in usedB:
                    if len(eset_curr) > 2:
                        esets.append(eset_curr)
                    break
                elif not is_a and nxt in usedA:
                    if len(eset_curr) > 2:
                        esets.append(eset_curr)
                    break
                if is_a:
                    usedB.add(nxt)
                else:
                    usedA.add(nxt)

                eset_curr.append(nxt)
                curr = nxt
                is_a = not is_a

        if len(esets) == 0: # parent 1 == parent 2
            return parent1_words

        # choose one eset
        if self.mode == 1:
            idx_eset = random.randint(0, len(esets) - 1)
            eset = esets[idx_eset]
            edge_remove = set()
            edge_add = {}
            for i in range(len(eset)):
                curr = eset[i]
                next_word = eset[(i + 1) % len(eset)]
                if i % 2 == 0:
                    edge_remove.add((next_word, curr)) # from parent 1 inv
                else:
                    edge_add[curr] = next_word # from parent 2
        else: # mode == 2, global EAX mode
            edge_remove = set()
            edge_add = {}
            for eset in esets:
                if random.random() < 0.5:
                    continue
                for i in range(len(eset)):
                    curr = eset[i]
                    next_word = eset[(i + 1) % len(eset)]
                    if i % 2 == 0:
                        edge_remove.add((next_word, curr)) # from parent 1 inv
                    else:
                        edge_add[curr] = next_word # from parent 2

        edges_offspring = edges1.copy()
        for k, v in edge_add.items():
            edges_offspring[k] = v
            # automatically removed edges from 1 inv

        # for i in range(n):
        #     edges_offspring[i]
        #     edge = (i, edges1[i])
        #     if edge not in edge_remove:
        #         edges_offspring[i] = edges1[i]
        #     else:
        #         edges_offspring[i] = edge_add[i]
        while True:
            # Find disconnected cycles
            cycles = []
            used = set()
            for start in range(n):
                if start in used:
                    continue
                cycle = []
                curr = start
                while curr not in used:
                    cycle.append(curr)
                    used.add(curr)
                    curr = edges_offspring[curr]
                if len(cycle) > 1:
                    cycles.append(cycle)

            if len(cycles) <= 1:
                break

            # Select two cycles and connect them
            c1, c2 = random.sample(cycles, 2)
            
            # Find best connection points based on pheromone
            best_i = best_j = 0
            if random.random() < 0.5:
                max_pheromone = float('-inf')
                for i in range(len(c1)):
                    for j in range(len(c2)):
                        i_nxt = (i + 1) % len(c1)
                        j_nxt = (j + 1) % len(c2)
                        pheromone = - self.pheromone[c1[i]][c1[i_nxt]] - self.pheromone[c2[j]][c2[j_nxt]] \
                            + self.pheromone[c1[i]][c2[j_nxt]] + self.pheromone[c2[j]][c1[i_nxt]]
                        if pheromone > max_pheromone:
                            max_pheromone = pheromone
                            best_i = i
                            best_j = j
            else:
                best_i = random.randint(0, len(c1)-1)
                best_j = random.randint(0, len(c2)-1)
            # Reconnect the cycles
            edges_offspring[c1[best_i]] = c2[(best_j+1)%len(c2)]
            edges_offspring[c2[best_j]] = c1[(best_i+1)%len(c1)]

        curr = 0
        offspring = [curr]
        for _ in range(n - 1):
            curr = edges_offspring[curr]
            offspring.append(curr)

        if set(offspring) != set(range(n)):
            print("bug - offspring is broken. use parent1_words", offspring, parent1_words)
            return parent1_words

        assert set(offspring) == set(range(n))
        return [self.words[i] for i in offspring]

    def update_pheromone(self, population: list[list[str]]):
        """集団のpheromoneを更新"""
        # Decay existing pheromone
        # self.decay_rate = 0.95  # フェロモンの蒸発率
        self.pheromone *= self.decay_rate
        
        # Add new pheromone based on solutions
        for solution in population:
            score = self._calc_perplexity(solution)
            for i in range(len(solution)):
                word1 = solution[i]
                word2 = solution[(i + 1) % len(solution)]
                idx1 = self.idx_words[word1]
                idx2 = self.idx_words[word2]
                self.pheromone[idx1][idx2] += 1.0 / score

    def run(self):
        """GA-EAXメインループ"""
        # 初期集団の生成
        if not self.pop:
            for _ in range(self.population_size):
                solution = self.words.copy()
                random.shuffle(solution[1:])  # 先頭は固定
                self.pop.append(solution)
        
        scores = [self._calc_perplexity(sol) for sol in self.pop]
        best_solution = min(zip(self.pop, scores), key=lambda x: x[1])
        self.update_pheromone(self.pop)

        pbar = tqdm.tqdm(range(self.generations), desc="GA-EAX")
        curr_gen_best = 1000000
        for gen in pbar:
            # 全ての親ペアからoffspringを生成
            for i in range(len(self.pop)-1):
                offspring_list = [copy.deepcopy(self.pop[i])]
                offspring_scores = [scores[i]]
                j = random.randint(i+1, len(self.pop)-1)
                orig_scores = scores[i], scores[j]
                # for j in range(i+1, len(self.pop)):
                for _ in range(self.n_offspring):
                    parent1, parent2 = self.pop[i], self.pop[j]
                    child = self.edge_assembly_crossover(parent1, parent2)
                    child_score = self._calc_perplexity(child)
                    offspring_list.append(child)
                    offspring_scores.append(child_score)
                score_parent = scores[i]
                entropy_parent = self.calculate_edge_entropy(self.pop)
                best_idx = -1
                best_score_diff = 0
                eps = 1e-6
                for jj, (score, child) in enumerate(zip(offspring_scores, offspring_list)):
                    j = i + 1 + jj
                    score_diff = score - score_parent
                    if score_diff > 0:
                        continue
                    else:
                        pop_new = [p for k, p in enumerate(self.pop) if k != i] + [child]
                        entropy = self.calculate_edge_entropy(pop_new)
                        entropy_diff = entropy - entropy_parent
                        if entropy_diff >= 0.0:
                            score_diff = - score_diff / eps
                        else:
                            score_diff = score_diff / entropy_diff
                        if score_diff > best_score_diff:
                            best_score_diff = score_diff
                            best_idx = jj
                if best_idx != -1:
                    self.pop[i] = offspring_list[best_idx]
                    scores[i] = offspring_scores[best_idx]
                    if scores[i] < curr_gen_best:
                        curr_gen_best = scores[i]
                        idx_bos = self.pop[i].index("<bos>")
                        pop_save = self.pop[i][idx_bos+1:] + self.pop[i][:idx_bos]
                        words_save = [self.words_ppls_map[w] for w in pop_save]
                        save_text(self.calculator.get_perplexity, self.idx, " ".join(words_save), verbose=1)
                        print(f"Best score: {orig_scores[0]:.2f} & {orig_scores[1]:.2f} -> {curr_gen_best:.2f}")
            
            # # 親と子を合わせて評価
            # all_solutions = self.pop + offspring_list
            # all_scores = scores + offspring_scores
            
            # スコアの良い順にソート
            sorted_solutions = sorted(zip(self.pop, scores), 
                                   key=lambda x: x[1])
            
            # # 上位population_size個体を次世代に残す
            # self.pop = [sol for sol, _ in sorted_solutions[:self.population_size]]
            # scores = [score for _, score in sorted_solutions[:self.population_size]]
            
            # 最良解の更新
            if sorted_solutions[0][1] < best_solution[1]:
                best_solution = sorted_solutions[0]
                save_text(self.calculator.get_perplexity, 3, " ".join(best_solution[0]), verbose=1)
                print(f"Best score: {best_solution[1]:.2f}")
                print(f"Best solution: {best_solution[0]}")
            pbar.set_postfix({"score": best_solution[1]})
        return best_solution[0], best_solution[1]



In [161]:
# class GAEAX:
#     def __init__(self, words: list[str],
#                   scorer, population_size: int = 10,
#                     generations: int = 50,
#                       initial_pop: list[list[str]] = None):


In [170]:
# import glob
# pop = pd.read_pickle("pop_0002_100samples.pkl")
# files = glob.glob("./save/0003/*.txt")
# files = glob.glob("./save/0005/*.txt")
# files.sort()

In [171]:
# pop = []
# for i, file in enumerate(files[::-1]):
#     with open(file, "r") as f:
#         pop.append(f.read().split())    


In [102]:
# memo = pd.read_pickle("score_memo_with_error.pkl")
memo = pd.read_pickle("../../code_nagiss/save/score_memo.pkl")
len(memo)

682080

In [103]:
idx = 4
text_sample = df_sample.iloc[idx]["text"]
words_set_sample = sorted(text_sample.split())
len(words_set_sample)

50

In [104]:
scores = []
texts = []
# score_words = []
for i, (text, score) in tqdm.tqdm(enumerate(memo.items())):
    words_sorted = sorted(text.split())
    if words_sorted == words_set_sample:
        texts.append(text)
        scores.append(score)
        # score_words.append((score, text.split()))
len(texts), len(scores)

682080it [00:02, 255194.87it/s]


(178667, 178667)

In [105]:
df = pd.DataFrame({"text": texts, "score": scores})
df = df.sort_values(by="score", ascending=True)

In [106]:
texts = df["text"].tolist()[:1000]
# texts = texts[::-1]
# scores = scorer.get_perplexity(texts)

In [108]:
texts.append("of and to in the as you that it we with from have not night season eggnog milk chocolate candy peppermint cookie fruitcake toy doll game puzzle snowglobe candle fireplace star angel wreath poinsettia greeting card wrapping paper bow wish dream believe wonder hope joy peace merry hohoho kaggle workshop")

In [109]:
pop = [t.split() for t in texts]

In [110]:
n_pop = len(pop)
print(n_pop)

1001


In [111]:
pop_use = []
words_ppls_map = {}
for p in pop:
    p_use = ["<bos>"]
    for w in p:
        while w in p_use:
            w += "1"
        p_use.append(w)
        words_ppls_map[w] = w
    pop_use.append(p_use)

In [112]:
len(pop_use)

1001

In [113]:
ga_eax = GAEAX(pop_use[0], words_ppls_map, scorer, idx=idx, population_size=n_pop, n_offspring=32, generations=10000, initial_pop=pop_use, mode=2)

In [114]:
# for p in pop:
#     print(scorer.get_perplexity(" ".join(p)))


In [115]:
test = set(pop_use[0])
for p in pop_use:
    if set(p) != test:
        print(len(p), len(set(p)), p)


In [116]:
ga_eax.run()

GA-EAX:   0%|          | 0/10000 [00:00<?, ?it/s]

score:77.5996
Best score: 77.90 & 80.38 -> 77.60
score:76.9957
Best score: 78.21 & 84.23 -> 77.00
score:68.2144
Best score: 80.06 & 68.21 -> 68.21


GA-EAX:   0%|          | 15/10000 [1:30:33<585:53:25, 211.24s/it, score=68.2]

In [91]:
df_best = pd.read_csv("./submission_248.723217.csv")

In [98]:
scores = scorer.get_perplexity(df_best["text"].to_list())

In [99]:
np.mean(scores)

np.float64(249.379098366195)

In [97]:
scorer.get_perplexity("sleigh of holly yuletide cheer unwrap gifts eat holiday cheer relax sing carol the magi visit workshop grinch is naughty and nice decorations ornament chimney stocking nutcracker polar beard jingle")

198.9331323667161

In [96]:
scorer.get_perplexity("sleigh of holiday cheer unwrap gifts relax eat yuletide cheer sing carol the magi visit workshop grinch is naughty and nice decorations ornament chimney stocking nutcracker polar beard holly jingle")

198.9331323667161

In [90]:
import gc
gc.collect()
import torch
torch.cuda.empty_cache()


In [None]:
def edge_assembly_crossover(parent1: list[str], parent2: list[str]) -> list[str]:
    """Edge Assembly Crossoverを実装"""
    n = len(parent1)
    
    # 両親からエッジ情報を抽出
    edges = {}
    for p in [parent1, parent2]:
        for i in range(n-1):
            w1, w2 = p[i], p[i+1]
            if w1 not in edges:
                edges[w1] = set()
            edges[w1].add(w2)
    
    # ABサイクルの構築
    result = ["<bos>"]  # 開始ワード
    used = set(["<bos>"])
    
    while len(result) < n:
        curr = result[-1]
        # 使用可能なエッジから次のワードを選択
        candidates = edges.get(curr, set()) - used
        if not candidates:  # 行き詰まった場合
            candidates = set(parent1) - used
        
        next_word = random.choice(list(candidates))
        result.append(next_word)
        used.add(next_word)
    
    return result

def run_ga_eax(words: list[str], population_size: int = 10, generations: int = 50):
    """GA-EAXメインループ"""
    # 初期集団の生成
    population = []
    scores = []
    for _ in range(population_size):
        solution = words.copy()
        random.shuffle(solution[1:])  # <bos>は固定
        score = scorer.get_perplexity(" ".join(solution))
        population.append(solution)
        scores.append(score)
    
    best_solution = min(zip(population, scores), key=lambda x: x[1])
    
    for gen in tqdm.tqdm(range(generations)):
        # 親の選択（トーナメント選択）
        new_population = []
        new_scores = []
        
        while len(new_population) < population_size:
            # 親の選択
            parents = random.sample(list(zip(population, scores)), 2)
            parent1 = min(parents, key=lambda x: x[1])[0]
            parent2 = max(parents, key=lambda x: x[1])[0]
            
            # 交叉
            child = edge_assembly_crossover(parent1, parent2)
            
            # 局所探索（近傍探索）
            for _ in range(5):  # 局所探索の回数
                neighbor, _ = make_neighbor(child)
                if scorer.get_perplexity(" ".join(neighbor)) < scorer.get_perplexity(" ".join(child)):
                    child = neighbor
            
            score = scorer.get_perplexity(" ".join(child))
            new_population.append(child)
            new_scores.append(score)
            
            # 最良解の更新
            if score < best_solution[1]:
                best_solution = (child, score)
                print(f"Generation {gen}: Best Score = {best_solution[1]:.2f}, solution = {child}")
        
        # 世代の更新
        population = new_population
        scores = new_scores
        
        if gen % 5 == 0:  # 5世代ごとに経過を表示
            print(f"Generation {gen}: Best Score = {best_solution[1]:.2f}")
    
    return best_solution

In [None]:
# Calculate pheromone matrix from population
n_words = len(words)
pheromone = np.zeros((n_words, n_words))

# Add sentinel nodes for cycle representation
words_with_sentinel = ["<bos>"] + words + ["<eos>"]
n_words_with_sentinel = len(words_with_sentinel)

# Calculate scores for each solution in population
pop_scores = []
for solution in pop:
    solution_with_sentinel = ["<bos>"] + solution + ["<eos>"]
    score = scorer.get_perplexity(" ".join(solution))
    pop_scores.append(score)

# Update pheromone based on solution quality
for solution, score in zip(pop, pop_scores):
    solution_with_sentinel = ["<bos>"] + solution + ["<eos>"]
    for i in range(n_words_with_sentinel-1):
        word1 = solution_with_sentinel[i] 
        word2 = solution_with_sentinel[i+1]
        idx1 = words_with_sentinel.index(word1)
        idx2 = words_with_sentinel.index(word2)
        # Add pheromone inversely proportional to perplexity
        pheromone[idx1][idx2] += 1.0/score

# Normalize pheromone matrix
pheromone = pheromone / np.sum(pheromone)


In [12]:
words = ["<bos>"] + words