In [None]:
import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

from collections import deque
from typing import Tuple, List
from collections import deque
import random

In [2]:
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ParticipantVisibleError(Exception):
    pass

def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))

class PerplexityCalculator:
    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_perplexity(
        self, input_texts: Union[str, List[str]], debug=False
    ) -> Union[float, List[float]]:
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )

                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

                # Debug output
                if debug:
                    print(f"\nProcessing: '{text}'")
                    print(f"With special tokens: '{text_with_special}'")
                    print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                    print(f"Target tokens: {shift_labels[0].tolist()}")
                    print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                    print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                    print(f"Individual losses: {loss.tolist()}")
                    print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        if debug:
            print("\nFinal perplexities:")
            for text, perp in zip(input_texts, ppl):
                print(f"Text: '{text}'")
                print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

In [None]:
model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
scorer = PerplexityCalculator(model_path=model_path)

In [4]:
submission = pd.read_csv("/kaggle/input/santa-2024/sample_submission.csv")
#perplexities = scorer.get_perplexity(submission["text"].tolist())
#perplexities

In [25]:
class TabuSearch:
    def __init__(self, calculate_perplexity, tabu_tenure=5, stop_on_non_improvement=20, max_iter=100):
        self.tabu_tenure = tabu_tenure
        self.stop_on_non_improvement = stop_on_non_improvement
        self.max_iter = max_iter
        self.calculate_perplexity = calculate_perplexity
        self.tabu_list = deque(maxlen=self.tabu_tenure)
    
    def is_sentence_visited(self, sentence):
        return self.to_text(sentence) in self.tabu_list
    
    def to_text(self, sentence):
        return " ".join(sentence) 

    def get_candidates(self, sentence): # two-opt swap
        neighbors = []
        
        for _ in range(len(sentence)):
            node1 = 0
            node2 = 0
            
            while node1 == node2:
                node1 = random.randint(1, len(sentence)-1)
                node2 = random.randint(1, len(sentence)-1)
                
            if node1 > node2:
                swap = node1
                node1 = node2
                node2 = swap
                
            tmp = sentence[node1:node2]
            tmp_route = sentence[:node1] + tmp[::-1] +sentence[node2:]
            neighbors.append(tmp_route)
            
        return neighbors

    def optimize(self, current_sentence) -> Tuple[str, float]:
        current_perplexity = self.calculate_perplexity(current_sentence)

        best_sentence = current_sentence
        best_perplexity = current_perplexity

        no_improvement = 0
        itr = 1
        
        while True and itr <= self.max_iter:
            itr+=1
            print(f"[ITER] [{itr}] --- {best_perplexity}")
            
            candidates = self.get_candidates(current_sentence)

            best_candidate = None
            best_candidate_perplexity = None

            for candidate in candidates:
                candidate_perplexity = self.calculate_perplexity(candidate)
                if self.is_sentence_visited(candidate) or best_candidate_perplexity is None or candidate_perplexity < best_candidate_perplexity:
                        best_candidate = candidate
                        best_candidate_perplexity = candidate_perplexity
            print(f"[{best_candidate_perplexity}] {self.to_text(best_candidate)}")

            if best_candidate_perplexity < best_perplexity:
                best_sentence = best_candidate
                best_perplexity = best_candidate_perplexity
                no_improvement = 0
                print(f"[IM] [{best_perplexity}]")
            else:
                no_improvement += 1
                print(f"[NO IM] __{no_improvement}__")

            if no_improvement == self.stop_on_non_improvement:
                break

            current_sentence = best_candidate
            current_perplexity = best_candidate_perplexity

            self.tabu_list.append(best_sentence)

        return best_sentence, best_perplexity

---

In [None]:
import numpy as np
from itertools import permutations


class WhaleOptimizationForPerplexity:
    def __init__(self, words, pop_size=30, max_iter=100):
        # TODO pop_size < n!
        self.words = words
        self.n = len(words)  # Number of words in the sequence
        self.pop_size = pop_size
        self.max_iter = max_iter
        self.population = None
        self.best_solution = None
        self.best_fitness = float('inf')
    
    def calculate_perplexity(self, sentence) -> float:
        sentence = " ".join(sentence)
        submission = pd.DataFrame({'id': [0], 'text': [sentence] })
        perplexities = scorer.get_perplexity(submission["text"].tolist())
        # print(f"{perplexities[0]} || {sentence}")
        return perplexities[0]

    def initialize_population(self):
        """Randomly initialize the population with valid permutations of the sequence."""
        perms = list(permutations(self.words))
        random_indices = np.random.choice(len(perms), self.pop_size, replace=False)
        self.population = [list(perms[idx]) for idx in random_indices]

    def evaluate_population(self):
        """Evaluate the perplexity of each sequence in the population."""
        return np.array([self.calculate_perplexity(seq) for seq in self.population])

    def optimize(self):
        """Perform the optimization using WOA."""
        # Initialize population and evaluate fitness
        self.initialize_population()
        fitness = self.evaluate_population()

        # Track the best solution
        best_idx = np.argmin(fitness)
        self.best_solution = self.population[best_idx]
        self.best_fitness = fitness[best_idx]

        # Optimization loop
        for itr in range(self.max_iter):
            a = 2 - itr * (2 / self.max_iter)  # Linearly decreasing parameter

            for i in range(self.pop_size):
                r = np.random.random(self.n)  # Random vector
                A = 2 * a * r - a  # Encircling prey
                C = 2 * r  # Attraction parameter
                p = np.random.random()  # Probability for exploitation vs exploration

                if p < 0.5:
                    if np.linalg.norm(A) < 1:  # Exploitation: Move closer to the best solution
                        new_position = self.population[i].copy()
                        for j in range(self.n):
                            if np.random.random() < np.abs(A[j]):  # Update based on closeness to the best
                                new_position[j] = self.best_solution[j]
                    else:  # Exploration: Move closer to a random whale
                        random_whale = self.population[np.random.randint(0, self.pop_size)]
                        new_position = self.population[i].copy()
                        for j in range(self.n):
                            if np.random.random() < np.abs(A[j]):  # Update based on closeness to the random whale
                                new_position[j] = random_whale[j]
                else:  # Spiral updating (bubble-net hunting)
                    l = np.random.uniform(-1, 1)
                    new_position = self.population[i].copy()
                    for j in range(self.n):
                        distance = self.best_solution[j] != new_position[j]
                        if np.random.random() < np.exp(-distance * l):  # Spiral towards the best
                            new_position[j] = self.best_solution[j]

                # Ensure new position is a valid permutation
                new_position = list(np.random.permutation(new_position))

                # Update population
                self.population[i] = new_position

            # Evaluate fitness after update
            fitness = self.evaluate_population()
            current_best_idx = np.argmin(fitness)
            current_best_fitness = fitness[current_best_idx]

            # Update global best if needed
            if current_best_fitness < self.best_fitness:
                self.best_fitness = current_best_fitness
                self.best_solution = self.population[current_best_idx]

            # Logging (optional)
            print(f"Iteration {itr+1}/{self.max_iter}, Best Perplexity: {self.best_fitness}")

        return self.best_solution, self.best_fitness

In [None]:
sequence = [word for sublist in submission['text'].str.split(' ') for word in sublist]

woa = WhaleOptimizationForPerplexity(sequence, pop_size=10, max_iter=50)
woa.optimize()