In [1]:
import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

from collections import deque
from typing import Tuple, List
from collections import deque
import random

## GEMMA

In [2]:
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ParticipantVisibleError(Exception):
    pass

def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))

class PerplexityCalculator:
    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_perplexity(
        self, input_texts: Union[str, List[str]], debug=False
    ) -> Union[float, List[float]]:
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )

                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

                # Debug output
                if debug:
                    print(f"\nProcessing: '{text}'")
                    print(f"With special tokens: '{text_with_special}'")
                    print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                    print(f"Target tokens: {shift_labels[0].tolist()}")
                    print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                    print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                    print(f"Individual losses: {loss.tolist()}")
                    print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        if debug:
            print("\nFinal perplexities:")
            for text, perp in zip(input_texts, ppl):
                print(f"Text: '{text}'")
                print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

---
# CACHING

In [3]:
class Caching:
    def __init__(self, max_size: int = 1000):
        self.stores: List[dict] = []
        self.idx = -1
        self.max_size = max_size
        self._state = {'len': 0, 'hit': 0, 'miss': 0, 'total': 0}

    def create_store(self):
        self.stores.append({})
        self.idx += 1
        self._state['len'] += 1

    def add(self, key, value):
        if self.idx == -1 or len(self.stores[self.idx]) >= self.max_size:
            self.create_store()
        self.stores[self.idx][key] = value

    def get(self, key):
        if self.idx == -1:
            return None
        self._state['total'] += 1
        store_idx = random.choices(range(self.idx+1), k=self.idx+1) 
        for i in store_idx:
            v = self.stores[i].get(key, None)
            if v is not None:
                self._state['hit'] += 1
                return v
        self._state['miss'] += 1
        return v

    def state(self):
        return f"total: {self._state['total']}, len: {self._state['len']}, stores: {len(self.stores)} "\
              f"hit: {self._state['hit']} [{self._state['hit']/self._state['total']*100}%], "\
              f"miss: {self._state['miss']} [{self._state['miss']/self._state['total']*100}%]"


---
# TabuSearch

