In [1]:
import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

from collections import deque
from typing import Tuple, List
from collections import deque
import random

In [2]:
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ParticipantVisibleError(Exception):
    pass

def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))

class PerplexityCalculator:
    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_perplexity(
        self, input_texts: Union[str, List[str]], debug=False
    ) -> Union[float, List[float]]:
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )

                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

                # Debug output
                if debug:
                    print(f"\nProcessing: '{text}'")
                    print(f"With special tokens: '{text_with_special}'")
                    print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                    print(f"Target tokens: {shift_labels[0].tolist()}")
                    print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                    print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                    print(f"Individual losses: {loss.tolist()}")
                    print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        if debug:
            print("\nFinal perplexities:")
            for text, perp in zip(input_texts, ppl):
                print(f"Text: '{text}'")
                print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

In [3]:
model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
scorer = PerplexityCalculator(model_path=model_path)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

---
# TabuSearch

In [4]:
class TabuSearch:
    def __init__(self, words: list, calculate_perplexity, tabu_tenure=5,
                 stop_on_non_improvement=20, max_iter=100, 
                 max_condidates=100,
                 debug=False, logger = None):
        self.words = words
        self.n = len(self.words)
        self.tabu_tenure = tabu_tenure
        self.stop_on_non_improvement = stop_on_non_improvement
        self.max_iter = max_iter
        self.calculate_perplexity = calculate_perplexity
        self.tabu_list = deque(maxlen=self.tabu_tenure)
        self.max_condidates = max_condidates
        self.debug = debug
        self.logger = logger

    def log(self, msg):
        if self.debug:
            if self.logger is None:
                print(f"[LOCAL SEARCH] {msg}")
            else:
                self.logger(f"[LOCAL SEARCH] {msg}")
    
    def is_move_used(self, move: Tuple[int, int]):
        return move in self.tabu_list
    
    def __to_sentence_text(self, sentence):
        return " ".join([self.words[i] for i in sentence])
    
    def to_tabulist(self, move):
        self.tabu_list.append(tuple(move))

    def get_candidates(self, solution: np.ndarray) -> List[Tuple[list, Tuple[int, int]]]:
        candidates = []

        itr = 0
        while itr < self.max_condidates:
            i,j = random.choices(solution, k=2)
            candidate = solution.copy()
            candidate[i], candidate[j] = candidate[j], candidate[i]
            candidates.append((candidate, (i, j)))
            itr+=1
        
        return candidates

    def optimize(self, current_sentence: np.ndarray) -> Tuple[np.ndarray, float]:
        self.log(f"start {current_sentence}")
        current_perplexity = self.calculate_perplexity(current_sentence)
        
        best_sentence = current_sentence
        best_perplexity = current_perplexity
        
        no_improvement = 0
        itr = 1
        
        while no_improvement < self.stop_on_non_improvement and itr <= self.max_iter:
            self.log(f"[ITER] [{itr}] --- {best_perplexity}")
            
            candidates = self.get_candidates(current_sentence)
            
            best_candidate = None
            best_candidate_perplexity = None
            used_move = None
            
            for candidate, swap in candidates:
                candidate_perplexity = self.calculate_perplexity(candidate)

                if best_candidate_perplexity is None \
                    or ((not self.is_move_used(swap)) or (candidate_perplexity < best_perplexity)) \
                        and (candidate_perplexity < best_candidate_perplexity):
                        best_candidate = candidate
                        best_candidate_perplexity = candidate_perplexity
                        used_move = swap
            
            self.log(f"[{best_candidate_perplexity}] {best_candidate}")
            
            if best_candidate_perplexity < best_perplexity:
                best_sentence = best_candidate
                best_perplexity = best_candidate_perplexity
                no_improvement = 0
                self.log(f"[IM] [{itr}/{self.max_iter}] [{best_perplexity}]")
            else:
                no_improvement += 1
                self.log(f"[NO IM] __{no_improvement}__")
            
            current_sentence = best_candidate
            current_perplexity = best_candidate_perplexity
            
            self.to_tabulist(used_move)
            itr+=1

            if(no_improvement == self.stop_on_non_improvement):
                self.log(f"[NO IM] break; {best_sentence} | {best_perplexity}")
            
        return best_sentence, best_perplexity

