In [1]:
!pip install accelerate peft bitsandbytes transformers

Collecting peft
  Downloading peft-0.14.0-py3-none-any.whl.metadata (13 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl.metadata (2.9 kB)
Collecting huggingface-hub>=0.21.0 (from accelerate)
  Downloading huggingface_hub-0.27.0-py3-none-any.whl.metadata (13 kB)
Downloading peft-0.14.0-py3-none-any.whl (374 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m374.8/374.8 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hDownloading bitsandbytes-0.45.0-py3-none-manylinux_2_24_x86_64.whl (69.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 MB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading huggingface_hub-0.27.0-py3-none-any.whl (450 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m450.5/450.5 kB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: huggingface-hub, bitsandbytes, peft
  Attempting uninstall: hug

In [2]:
import gc
import heapq
import math
import os
import random
import statistics
from collections import Counter
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import bitsandbytes
from math import exp
from pprint import pprint

# 1. Perplexity (Log Scale)
    - Originally, perplexity is calculated by taking the exponential of the final loss. 
    - However, because the exponential function is monotonically increasing, I thought there was no real need to do so.

In [3]:
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class HuggingFaceModelLoader:
    def __init__(self, model_path: str, load_in_8bit: bool, device_map: str):
        self.model_path = model_path
        self.load_in_8bit = load_in_8bit
        self.device_map = device_map

    def load_model(self) -> transformers.PreTrainedModel:
        if self.load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires a CUDA device')

            quantization_config = transformers.BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="fp4",
                bnb_4bit_use_double_quant=False,
                bnb_4bit_compute_dtype=torch.float16,
            )

            model = transformers.AutoModelForCausalLM.from_pretrained(
                self.model_path,
                quantization_config=quantization_config,
                device_map=self.device_map
            )
        else:
            model = transformers.AutoModelForCausalLM.from_pretrained(
                self.model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=self.device_map
            )

        model.eval()
        return model


class HuggingFaceTokenizer:
    def __init__(self, model_path: str):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path, padding_side="right")
        self.bos_token = self.tokenizer.bos_token or self.tokenizer.cls_token
        self.eos_token = self.tokenizer.eos_token or self.tokenizer.sep_token
        if self.bos_token is None:
            self.bos_token = ""
        if self.eos_token is None:
            self.eos_token = ""

    def tokenize(self, texts: List[str]) -> dict:
        processed_texts = []

        for text in texts:
            combined_text = f"{self.bos_token}{text}{self.eos_token}"
            processed_texts.append(combined_text)

        model_inputs = self.tokenizer(
            processed_texts,
            return_tensors='pt',
            add_special_tokens=False,
            padding=True
        )

        if 'token_type_ids' in model_inputs:
            model_inputs.pop('token_type_ids')

        return model_inputs


class PerplexityCalculator:
    def __init__(self, model_loader, tokenizer, exp_mode=False):
        self.model = model_loader.load_model()
        self.tokenizer = tokenizer
        self.exp_mode = exp_mode
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

    def get_perplexity(
        self, 
        input_texts: Union[str, List[str]], 
        batch_size: int = 32
    ) -> Union[float, List[float]]:
        
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        num_texts = len(input_texts)
        batches = num_texts // batch_size + (num_texts % batch_size != 0)
        with torch.no_grad():
            for j in range(batches):
                start_idx = j * batch_size
                end_idx = (j + 1) * batch_size
                input_batch = input_texts[start_idx:end_idx]
    
                with torch.no_grad():
                    model_inputs = self.tokenizer.tokenize(input_batch)
                    model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
    
                    output = self.model(**model_inputs, use_cache=False)
                    logits = output['logits']
    
                    label = model_inputs['input_ids']
                    if hasattr(self.model.config, 'pad_token_id') and self.model.config.pad_token_id is not None:
                        label[label == self.model.config.pad_token_id] = PAD_TOKEN_LABEL_ID
    
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = label[..., 1:].contiguous()
    
                    token_loss = self.loss_fct(
                        shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1)
                    ).view(len(logits), -1)
    
                    valid_length = (shift_labels != PAD_TOKEN_LABEL_ID).sum(dim=-1)
                    sequence_loss = torch.sum(token_loss, -1) / valid_length
                    loss_list.extend(sequence_loss.cpu().tolist())
        
        if self.exp_mode:
            ppl = [exp(i) for i in loss_list]
        else:
            ppl = loss_list
            
        return ppl[0] if single_input else ppl

In [4]:
model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
model_loader = HuggingFaceModelLoader(model_path=model_path, load_in_8bit=False, device_map='auto')
tokenizer = HuggingFaceTokenizer(model_path)
scorer = PerplexityCalculator(model_loader, tokenizer, exp_mode=False)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

# 2. Cost Function
    - The `CostFunction` class wraps a scorer’s `get_perplexity` method to evaluate a list of solutions given a `batch_size`. 
    - It provides both `evaluate` and `__call__` methods so you can use it like a function that returns perplexities for each solution.

In [5]:
class CostFunction:
    def __init__(self, scorer, batch_size):
        self.scorer = scorer
        self.batch_size = batch_size

    def evaluate(self, solutions: List[str]) -> List[float]:
        return self.scorer.get_perplexity(solutions, self.batch_size)

    def __call__(self, solutions: List[str]) -> List[float]:
        return self.evaluate(solutions)

# 3. Neighbor Modifier
    - Applies random TSP list operations to the middle portion of a route, preserving prefix and suffix elements.
    - Use `modify` or call the instance directly to transform the input string into a new route.

#### `two_rotation`
$$
\text{two\_rotation}(\mathbf{s}) \;\Longrightarrow\;
[\dots, s_j, \dots, s_i, \dots].
$$

---

#### `near_two_rotation`
$$
\text{near\_two\_rotation}(\mathbf{s}) \;\Longrightarrow\;
[\dots, s_j, \dots, s_i, \dots].
$$

---

#### `random_rotation`

$$
\text{random\_rotation}(\mathbf{s}) \;\Longrightarrow\;
[\dots, \underbrace{s_{i+\delta},\dots,s_j, s_i,\dots,s_{i+\delta-1}}_{\text{rotated sublist}},\dots].
$$

---

#### `random_insertion`  
$$
\text{random\_insertion}(\mathbf{s}) \;\Longrightarrow\;
[\dots, s_k \text{ removed}, \dots] \;\longrightarrow\;
[\dots, s_k \text{ inserted at index } m, \dots].
$$

---

#### `random_reverse`

$$
\text{random\_reverse}(\mathbf{s}) \;\Longrightarrow\;
[\dots, s_j, \dots, s_i, \dots].
$$

---

#### `shuffle_sublist`
 
$$
\text{shuffle\_sublist}(\mathbf{s}) \;\Longrightarrow\;
[\dots, s_{\pi(1)}, s_{\pi(2)}, \dots, s_{\pi(\ell)}, \dots]
$$

---

#### `double_bridge_move`

$$
\mathbf{s}' = A \;||\; D \;||\; C \;||\; B \;||\; E.
$$

In [6]:
class TSPListOperations:
    @staticmethod
    def two_rotation(item: List[str]) -> List[str]:
        # Swap two randomly chosen elements in the list.
        if len(item) < 2:
            return item
        i, j = random.sample(range(len(item)), 2)
        item[i], item[j] = item[j], item[i]
        return item

    @staticmethod
    def near_two_rotation(item: List[str]) -> List[str]:
        # Select a random segment within a certain size range and swap its endpoints.
        if len(item) < 3:
            return item
        size = random.choice(range(2, min(20, len(item))))
        i = random.choice(range(len(item) - size))
        j = i + size
        item[i], item[j] = item[j], item[i]
        return item

    @staticmethod
    def random_rotation(item: List[str]) -> List[str]:
        # Randomly rotate a sublist within the list.
        n = len(item)
        if n < 2:
            return item
        start_idx, end_idx = sorted(random.sample(range(n), 2))
        sublist = item[start_idx:end_idx+1]
        if len(sublist) < 2:
            return item

        shift = random.randint(1, len(sublist) - 1)
        shift %= len(sublist)
        sublist = sublist[shift:] + sublist[:shift]

        item[start_idx:end_idx+1] = sublist
        return item

    @staticmethod
    def random_insertion(item: List[str]) -> List[str]:
        # Remove a random element from the list and re-insert it at a new position.
        if len(item) < 2:
            return item
        i = random.randrange(len(item))
        elem = item.pop(i)
        j = random.randrange(len(item) + 1)
        item.insert(j, elem)
        return item

    @staticmethod
    def random_reverse(item: List[str]) -> List[str]:
        # Reverse a randomly chosen sublist in the list.
        if len(item) < 2:
            return item
        start, end = sorted(random.sample(range(len(item)), 2))
        item[start:end+1] = reversed(item[start:end+1])
        return item

    @staticmethod
    def shuffle_sublist(item: List[str]) -> List[str]:
        # Shuffle a randomly selected sublist within the list.
        if len(item) < 2:
            return item
        size = random.randint(2, min(6, len(item)))
        start = random.randint(0, len(item) - size)
        end = start + size - 1
        sub = item[start:end+1]
        random.shuffle(sub)
        item[start:end+1] = sub
        return item

    @staticmethod
    def double_bridge_move(item: List[str]) -> List[str]:
        # Perform a double-bridge move by selecting four cut points and rearranging the segments.
        n = len(item)
        if n < 5:
            return item
        cuts = sorted(random.sample(range(1, n), 4))
        a, b, c, d = cuts

        A = item[:a]
        B = item[a:b]
        C = item[b:c]
        D = item[c:d]
        E = item[d:]

        if random.choice([True, False]):
            B = B[::-1]
        if random.choice([True, False]):
            D = D[::-1]

        return A + D + C + B + E


In [7]:
class NeighborModifier:
    """
    Applies random TSP list operations to the middle portion of a route, preserving prefix and suffix elements.
    Use `modify` or call the instance directly to transform the input string into a new route.
    """

    def __init__(self, prefix: int = 0, suffix: int = 0):
        self.prefix = prefix
        self.suffix = suffix
        self.operations = [
            TSPListOperations.two_rotation,
            TSPListOperations.near_two_rotation,
            TSPListOperations.random_rotation,
            TSPListOperations.random_insertion,
            TSPListOperations.random_reverse,
            TSPListOperations.shuffle_sublist,
            TSPListOperations.double_bridge_move
        ]

    def modify(self, arr: str) -> str:
        item = arr.split()
        if len(item) <= self.prefix + self.suffix:
            return arr

        op = random.choice(self.operations)
        middle = item[self.prefix:len(item)-self.suffix]
        modified_middle = op(middle)
        new_item = item[:self.prefix] + modified_middle + item[len(item)-self.suffix:]
        return ' '.join(new_item)

    def __call__(self, arr: str) -> str:
        return self.modify(arr)

# 3. Parameters

In [8]:
class AlgorithmParameters:
    def __init__(self,
                 # Beam Search
                 beam_width: int = 3,
                 batch_size: int = 2, 
                 branch_interval: int = 5,
                 no_improve_limit: int = 10,
                 verbose: bool = True,
                 # SA + Temperature
                 cooling_mode: str = 'linear',
                 max_temperature: float = 0.1,
                 min_temperature: float = 0.,
                 num_trials: int = 2,
                 iterations: int = 3,
                 max_iterations: int = 20,
                 lambda_m: float = 5.0,
                 rho_target: float = 0.5,
                 M_min: int = 2,
                 M_max: int = 5):
        self.beam_width = beam_width
        self.batch_size = batch_size
        self.branch_interval = branch_interval
        self.no_improve_limit = no_improve_limit
        self.verbose = verbose

        self.cooling_mode = cooling_mode
        self.max_iterations = max_iterations
        self.max_temperature = max_temperature
        self.min_temperature = min_temperature
        self.num_trials = num_trials
        self.iterations = iterations
        self.lambda_m = lambda_m
        self.rho_target = rho_target
        self.M_min = M_min
        self.M_max = M_max

# 4. Beam Manager
    Manages multiple BeamBatches for beam search, maintaining solutions and their costs.
    Provides functionalities to initialize the beam from a set of candidate solutions,
    update the beam with new candidates, track improvement, and revert to saved branch points
    when needed.


In [9]:
class BeamBatch:
    def __init__(self, solutions: List[str], costs: List[float]):
        self.beam = list(zip(solutions, costs))  # [(sol, cost), ...]
        self.branch_archives: List[List[Tuple[str, float]]] = [
            [] for _ in range(len(self.beam))
        ]
        self.no_improve_counts: List[int] = [0]*len(self.beam)
        self.prev_best_costs: List[float] = costs[:]

    def size(self) -> int:
        """Returns the number of solutions in the beam."""
        return len(self.beam)

In [10]:
class BeamManager:
    """
    Attributes:
        params (AlgorithmParameters): A parameter object containing beam search configurations 
            such as beam_width, batch_size, branch_interval, no_improve_limit, and verbosity.
        beam_width (int): The maximum number of BeamBatches (sub-beams).
        batch_size (int): The number of solutions contained in each BeamBatch.
        cost_fn (Callable[[List[str]], List[float]]): A function that takes a list of solutions 
            and returns a corresponding list of costs.
        batches (List[BeamBatch]): A list of BeamBatch instances representing the current beam.
    """

    def __init__(self, params: AlgorithmParameters, cost_fn: Callable[[List[str]], List[float]]):
        """
        Initializes the BeamManager with given parameters and a cost function.

        Args:
            params (AlgorithmParameters): Contains essential beam search parameters.
            cost_fn (Callable[[List[str]], List[float]]): A function to evaluate solutions' costs.
        """
        self.params = params
        self.beam_width = params.beam_width
        self.batch_size = params.batch_size
        self.cost_fn = cost_fn

        self.batches: List[BeamBatch] = []  # Holds multiple BeamBatch instances.

    def initialize_beam(self, initial_solutions: List[str]):
        """
        Selects up to beam_width * batch_size solutions (potentially more),
        ranks them by cost, and creates multiple BeamBatch objects by grouping
        solutions in chunks of batch_size.

        Args:
            initial_solutions (List[str]): A list of candidate solutions from which
                to initialize the beam.
        """
        top_needed = self.beam_width * self.batch_size * 2
        candidate_solutions = random.sample(
            initial_solutions,
            min(top_needed, len(initial_solutions))
        )
        candidate_costs = self.cost_fn(candidate_solutions)

        zipped = list(zip(candidate_solutions, candidate_costs))
        zipped_sorted = sorted(zipped, key=lambda x: x[1])
        needed = self.beam_width * self.batch_size
        selected = zipped_sorted[:needed]

        chunks = [
            selected[i:i+self.batch_size]
            for i in range(0, len(selected), self.batch_size)
        ]

        for chunk in chunks:
            sols, costs = zip(*chunk) if chunk else ([], [])
            new_batch = BeamBatch(list(sols), list(costs))
            self.batches.append(new_batch)

    def record_branch_points(self, iteration: int):
        """
        Saves the current solutions and costs into each batch's archive 
        every 'branch_interval' iterations.

        Args:
            iteration (int): The current iteration count of the beam search.
        """
        if iteration % self.params.branch_interval == 0:
            for batch_idx, batch in enumerate(self.batches):
                for i, (sol, cost) in enumerate(batch.beam):
                    batch.branch_archives[i].append((sol, cost))

    def check_improvement(self, batch_idx: int, i: int, new_cost: float) -> bool:
        """
        Checks whether a new cost is an improvement for the i-th member of a given batch.

        Args:
            batch_idx (int): The index of the batch in self.batches.
            i (int): The index of the member within the batch.
            new_cost (float): The newly computed cost for that member.

        Returns:
            bool: True if the cost is improved, False otherwise.
        """
        batch = self.batches[batch_idx]
        improved = False
        if new_cost < batch.prev_best_costs[i]:
            batch.no_improve_counts[i] = 0
            improved = True
        else:
            batch.no_improve_counts[i] += 1
        batch.prev_best_costs[i] = new_cost
        return improved

    def return_to_branch_point(self, batch_idx: int, i: int):
        """
        If the cost has not improved for more than 'no_improve_limit' times for 
        a specific member, reverts that member to the best known branch archive.

        Args:
            batch_idx (int): The index of the batch in self.batches.
            i (int): The index of the member within the batch.
        """
        batch = self.batches[batch_idx]
        if (batch.no_improve_counts[i] > self.params.no_improve_limit
            and batch.branch_archives[i]):
            chosen_point = min(batch.branch_archives[i], key=lambda x: x[1])
            sol, cost = chosen_point
            batch.beam[i] = (sol, cost)
            batch.no_improve_counts[i] = 0
            if self.params.verbose:
                print(f"[Info] Batch {batch_idx}, member {i}: "
                      f"No improvement -> Return to best past branch")

    def update_beam(self, new_candidates: List[List[Tuple[str, float]]]):
        """
        Merges new candidate solutions for each batch, filters duplicates 
        (keeping the lowest cost if duplicates occur), and updates each batch 
        to maintain only the top 'batch_size' solutions.

        Args:
            new_candidates (List[List[Tuple[str, float]]]): A list of lists,
                where each sub-list corresponds to a batch and contains tuples
                of (solution, cost).
        """
        for b_idx, batch in enumerate(self.batches):
            old_info = {}
            for (sol, _), nic, pbc, ba in zip(
                batch.beam,
                batch.no_improve_counts,
                batch.prev_best_costs,
                batch.branch_archives
            ):
                old_info[sol] = (nic, pbc, ba)

            combined = list(batch.beam) + new_candidates[b_idx]
            unique_map = {}
            for sol, cost in combined:
                if sol not in unique_map or unique_map[sol] > cost:
                    unique_map[sol] = cost

            sorted_candidates = sorted(unique_map.items(), key=lambda x: x[1])
            updated_beam = sorted_candidates[:batch.size()]

            new_solutions, new_costs = zip(*updated_beam) if updated_beam else ([], [])
            new_info = [old_info.get(sol, (0, cost, [])) for sol, cost in updated_beam]

            if new_info:
                new_no_improve_counts, new_prev_best_costs, new_branch_archives = map(list, zip(*new_info))
            else:
                new_no_improve_counts, new_prev_best_costs, new_branch_archives = [], [], []

            batch.beam = list(zip(new_solutions, new_costs))
            batch.no_improve_counts = new_no_improve_counts
            batch.prev_best_costs = new_prev_best_costs
            batch.branch_archives = new_branch_archives

    def get_global_best(self) -> Tuple[str, float, int, int]:
        """
        Retrieves the globally best solution (lowest cost) across all batches.

        Returns:
            Tuple[str, float, int, int]: A tuple containing the best solution string, 
            its cost, the batch index, and the member index.
        """
        best_sol = None
        best_cost = float('inf')
        best_batch = -1
        best_member = -1
        for b_idx, batch in enumerate(self.batches):
            for i, (sol, cost) in enumerate(batch.beam):
                if cost < best_cost:
                    best_sol = sol
                    best_cost = cost
                    best_batch = b_idx
                    best_member = i
        return best_sol, best_cost, best_batch, best_member

    def get_all_solutions(self) -> List[Tuple[str, float]]:
        """
        Gathers all (solution, cost) pairs from every batch in the current beam.

        Returns:
            List[Tuple[str, float]]: A list of (solution, cost) tuples representing
            the current members of all BeamBatches.
        """
        result = []
        for b in self.batches:
            result.extend(b.beam)
        return result


# 5. Temperature Manager

**Interpolation Modes**  
   - **Linear**:  
     $$
       T(t) = T_\text{max} - \bigl(T_\text{max} - T_\text{min}\bigr) \times \frac{t}{T_\text{total}}
     $$
   - **Exponential**:  
     $$
       T(t) = T_\text{max} \times \left(\frac{T_\text{min}}{T_\text{max}}\right)^{\frac{t}{T_\text{total}}}
     $$
   - **Logarithmic**:  
     $$
       T(t) = \frac{T_\text{max}}{1 + \ln(1 + t)}
       \quad\text{or}\quad
       T(t) = T_\text{min} + \frac{T_\text{max} - T_\text{min}}{1 + \ln(1 + t)}
     $$


In [11]:
class TemperatureManager:
    def __init__(self, params: 'AlgorithmParameters'):
        self.params = params
        self.temperatures: List[float] = []

    def init_temperatures(self, num_batches: int):
        self.temperatures = [self.params.max_temperature] * num_batches

    def get_temperature(self, batch_idx: int) -> float:
        return self.temperatures[batch_idx]

    def cool(self, batch_idx: int, iteration: int):
        mode = self.params.cooling_mode
        maxT = self.params.max_temperature
        minT = self.params.min_temperature
        total_iter = float(self.params.max_iterations)
        if iteration > self.params.max_iterations:
            iteration = self.params.max_iterations

        if mode == "linear":
            new_temp = maxT - (maxT - minT) * (iteration / total_iter)
        elif mode == "exponential":
            ratio = (minT / maxT) if maxT != 0 else 0.0
            exponent = iteration / total_iter
            new_temp = maxT * (ratio ** exponent)
        elif mode == "logarithmic":
            if iteration == 0:
                iteration = 1
            new_temp = maxT / (1.0 + math.log(1.0 + iteration))
        else:
            new_temp = maxT - (maxT - minT) * (iteration / total_iter)

        new_temp = max(minT, min(new_temp, maxT))
        self.temperatures[batch_idx] = new_temp


# 6. Simulated Annealing Controller
    This class orchestrates Simulated Annealing by managing solutions in a beam (BeamManager) and controlling each batch’s temperature (TemperatureManager). It iteratively generates neighbors, accepts them based on the SA acceptance rule, and updates or reverts solutions to improve overall optimization.  


In [13]:
import math
import random
from typing import List, Callable, Tuple

class SAController:
    def __init__(
        self,
        R: List[str],
        cost_function: Callable[[List[str]], List[float]],
        neighbor_function: Callable[[str], str],
        params: 'AlgorithmParameters'
    ):
        self.params = params
        # (1) BeamManager を使用してビームを管理する
        self.beam_manager = BeamManager(params, cost_fn=cost_function)
        self.beam_manager.initialize_beam(R)

        self.cost_function = cost_function
        self.neighbor_function = neighbor_function

        # (2) TemperatureManager を使用して温度を管理する
        self.temp_manager = TemperatureManager(params)

        # SAに関連する変数
        self.recent_improvements = [0]*5
        self.M = 10
        self.global_best_cost = float('inf')
        self.global_best_sol = None

    def adjust_M(self):
        """
        直近の改善率に基づいて M（近傍探索の回数）を調整する。
        """
        improvement_rate = sum(self.recent_improvements) / len(self.recent_improvements)
        M_new = self.M + self.params.lambda_m * (self.params.rho_target - improvement_rate)
        self.M = max(self.params.M_min, min(self.params.M_max, int(M_new)))

    def run(self):
        num_batches = len(self.beam_manager.batches)

        for trial in range(self.params.num_trials):
            # すべてのバッチに対して温度を初期化
            self.temp_manager.init_temperatures(num_batches)
            self.M = self.params.M_min

            for it in range(self.params.iterations):
                iteration_improved = False
                # 分岐ポイントを記録する
                self.beam_manager.record_branch_points(it)
                # Mを調整する
                self.adjust_M()

                # 新しい候補解をバッチごとに集める: [ [ (sol,cost), ... ], [ (sol,cost), ... ], ... ]
                new_candidates_per_batch: List[List[Tuple[str, float]]] = []

                # 各バッチに対して独立にSAを実行
                for b_idx, batch in enumerate(self.beam_manager.batches):
                    beam_solutions = [sol for (sol, c) in batch.beam]
                    beam_costs = [c for (sol, c) in batch.beam]

                    # M回、近傍探索を行う
                    for _ in range(self.M):
                        # 現在の温度を取得
                        T = self.temp_manager.get_temperature(b_idx)

                        # 近傍解を生成
                        neighbors = [self.neighbor_function(s) for s in beam_solutions]
                        neighbors_cost = self.cost_function(neighbors)

                        # 受理可否を決定
                        for i in range(len(beam_solutions)):
                            diff = neighbors_cost[i] - beam_costs[i]
                            if diff < 0:
                                beam_solutions[i] = neighbors[i]
                                beam_costs[i] = neighbors_cost[i]
                            else:
                                p = math.exp(-diff/(T+1e-12))
                                if random.random() < p:
                                    beam_solutions[i] = neighbors[i]
                                    beam_costs[i] = neighbors_cost[i]

                        # 各近傍探索後に温度を冷却
                        self.temp_manager.cool(b_idx, it)

                    # M回の探索後、ビーム内の各解の改善をチェック
                    batch_candidates = []
                    for i, (old_sol, old_cost) in enumerate(batch.beam):
                        new_sol = beam_solutions[i]
                        new_cost = beam_costs[i]

                        improved = self.beam_manager.check_improvement(b_idx, i, new_cost)
                        if not improved:
                            #改善していない場合、分岐ポイントに戻す
                            self.beam_manager.return_to_branch_point(b_idx, i)

                        batch_candidates.append((new_sol, new_cost))

                    new_candidates_per_batch.append(batch_candidates)

                # すべてのバッチを更新
                self.beam_manager.update_beam(new_candidates_per_batch)

                # グローバル最適解を更新
                g_sol, g_cost, g_bidx, g_iidx = self.beam_manager.get_global_best()
                if g_cost < self.global_best_cost:
                    self.global_best_cost = g_cost
                    self.global_best_sol = g_sol
                    iteration_improved = True

                self.recent_improvements.pop(0)
                self.recent_improvements.append(1 if iteration_improved else 0)
                ir = sum(self.recent_improvements)/len(self.recent_improvements)

                if self.params.verbose:
                    print(f"[Trial {trial}, Iter {it}], "
                          f"Global Best cost={exp(self.global_best_cost):.4f}, "
                          f"T={T:.5f}, "
                          f"M={self.M}, improvement_rate={ir:.2f}, \n"
                          f"{self.global_best_sol}\n")

        # 最終的なビームの状態
        final_solutions = self.beam_manager.get_all_solutions()
        return final_solutions, (self.global_best_sol, self.global_best_cost)


# Main

In [14]:

import random

# テキストの定義
text = 'sleigh of magi unwrap is gifts the cheer cheer yuletide naughty jingle and eat decorations holiday holly grinch nutcracker nice relax workshop sing polar visit ornament carol stocking chimney beard'
initial_solutions = []

# テキストを単語リストに変換
words = text.split()
# ランダムに単語の並びを変えたものを10個生成して表示
for i in range(40):
    # copyした単語リストをランダムにシャッフルしてから表示
    shuffled_words = random.sample(words, len(words))
    initial_solutions.append(' '.join(shuffled_words))


In [15]:
prefix_size = 0
suffix_size = 0

cost_func = CostFunction(scorer, batch_size=30)
neighbor_func = NeighborModifier(prefix=prefix_size, suffix=suffix_size)
   
params = AlgorithmParameters(
    beam_width=30,   # 한 배치의 빔 크기
    batch_size=4,   # 배치 수: beam_width * batch_size = 전체 솔루션 수
    verbose=True,
    cooling_mode='linear',
    max_temperature=0.02,
    min_temperature=0.00001,
    num_trials=30, # before 30
    iterations=10,
    max_iterations=100,
    M_min=2,
    M_max=15
)


print(len(initial_solutions))


controller = SAController(
    R=initial_solutions,
    cost_function=cost_func,
    neighbor_function=neighbor_func,
    params=params
)

40


In [16]:
final_solutions, (best_sol, best_cost) = controller.run()
print("\n[All Solutions in Batches]:")
for sol, c in final_solutions:
    print(f"{sol} | cost={c:.4f}")

print(f"\n[Global Best] sol='{best_sol}', cost={best_cost:.4f}")


[Trial 0, Iter 0], Global Best cost=1509.7621, T=0.02000, M=4, improvement_rate=0.20, 
grinch unwrap sing relax holiday of is beard polar yuletide eat nutcracker jingle holly and cheer carol the visit decorations stocking ornament gifts workshop chimney naughty nice magi sleigh cheer

[Trial 0, Iter 1], Global Best cost=1277.0232, T=0.01980, M=5, improvement_rate=0.40, 
grinch unwrap sing relax holiday of is the carol cheer and holly jingle nutcracker eat yuletide polar beard visit decorations stocking ornament gifts workshop chimney naughty nice magi sleigh cheer

[Trial 0, Iter 2], Global Best cost=1181.4612, T=0.01960, M=5, improvement_rate=0.60, 
jingle chimney beard carol the is grinch magi holly sleigh eat visit and sing of holiday cheer unwrap gifts yuletide ornament stocking nutcracker polar decorations relax cheer naughty nice workshop

[Trial 0, Iter 3], Global Best cost=1167.6884, T=0.01940, M=4, improvement_rate=0.80, 
jingle chimney beard carol the is grinch polar holly sl

In [17]:
solutions = controller.beam_manager.get_all_solutions()
solutions = [(x, exp(y)) for (x, y ) in solutions]
solutions = sorted(solutions, key=lambda x: x[1])

pprint(solutions)

[('sleigh yuletide holly jingle relax and unwrap gifts of the magi holiday '
  'cheer carol sing visit workshop eat cheer grinch nutcracker decorations '
  'ornament chimney stocking naughty nice polar is beard',
  241.5128645268717),
 ('sleigh yuletide holly jingle relax and unwrap gifts of the magi holiday '
  'cheer carol sing visit workshop eat cheer grinch chimney stocking '
  'nutcracker decorations ornament naughty nice polar is beard',
  241.9475322727233),
 ('sleigh yuletide holly jingle relax and unwrap gifts of the magi holiday '
  'cheer carol sing visit workshop eat cheer grinch chimney stocking '
  'nutcracker ornament decorations naughty nice polar is beard',
  243.36413293931074),
 ('sleigh yuletide holly jingle relax and unwrap gifts of the magi holiday '
  'cheer carol sing visit workshop eat cheer grinch chimney stocking '
  'decorations ornament nutcracker naughty nice polar is beard',
  245.08475272144213),
 ('jingle is the holiday of holly yuletide cheer decoratio

In [18]:
import torch

# デフォルトデバイスをCPUに設定
device = torch.device('cpu')