In [4]:
class TabuSearch:
    '''
    Tabu Search Algorithm

    This class implements the Tabu Search optimization algorithm, which is used to solve
    combinatorial optimization problems. It iteratively improves solutions by exploring
    neighboring candidates and using a tabu list to avoid revisiting recently explored solutions.

    Parameters:
        words (list[str]): The sequence of words to be optimized.
        calculate_perplexity (callable): A method to calculate the perplexity of a sentence.
        tabu_tenure (int, optional): The size of the tabu list. Default is 5.
        stop_on_non_improvement (int, optional): The number of iterations to allow without improvement before stopping. Default is 20.
        max_iter (int, optional): The maximum number of iterations to execute. Default is 100.
        max_candidates (int, optional): The maximum number of candidate solutions to generate per iteration. Default is 100.
        debug (bool, optional): If True, enables debug logging. Default is False.
        logger (callable, optional): A logging function. If None, prints to the console. Default is None.
    '''
    
    def __init__(self, words: list, calculate_perplexity, tabu_tenure=5,
                 stop_on_non_improvement=20, max_iter=100, 
                 max_condidates=100,
                 debug=False, logger = None):
        '''
        Initializes the TabuSearch instance.

        Parameters:
            words (list[str]): The sequence of words to be optimized.
            calculate_perplexity (callable): A method to calculate the perplexity of a sentence.
            tabu_tenure (int, optional): The size of the tabu list. Default is 5.
            stop_on_non_improvement (int, optional): The number of iterations to allow without improvement. Default is 20.
            max_iter (int, optional): The maximum number of iterations. Default is 100.
            max_condidates (int, optional): The maximum number of candidate solutions per iteration. Default is 100.
            debug (bool, optional): If True, enables debug logging. Default is False.
            logger (callable, optional): A logging function. Default is None.
        '''
        self.words = words
        self.n = len(self.words)
        self.tabu_tenure = tabu_tenure
        self.stop_on_non_improvement = stop_on_non_improvement
        self.max_iter = max_iter
        self.calculate_perplexity = calculate_perplexity
        self.tabu_list = deque(maxlen=self.tabu_tenure)
        self.max_condidates = max_condidates
        self.debug = debug
        self.logger = logger

    def log(self, msg):
        '''
        Logs a message if debugging is enabled.

        Parameters:
            msg (str): The message to log.
        '''
        if self.debug:
            if self.logger is None:
                print(f"[LOCAL SEARCH] {msg}")
            else:
                self.logger(f"[LOCAL SEARCH] {msg}")
    
    def is_move_used(self, move: Tuple[int, int]):
        '''
        Checks if a move is present in the tabu list.

        Parameters:
            move (Tuple[int, int]): The move to check.

        Returns:
            bool: True if the move is in the tabu list, False otherwise.
        '''
        return move in self.tabu_list
    
    def __to_sentence_text(self, solution):
        '''
        Converts a solution to a sentence by mapping indices to words.

        Parameters:
            solution (list[int]): The solution as a list of word indices.

        Returns:
            str: The sentence formed by the solution.
        '''
        return " ".join([self.words[i] for i in solution])
    
    def to_tabulist(self, move):
        '''
        Adds a move to the tabu list.

        Parameters:
            move (Tuple[int, int]): The move to add to the tabu list.
        '''
        self.tabu_list.append(tuple(move))

    def get_candidates(self, solution: np.ndarray) -> List[Tuple[list, Tuple[int, int]]]:
        '''
        Generates candidate solutions by randomly swapping two elements.

        Parameters:
            solution (np.ndarray): The current solution as an array of indices.

        Returns:
            list[Tuple[list, Tuple[int, int]]]: A list of candidate solutions and their corresponding moves.
        '''
        candidates = []

        while len(candidates) < self.max_condidates:
            i,j = random.choices(solution, k=2)
            candidate = solution.copy()
            candidate[i], candidate[j] = candidate[j], candidate[i]
            candidates.append((candidate, (i, j)))
        
        return candidates

    def optimize(self, current_sentence: np.ndarray) -> Tuple[np.ndarray, float]:
        '''
        Executes the Tabu Search algorithm.

        Parameters:
            current_sentence (np.ndarray): The initial solution as an array of indices.

        Returns:
            Tuple[np.ndarray, float]: The best solution and its corresponding perplexity.
        '''
        self.log(f"start {current_sentence}")
        
        # Calculate the perplexity of the initial solution
        current_perplexity = self.calculate_perplexity(current_sentence)
        
        # Initialize the best solution and its perplexity
        best_sentence = current_sentence
        best_perplexity = current_perplexity
        
        # Track the number of iterations without improvement and the total iterations
        no_improvement = 0
        itr = 1
        
        # Iterate until the stopping condition is met
        while no_improvement < self.stop_on_non_improvement and itr <= self.max_iter:
            self.log(f"[ITER] [{itr}] --- {best_perplexity}")
            
            # Generate candidate solutions
            candidates = self.get_candidates(current_sentence)
            
            # Initialize variables for the best candidate solution
            best_candidate = None
            best_candidate_perplexity = None
            used_move = None
            
            # Evaluate each candidate solution
            for candidate, swap in candidates:
                candidate_perplexity = self.calculate_perplexity(candidate)
                # Update the best candidate if it is better and satisfies tabu conditions
                if best_candidate_perplexity is None \
                    or (((not self.is_move_used(swap)) or (candidate_perplexity < best_perplexity)) \
                        and (candidate_perplexity < best_candidate_perplexity)):
                        best_candidate = candidate
                        best_candidate_perplexity = candidate_perplexity
                        used_move = swap
            
            self.log(f"[{best_candidate_perplexity}] {best_candidate}")
            
            # Check if the best candidate is an improvement
            if best_candidate_perplexity < best_perplexity:
                # Update the best solution and reset the no-improvement counter
                best_sentence = best_candidate
                best_perplexity = best_candidate_perplexity
                no_improvement = 0
                self.log(f"[IM] [{itr}/{self.max_iter}] [{best_perplexity}]")
            else:
                # Increment the no-improvement counter
                no_improvement += 1
                self.log(f"[NO IM] __{no_improvement}__")
            
            # consider the best candidate as the current solution
            current_sentence = best_candidate
            current_perplexity = best_candidate_perplexity
            
            # add the move to the tabu list
            self.to_tabulist(used_move)
            itr+=1

            # Break if the no-improvement limit is reached
            if(no_improvement == self.stop_on_non_improvement):
                self.log(f"[NO IM] break; {best_sentence} | {best_perplexity}")
                break
        
        # Return the best solution and its perplexity
        return best_sentence, best_perplexity