In [5]:
# import numpy as np

# words = problem['text'][0].split()
# print(words)

# init = np.random.permutation(len(words)).tolist()
# print(init)

# def __calculate_fitness(solution) -> float:
#     sentence = " ".join([words[i] for i in solution])
#     problem = pd.DataFrame({'id': [0], 'text': [sentence] })
#     perplexities = scorer.get_perplexity(problem["text"].tolist())
#     # print(f"{perplexities[0]} || {sentence}")
#     return perplexities[0]

# tabuSearch = TabuSearch(words, __calculate_fitness, tabu_tenure=10, stop_on_non_improvement=50, debug=True)

# tabuSearch.optimize(init)

---

In [6]:
import numpy as np
from itertools import permutations

class WhaleOptimization:
    def __init__(self, sentence: str, scorer,
                 n_whales: int = 20, max_iter: int = 80,
                 tabu_tenure=10, tabu_no_imp=10, tabu_max_iter=50,
                 logger=None, debug=False):
        self.words = sentence.split(' ')
        self.n = len(self.words)
        self.scorer = scorer
        self.n_whales = n_whales
        self.max_iter = max_iter
        self.logger = logger
        self.debug = debug
        self.cur_whale = None
        self.itr = None
        self.TabuSearch = TabuSearch(
            self.words, 
            self.__calculate_fitness,
            tabu_tenure=tabu_tenure,
            stop_on_non_improvement=tabu_no_imp,
            max_iter=tabu_max_iter,
            debug=debug,
            logger=self.log
        )

        self.cache = {}
        self.calccached = 0
        self.calcnotcached = 0
        self.calcattempt = 0
    
    def __create_initial_sols(self) -> np.ndarray:
        """Create a random permutation solution"""
        return np.random.permutation(self.n)
    
    def log(self, msg):
        if self.debug:
            if self.logger is None:
                if self.cur_whale is None:
                    print(f"[WOA] [{self.itr}] {msg}")
                else:
                    print(f"[WOA] [{self.itr}] [WHALE {self.cur_whale + 1}] {msg}")
            else:
                if self.cur_whale is None:
                    self.logger(f"[{self.itr}] {msg}")
                else:
                    self.logger(f"[{self.itr}] [WHALE {self.cur_whale + 1}] {msg}")
    
    def _compute_A(self, a: float):
        r = np.random.uniform(0.0, 1.0, size=1)
        return (2.0*np.multiply(a, r)) - a

    def _compute_C(self):
        return 2.0 * np.random.uniform(0.0, 1.0, size=1)
    
    def __to_sentence_text(self, sol):
        return " ".join([self.words[i] for i in sol])
    
    def __calculate_fitness(self, solution: np.ndarray) -> float:
        self.calcattempt += 1
        solt = tuple(solution)
        if solt in self.cache:
            self.calccached += 1
            return self.cache[solt]
        
        sentence = self.__to_sentence_text(solution)
        submission = pd.DataFrame({'id': [0], 'text': [sentence] })
        perplexities = scorer.get_perplexity(submission["text"].tolist()) # TODO use self to access the scorer
        # print(f"{perplexities[0]} || {sentence}")
        self.cache[solt] = perplexities[0]
        self.calcnotcached += 1
        return perplexities[0]

    def __caching_state(self):
        self.log(f"[CACHING] attempt: [{self.calcattempt}], "
                    f"cached: [{self.calccached}] ({self.calccached / self.calcattempt * 100}%), "
                    f"calculated: [{self.calcnotcached}] ({self.calcnotcached / self.calcattempt * 100}%)")
    
    def __encircling_prey(self, current_pos: np.ndarray, best_pos: np.ndarray, A: float, C: float) -> np.ndarray:
        D = abs(C * best_pos - current_pos)
        new_pos = best_pos - A * D
        return np.argsort(new_pos)
    
    def __search_for_prey(self, current_pos: np.ndarray, random_pos: np.ndarray, A: float, C: float) -> np.ndarray:
        D = abs(C * random_pos - current_pos)
        new_pos = random_pos - A * D
        return np.argsort(new_pos) 
    
    def __bubble_net_attack(self, current_pos: np.ndarray, best_pos: np.ndarray) -> np.ndarray:
        D = abs(best_pos - current_pos)
        b = 1
        l = random.uniform(-1, 1)
        new_pos = D * np.exp(l * b) * np.cos(2 * np.pi * l) + best_pos
        return np.argsort(new_pos)
    
    def __amend_position(self, position: np.ndarray) -> np.ndarray:
        """Ensure position is a valid permutation."""
        return np.argsort(position)
   
    def __local_search(self, solution: np.ndarray) -> Tuple[np.ndarray, float]:
        return self.TabuSearch.optimize(solution)
    
    def optimize(self) -> Tuple[np.ndarray, float, List[float], List[list]]:
        # Initialize population with local search improvement
        population = []
        fitness_values = []
        for w in range(self.n_whales):
            self.cur_whale = w
            solution = self.__create_initial_sols()
            self.log(f"sol {solution} --> to improve")
            improved_solution, improved_fitness = self.__local_search(solution)
            self.log(f"{improved_solution} | {improved_fitness}")
            population.append(improved_solution)
            fitness_values.append(improved_fitness)
        self.cur_whale=None
        # population = [np.array([8, 6, 9, 1, 2, 5, 3, 0, 4, 7]), np.array([8, 6, 2, 9, 5, 3, 7, 0, 1, 4]), np.array([8, 6, 2, 9, 1, 4, 5, 7, 3, 0])]
        # fitness_values = [531.0074131628102, 530.3983005966822, 561.4057235990515]

        self.log("------------------------------")
        self.log(f"[WHALES]")
        self.log(population)
        self.log(fitness_values)
        self.log("------------------------------")
        
        best_idx = np.argmin(fitness_values)
        best_pos = population[best_idx].copy()
        best_fitness = fitness_values[best_idx]
        
        self.itr = 1
        while self.itr <= self.max_iter:
            self.log(f"[START ITER] {best_pos} | {best_fitness}")
            for w in range(self.n_whales):
                self.cur_whale = w
                whale = population[w]
                self.log(f"{whale} | {fitness_values[w]}")
                a = 2 - self.itr * (2 / self.max_iter)
                A = self._compute_A(a)
                C = self._compute_C()
                p = random.random()
                self.log(f"a={a}, A={A}, C={C}, p={p}")

                if p < 0.5:
                    if abs(A) < 1:
                        new_pos = self.__encircling_prey(whale, best_pos, A, C)
                        self.log(f"[MOVE] encircling prey --> {new_pos}")
                    else:
                        rand_idx = random.randint(0, self.n_whales-1)
                        random_pos = population[rand_idx]
                        new_pos = self.__search_for_prey(whale, random_pos, A, C)
                        self.log(f"[MOVE] search for prey --> {new_pos}")
                else:
                    new_pos = self.__bubble_net_attack(whale, best_pos)
                    self.log(f"[MOVE] bubble net attack --> {new_pos}")
                
                # new_pos = self.__amend_position(new_pos)
                # self.log(f"amend position --> {new_pos}")
                # Apply local search to improve the new position
                improved_pos, improved_fitness = self.__local_search(new_pos)
                self.log(f"[IMPROVE POSITION] {new_pos}, [{improved_fitness}]")
                
                population[w] = improved_pos
                fitness_values[w] = improved_fitness
                
                if improved_fitness < best_fitness:
                    best_pos = improved_pos.copy()
                    best_fitness = improved_fitness
                    self.log(f"[BEST] {best_pos}, [{best_fitness}]")
            self.cur_whale=None
            self.itr += 1

        self.__caching_state()
        return best_pos, best_fitness, self.__to_sentence_text(best_pos)

