In [4]:
import random
from abc import ABCMeta, abstractmethod

import torch
import numpy as np


class Operator(metaclass=ABCMeta):
    def __call__(self, tokens):
        if isinstance(tokens, (np.ndarray, torch.Tensor)):
            tokens = tokens.tolist()
        return self.apply(tokens)

    @abstractmethod
    def apply(self, tokens):
        raise NotImplementedError


class OnePointShuffle(Operator):
    def apply(self, tokens):
        i = random.choice(np.arange(1, len(tokens)))
        tokens = tokens[i:] + tokens[:i]
        return tokens


class PairPointShuffle(Operator):
    def apply(self, tokens):
        i, j = random.sample(range(len(tokens)), k=2)
        tokens[i], tokens[j] = tokens[j], tokens[i]
        return tokens


class TokensShuffle(Operator):
    def __init__(self, min_tokens=1, max_tokens=3):
        assert min_tokens < max_tokens
        self.min_tokens = min_tokens
        self.max_tokens = max_tokens

    def apply(self, tokens):
        assert len(tokens) > self.max_tokens
        while True:
            i, j = random.sample(range(len(tokens)), k=2)
            if i > j:
                i, j = j, i
            i_end = i + random.randint(self.min_tokens, self.max_tokens+1)
            j_end = j + random.randint(self.min_tokens, self.max_tokens+1)
            if i_end <= j:
                break
        tokens = tokens[:i] + tokens[j: j_end] + tokens[i_end: j] + tokens[i: i_end] + tokens[j_end:]
        return tokens


class TokensReverse(Operator):
    def __init__(self, min_tokens=2, max_tokens=3):
        assert min_tokens < max_tokens
        self.min_tokens = min_tokens
        self.max_tokens = max_tokens

    def apply(self, tokens):
        assert len(tokens) > self.max_tokens
        i = random.choice(range(len(tokens)-1))
        j = random.choice(range(self.min_tokens, self.max_tokens+1))
        k = min(i+j, len(tokens))
        tokens[i:k] = tokens[i:k][::-1]
        return tokens


class TokensInsert(Operator):
    def __init__(self, min_tokens=2, max_tokens=3):
        assert min_tokens < max_tokens
        self.min_tokens = min_tokens
        self.max_tokens = max_tokens

    def apply(self, tokens):
        assert len(tokens) > self.max_tokens
        i = random.choice(range(len(tokens)))
        j = random.choice(range(self.min_tokens, self.max_tokens+1))
        k = min(i+j, len(tokens))
        sub_tokens = tokens[i:k]
        main_tokens = tokens[:i] + tokens[k:]
        l = random.choice(range(len(main_tokens)))
        tokens = main_tokens[:l] + sub_tokens + main_tokens[l:]
        return tokens


tokens = list(range(10))
op = TokensInsert()
op(tokens)

[0, 1, 2, 3, 6, 7, 8, 4, 5, 9]

In [67]:
import random
from abc import ABCMeta, abstractmethod


class BaseSampler(metaclass=ABCMeta):
    def __init__(self, ops):
        self.ops = ops
        self.indices = list(range(len(ops)))
        self.index = -1
        self.weight = np.ones(len(ops))

    def sample(self):
        self.index = np.random.choice(self.indices, p=self.weight/self.weight.sum())
        return self.ops[self.index]
        
    @abstractmethod
    def update(self, diff):
        raise NotImplementedError
        

class UniformSampler(BaseSampler):
    def __init__(self, ops):
        super().__init__(ops)

    def update(self, diff):
        return


ops = [TokensInsert(), TokensReverse(), TokensShuffle()]
from collections import Counter
sampler = UniformSampler(ops)
counter = Counter()
for _ in range(10000):
    op = sampler.sample()
    counter[str(op.__class__)] += 1
counter

Counter({"<class '__main__.TokensInsert'>": 3376,
         "<class '__main__.TokensShuffle'>": 3320,
         "<class '__main__.TokensReverse'>": 3304})

In [17]:
class Scorer:
    def __init__(self, mi=0, ma=1):
        self.mi = mi
        self.ma = ma
    def get_perplexity(self, text, batch_size=32):
        return random.uniform(self.mi, self.ma)


scorer = Scorer()
scorer.get_perplexity("")

0.14426930215331735

In [68]:
import math
import random


# ToDo: どの操作がスコアを上げたのかをログする機能の追加
def simulated_annealing(text, sampler, scorer, temp_start=10, temp_end=0.5, cooling_rate=0.95, steps_per_temp=5, alpha=1.0, precomputed={}, verbose=False, logging_step=1, batch_size=1):
    # initial setting
    text = text.strip()
    tokens = text.split(" ")
    best_tokens = tokens.copy()
    best_score = scorer.get_perplexity(text, batch_size=batch_size)
    # optimization
    temp = temp_start
    print(f"start temp: {temp:.2f}, init score: {best_score:.5f}")
    num_steps = 0
    while temp > temp_end:
        num_steps += 1
        for _ in range(steps_per_temp):
            op = sampler.sample()
            tokens = op(tokens)
            new_text = " ".join(tokens)
            if new_text in precomputed:
                new_score = precomputed[new_text]
            else:
                new_score = scorer.get_perplexity(new_text, batch_size=batch_size)
                precomputed[new_text] = new_score
            delta = new_score - best_score
            sampler.update(-delta)
            if delta < 0:
                # improvement
                best_tokens = tokens.copy()
                best_score = new_score
                print(">", end="")
            elif random.random() < math.exp(-alpha*delta / temp):
                # explore
                print("<", end="")
            else:
                # exploit
                tokens = best_tokens.copy()
                print("-", end="")
        temp *= cooling_rate
        if verbose and num_steps % logging_step == 0:
            print(f"\ncurrent temp: {temp:.2f}, current score: {best_score:.5f}")
    return " ".join(best_tokens), best_score, precomputed


In [None]:
def sub_permutations(tokens, fixed_ids=[]):
    tokens = np.array(tokens)
    fixed_ids = np.array(sorted(fixed_ids))
    mutable_tokens = np.array([
        tokens[i] for i in range(len(tokens)) if i not in fixed_ids
    ])
    assert len(mutable_tokens) < 11
    perms = list(map(list, itertools.permutations(mutable_tokens)))
    for perm in perms:
        for i in fixed_ids:
            perm.insert(i, tokens[i])
    return perms