---
# WAHLE OPT

In [5]:
import numpy as np
from itertools import permutations

class WhaleOptimization:
    """
    Implementation of the Whale Optimization Algorithm (WOA) for optimizing word order 
    in a sentence based on perplexity scoring.
    """
    
    def __init__(self, sentence: str, scorer,
                 n_whales: int = 20, max_iter: int = 80,
                 tabu_tenure=10, tabu_no_imp=10, tabu_max_iter=50,
                 cache: Caching = Caching(), logger=None, debug=False):
        """
        Initialize the Whale Optimization Algorithm.

        Args:
            sentence (str): Input sentence to optimize.
            scorer (object): Scorer with a `get_perplexity` method to evaluate sentence quality.
            n_whales (int, optional): Number of whales in the population. Defaults to 20.
            max_iter (int, optional): Maximum iterations for WOA. Defaults to 80.
            tabu_tenure (int, optional): Size of tabu list in Tabu Search. Defaults to 10.
            tabu_no_imp (int, optional): No-improvement iterations to stop Tabu Search. Defaults to 10.
            tabu_max_iter (int, optional): Maximum iterations in Tabu Search. Defaults to 50.
            logger (callable, optional): Logging function. Defaults to None.
            debug (bool, optional): Enable debug messages. Defaults to False.
        """
        self.words = sentence.split(' ')
        self.n = len(self.words)
        self.scorer = scorer
        self.n_whales = n_whales
        self.max_iter = max_iter
        self.logger = logger
        self.debug = debug
        
        self.cwi = None  # Index of the current whale
        self.itr = ''  # Current iteration
        
        # Initialize Tabu Search for local improvements
        self.TabuSearch = TabuSearch(
            self.words, 
            self.__calculate_perplexity,
            tabu_tenure=tabu_tenure,
            stop_on_non_improvement=tabu_no_imp,
            max_iter=tabu_max_iter,
            debug=debug,
            logger=self.log
        )

        # Cache for perplexity calculations
        self.cache = cache
    
    def __create_initial_sols(self) -> np.ndarray:
        """Generate a random initial solution (word permutation)."""
        return np.random.permutation(self.n)
    
    def log(self, msg):
        """
        Log messages for debugging and tracing execution.

        Args:
            msg (str): The message to log.
        """
        if self.debug:
            if self.logger is None:
                if self.cwi is None:
                    print(f"[WOA] [{self.itr}] {msg}")
                else:
                    print(f"[WOA] [{self.itr}] [WHALE {self.cwi + 1}] {msg}")
            else:
                if self.cwi is None:
                    self.logger(f"[{self.itr}] {msg}")
                else:
                    self.logger(f"[{self.itr}] [WHALE {self.cwi + 1}] {msg}")
    
    def _compute_A(self, a: float):
        """
        Compute the coefficient A for WOA movement equations.

        Args:
            a (float): Shrinking parameter.

        Returns:
            float: Coefficient A.
        """
        r = np.random.uniform(0.0, 1.0, size=1)
        return (2.0*np.multiply(a, r)) - a

    def _compute_C(self):
        """
        Compute the coefficient C for WOA movement equations.

        Returns:
            float: Coefficient C.
        """
        r = np.random.uniform(0.0, 1.0, size=1)
        return 2.0 * r
    
    def __to_sentence_text(self, sol):
        """
        Convert a solution (word order) into a readable sentence.

        Args:
            sol (np.ndarray): Solution as indices.

        Returns:
            str: Reconstructed sentence.
        """
        return " ".join([self.words[i] for i in sol])
    
    def __get_from_cache(self, solution: tuple) -> float:
        return self.cache.get(solution)
    
    def __calculate_perplexity(self, solution: np.ndarray) -> float:
        """
        Compute the perplexity of a given solution, with caching.

        Args:
            solution (np.ndarray): Solution as indices.

        Returns:
            float: Perplexity score of the solution.
        """
        solt = tuple(solution)
        cached = self.__get_from_cache(solt)
        if cached is not None:
            return cached
        
        # Convert solution to sentence and calculate perplexity
        sentence = self.__to_sentence_text(solution)
        submission = pd.DataFrame({'id': [0], 'text': [sentence] })
        perplexities = scorer.get_perplexity(submission["text"].tolist()) # TODO use self to access the scorer
        self.cache.add(solt, perplexities[0])
        return perplexities[0]

    def __caching_state(self):
        self.log(self.cache.state())
    
    def __encircling_prey(self, current_pos: np.ndarray, best_pos: np.ndarray, A: float, C: float) -> np.ndarray:
        D = abs(C * best_pos - current_pos)
        new_pos = best_pos - A * D
        return np.argsort(new_pos)
    
    def __search_for_prey(self, current_pos: np.ndarray, random_pos: np.ndarray, A: float, C: float) -> np.ndarray:
        D = abs(C * random_pos - current_pos)
        new_pos = random_pos - A * D
        return np.argsort(new_pos) 
    
    def __bubble_net_attack(self, current_pos: np.ndarray, best_pos: np.ndarray) -> np.ndarray:
        D = abs(best_pos - current_pos)
        b = 1
        l = random.uniform(-1, 1)
        new_pos = D * np.exp(l * b) * np.cos(2 * np.pi * l) + best_pos
        return np.argsort(new_pos)
    
    def __amend_position(self, position: np.ndarray) -> np.ndarray:
        """Ensure position is a valid permutation."""
        return np.argsort(position)
   
    def __local_search(self, solution: np.ndarray) -> Tuple[np.ndarray, float]:
        return self.TabuSearch.optimize(solution)
    
    def optimize(self) -> Tuple[np.ndarray, float, List[float], List[list]]:
        """
        Execute the Whale Optimization Algorithm to find the best word order.
        
        Returns:
            Tuple[np.ndarray, float]: The best solution and its perplexity score.
        """
        
        # Initialize population of whales with random solutions
        population = []
        perplexities_values = []
        for w in range(self.n_whales):
            self.cwi = w
            solution = self.__create_initial_sols()
            self.log(f"sol {solution} --> to improve")
            improved_solution, improved_perplexity = self.__local_search(solution)
            self.log(f"{improved_solution} | {improved_perplexity}")
            population.append(improved_solution)
            perplexities_values.append(improved_perplexity)
        self.cwi=None
        # population = [np.array([8, 6, 9, 1, 2, 5, 3, 0, 4, 7]), np.array([8, 6, 2, 9, 5, 3, 7, 0, 1, 4]), np.array([8, 6, 2, 9, 1, 4, 5, 7, 3, 0])]
        # perplexities_values = [531.0074131628102, 530.3983005966822, 561.4057235990515]

        self.log("------------------------------")
        self.log(f"[WHALES]")
        self.log(population)
        self.log(perplexities_values)
        self.log("------------------------------") # output
        
        # Identify the initial best whale
        best_idx = np.argmin(perplexities_values)
        prey = population[best_idx].copy()
        best_perplexity = perplexities_values[best_idx]
        
        self.log(f"[PREY] : {prey} | {best_perplexity}")
        
        # Start main WOA loop
        self.itr = 1
        while self.itr <= self.max_iter:
            self.log(f"[START ITER] {prey} | {best_perplexity}")
            
            # Shrinking encircling mechanism parameter
            a = 2 - self.itr / self.max_iter * 2
            
            # Update each whale in the population
            for self.cwi in range(self.n_whales):
                cwhale = population[self.cwi]
                self.log(f"{cwhale} | {perplexities_values[self.cwi]}")
                A = self._compute_A(a)
                C = self._compute_C()
                p = random.random()
                self.log(f"a={a}, A={A}, C={C}, p={p}")

                if p < 0.5: # TODO: check if this is important for the problem, or it can be removed
                    if abs(A) < 1:
                        # Encircling prey
                        new_pos = self.__encircling_prey(cwhale, prey, A, C)
                        self.log(f"[MOVE] encircling prey --> {new_pos}")
                    else:
                        # Searching for prey
                        rand_idx = random.randint(0, self.n_whales-1)
                        random_pos = population[rand_idx]
                        new_pos = self.__search_for_prey(cwhale, random_pos, A, C)
                        self.log(f"[MOVE] search for prey --> {new_pos}")
                else:
                    new_pos = self.__bubble_net_attack(cwhale, prey)
                    self.log(f"[MOVE] bubble net attack --> {new_pos}")
                
                # new_pos = self.__amend_position(new_pos)
                # self.log(f"amend position --> {new_pos}")
                # Apply local search to improve the new position
                improved_pos, improved_perplexity = self.__local_search(new_pos)
                self.log(f"[IMPROVE POSITION] {new_pos}, [{improved_perplexity}]")
                
                # Evaluate the new solution
                # if improved_perplexity < perplexities_values[self.cwi]: # TODO: check if this will be useful or improve the solution
                population[self.cwi] = improved_pos
                perplexities_values[self.cwi] = improved_perplexity
                
                # Update the best whale if a better solution is found
                # if improved_perplexity < best_perplexity:
                #     prey = improved_pos.copy()
                #     best_perplexity = improved_perplexity
                #     self.log(f"[BEST] {prey}, [{best_perplexity}]")
            
            # Update the prey if a better solution is found
            best_idx = np.argmin(perplexities_values)
            if perplexities_values[best_idx] < best_perplexity:
                prey = population[best_idx]
                best_perplexity = perplexities_values[best_idx]
                self.log(f"[RES] [IMP] Prey: {prey} with perplexity {best_perplexity}")
            else:
                self.log(f"[RES] [NO IMP] Prey: {prey} with perplexity {best_perplexity}")
            
            self.cwi=None
            self.itr += 1
    
        self.__caching_state()
        return prey, best_perplexity, self.__to_sentence_text(prey)

In [6]:
model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
scorer = PerplexityCalculator(model_path=model_path)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
text = 'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge'

woa = WhaleOptimization(
    text, score,
    n_whales=3, max_iter=1, debug=True, 
    tabu_tenure=5, tabu_no_imp=1, tabu_max_iter=4
)
woa.optimize()

In [8]:
# import pandas as pd
# problem = pd.read_csv("/kaggle/input/santa-2024/sample_submission.csv")

# res = []

# for idx, row in problem.iterrows():
#     print("----------------------------------------------")
#     print("----------------------------------------------")
#     print("----------------------------------------------")
#     def logger(msg):
#         print(f"[WOA {idx}] {msg}")
#     woa = WhaleOptimization(
#         row.text, score,
#         n_whales=3, max_iter=10, debug=True, 
#         tabu_tenure=10, tabu_no_imp=3, tabu_max_iter=20,
#         logger = logger
#     )
#     sol, p, sen = woa.optimize()
#     res.append({
#         "id": len(res),
#         "or": row.text,
#         "sol": sol,
#         "text": sen,
#         "p": p
#     })

# submission_ = pd.DataFrame(res)
# submission_