In [7]:
# # text = problem['text'][0]

# text = 'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge'

# woa = WhaleOptimization(
#     text, score, 
#     n_whales=3, max_iter=1, debug=True, 
#     tabu_tenure=5, tabu_no_imp=1, tabu_max_iter=3
# )
# woa.optimize()

In [8]:
import pandas as pd
problem = pd.read_csv("/kaggle/input/santa-2024/sample_submission.csv")

res = []

for idx, row in problem.iterrows():
    print("----------------------------------------------")
    print("----------------------------------------------")
    print("----------------------------------------------")
    def logger(msg):
        print(f"[WOA {idx}] {msg}")
    woa = WhaleOptimization(
        row.text, score,
        n_whales=3, max_iter=10, debug=True, 
        tabu_tenure=10, tabu_no_imp=3, tabu_max_iter=20,
        logger = logger
    )
    sol, p, sen = woa.optimize()
    res.append({
        "id": len(res),
        "or": row.text,
        "sol": sol,
        "text": sen,
        "p": p
    })

submission_ = pd.DataFrame(res)
submission_

----------------------------------------------
----------------------------------------------
----------------------------------------------
[WOA 0] [None] [WHALE 1] sol [0 5 4 6 2 3 9 8 1 7] --> to improve
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] start [0 5 4 6 2 3 9 8 1 7]
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [ITER] [1] --- 4810.95576331565
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [1679.8564191837218] [5 0 4 6 2 3 9 8 1 7]
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [IM] [1/20] [1679.8564191837218]
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [ITER] [2] --- 1679.8564191837218
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [1019.2979575497574] [5 0 4 1 2 3 9 8 6 7]
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [IM] [2/20] [1019.2979575497574]
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [ITER] [3] --- 1019.2979575497574
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [923.6470701622601] [5 3 4 1 2 0 9 8 6 7]
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH] [IM] [3/20] [923.6470701622601]
[WOA 0] [None] [WHALE 1] [LOCAL SEARCH]

Unnamed: 0,id,or,sol,text
0,0,advent chimney elf family fireplace gingerbrea...,"[8, 6, 2, 9, 5, 1, 4, 7, 3, 0]",reindeer mistletoe elf scrooge gingerbread chi...
1,1,advent chimney elf family fireplace gingerbrea...,"[8, 6, 2, 9, 5, 7, 0, 1, 4, 3, 17, 16, 10, 13,...",reindeer mistletoe elf scrooge gingerbread orn...
2,2,yuletide decorations gifts cheer holiday carol...,"[16, 0, 5, 4, 3, 15, 18, 19, 11, 10, 12, 2, 13...",jingle yuletide carol holiday cheer holly naug...
3,3,yuletide decorations gifts cheer holiday carol...,"[10, 23, 2, 29, 16, 12, 15, 18, 19, 8, 17, 9, ...",sleigh of gifts unwrap jingle workshop holly n...
4,4,hohoho candle poinsettia snowglobe peppermint ...,"[47, 9, 31, 5, 6, 10, 35, 28, 41, 48, 46, 38, ...",from puzzle star eggnog fruitcake game the nig...
5,5,advent chimney elf family fireplace gingerbrea...,"[59, 20, 66, 65, 55, 19, 16, 5, 46, 80, 42, 23...",puzzle yuletide dream believe eggnog and sleep...


In [9]:
submission = submission_[['id', 'text']]
submission.to_csv("submission.csv")
submission

Unnamed: 0,id,text
0,0,reindeer mistletoe elf scrooge gingerbread chi...
1,1,reindeer mistletoe elf scrooge gingerbread orn...
2,2,jingle yuletide carol holiday cheer holly naug...
3,3,sleigh of gifts unwrap jingle workshop holly n...
4,4,from puzzle star eggnog fruitcake game the nig...
5,5,puzzle yuletide dream believe eggnog and sleep...
