In [1]:
import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ParticipantVisibleError(Exception):
    pass


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))


class PerplexityCalculator:
    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_perplexity(
        self, input_texts: Union[str, List[str]], debug=False
    ) -> Union[float, List[float]]:
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )

                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

                # Debug output
                if debug:
                    print(f"\nProcessing: '{text}'")
                    print(f"With special tokens: '{text_with_special}'")
                    print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                    print(f"Target tokens: {shift_labels[0].tolist()}")
                    print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                    print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                    print(f"Individual losses: {loss.tolist()}")
                    print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        if debug:
            print("\nFinal perplexities:")
            for text, perp in zip(input_texts, ppl):
                print(f"Text: '{text}'")
                print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

In [4]:
import pandas as pd
model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
scorer = PerplexityCalculator(model_path=model_path)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [5]:
submission = pd.read_csv("/kaggle/input/santa-2024/sample_submission.csv")
#perplexities = scorer.get_perplexity(submission["text"].tolist())
#perplexities

In [35]:
from collections import deque
from typing import Tuple
import random
import pandas as pd

class TabuSearch:
    def __init__(self, sentence: str, tabu_tenure=5, update_ani=20, max_iter=100):
        """
        update_ani: incress tabu_tenure after no improvment
        """
        self.words = sentence.split()
        self.n = len(self.words)
        self.tabu_tenure = tabu_tenure
        self.update_ani = update_ani
        self.max_iter = max_iter
        self.tabu_list = deque(maxlen=self.tabu_tenure)
        
        self.cache = {}
        self.attempt = 0
        self.calc = 0
        self.cached = 0
        
    def logcaching(self):
        print(f"[CACHING] attempt: {self.attempt}, calc: {self.calc}, cached: {self.cached}")
    
    def calculate_perplexity(self, sol) -> float:
        self.attempt+=1
        perplexity = 0
        tsol = tuple(sol)
        
        if tsol in self.cache:
            self.cached+=1
            perplexity = self.cache[tsol]
        else:
            snetence = " ".join([self.words[i] for i in sol])
            submission = pd.DataFrame({'id': [0], 'text': [snetence] })
            perplexity = scorer.get_perplexity(submission["text"].tolist())[0]
            self.cache[tsol] = perplexity
            self.calc+=1
        
        return perplexity
    
    def init(self):
        r = list(range(self.n))
        random.shuffle(r)  
        return r
    
    def is_sentence_visited(self, sentence):
        return tuple(sentence) in self.tabu_list
    
    def to_text(self, sentence):
        return " ".join([self.words[i] for i in sentence])

    def get_candidates(self, sentence: list):
        n = len(sentence)
        neighbors = []
        
        for i in range(n):
            for j in range(i + 1, n):
                neighbor = sentence.copy()
                neighbor[i], neighbor[j] = neighbor[j], neighbor[i]
                neighbors.append(neighbor)
        
        return neighbors

    def optimize(self, current_sentence) -> Tuple[str, float]:
        if current_sentence is None:
            current_sentence = self.init()
        
        current_perplexity = self.calculate_perplexity(current_sentence)

        best_sentence = current_sentence
        best_perplexity = current_perplexity
        print(f"[START] {best_sentence} --- {best_perplexity}")
        
        no_improvement = 0
        itr = 1
        
        while True and itr <= self.max_iter:
            # print(f"[ITER] [{itr}/{self.max_iter}] --- {best_perplexity}")
            
            candidates = self.get_candidates(current_sentence)

            best_candidate = None
            best_candidate_perplexity = None
            
            for candidate in candidates:
                candidate_perplexity = self.calculate_perplexity(candidate)
                if self.is_sentence_visited(candidate) or best_candidate_perplexity is None or candidate_perplexity < best_candidate_perplexity:
                        best_candidate = candidate
                        best_candidate_perplexity = candidate_perplexity
            # print(f"[{best_candidate_perplexity}] {self.to_text(best_candidate)}")

            if best_candidate_perplexity < best_perplexity:
                best_sentence = best_candidate
                best_perplexity = best_candidate_perplexity
                no_improvement = 0
                print(f"[IM] [{itr}/{self.max_iter}] [{best_perplexity}]")
            else:
                no_improvement += 1
                # print(f"[NO IM] __{no_improvement}__")

            if no_improvement == self.update_ani:
                if self.tabu_tenure > 50:
                    print(f"[NO IM] [{itr}/{self.max_iter}] MAX [{self.tabu_tenure}] | break;")
                    break
                no_improvement = 0
                old = self.tabu_list
                self.tabu_tenure += 5
                self.tabu_list = deque(maxlen=self.tabu_tenure)
                self.tabu_list.extend(old)
                print("tabu_list", end=" || ")
                print(f"[NO IM] [{itr}/{self.max_iter}] INCRESS TABU_LIST [{self.tabu_tenure}]")
            
            current_sentence = best_candidate
            current_perplexity = best_candidate_perplexity

            self.tabu_list.append(tuple(best_sentence))
            itr+=1
                
        self.logcaching()
        return self.to_text(best_sentence), best_sentence, best_perplexity

In [36]:
import random

data = pd.read_csv("/kaggle/input/santa-2024/sample_submission.csv")

ts = TabuSearch(data['text'][0], tabu_tenure=10, update_ani=50, max_iter=1000)
ts.optimize(ts.init())

[START] [1, 2, 8, 7, 4, 3, 9, 0, 5, 6] --- 2900.9736045356544
[IM] [1/1000] [1641.5146668588116]
[IM] [2/1000] [1322.6646815536633]
[IM] [3/1000] [926.119419511793]
[IM] [4/1000] [768.6019472806792]
[IM] [15/1000] [688.6363021417106]
[IM] [16/1000] [584.088676418536]
tabu_list || [NO IM] [66/1000] INCRESS TABU_LIST [15]
[CACHING] attempt: 4456, calc: 384, cached: 4072
tabu_list || [NO IM] [116/1000] INCRESS TABU_LIST [20]
tabu_list || [NO IM] [166/1000] INCRESS TABU_LIST [25]
[CACHING] attempt: 8956, calc: 384, cached: 8572
tabu_list || [NO IM] [216/1000] INCRESS TABU_LIST [30]
tabu_list || [NO IM] [266/1000] INCRESS TABU_LIST [35]
[CACHING] attempt: 13456, calc: 384, cached: 13072
tabu_list || [NO IM] [316/1000] INCRESS TABU_LIST [40]
tabu_list || [NO IM] [366/1000] INCRESS TABU_LIST [45]
[CACHING] attempt: 17956, calc: 384, cached: 17572
tabu_list || [NO IM] [416/1000] INCRESS TABU_LIST [50]
tabu_list || [NO IM] [466/1000] INCRESS TABU_LIST [55]
[CACHING] attempt: 22456, calc: 384, c

('reindeer mistletoe gingerbread chimney fireplace advent scrooge elf ornament family',
 [8, 6, 5, 1, 4, 0, 9, 2, 7, 3],
 584.088676418536)

In [37]:
import random

data = pd.read_csv("/kaggle/input/santa-2024/sample_submission.csv")

ts = TabuSearch(data['text'][0], tabu_tenure=10, update_ani=50, max_iter=1000)
ts.optimize([8, 6, 5, 1, 4, 0, 9, 2, 7, 3])

[START] [8, 6, 5, 1, 4, 0, 9, 2, 7, 3] --- 584.088676418536
tabu_list || [NO IM] [50/1000] INCRESS TABU_LIST [15]
[CACHING] attempt: 4456, calc: 90, cached: 4366
tabu_list || [NO IM] [100/1000] INCRESS TABU_LIST [20]
tabu_list || [NO IM] [150/1000] INCRESS TABU_LIST [25]
[CACHING] attempt: 8956, calc: 90, cached: 8866
tabu_list || [NO IM] [200/1000] INCRESS TABU_LIST [30]
tabu_list || [NO IM] [250/1000] INCRESS TABU_LIST [35]
[CACHING] attempt: 13456, calc: 90, cached: 13366
tabu_list || [NO IM] [300/1000] INCRESS TABU_LIST [40]
tabu_list || [NO IM] [350/1000] INCRESS TABU_LIST [45]
[CACHING] attempt: 17956, calc: 90, cached: 17866
tabu_list || [NO IM] [400/1000] INCRESS TABU_LIST [50]
tabu_list || [NO IM] [450/1000] INCRESS TABU_LIST [55]
[CACHING] attempt: 22456, calc: 90, cached: 22366
[NO IM] [500/1000] MAX [55] | break;


('reindeer mistletoe gingerbread chimney fireplace advent scrooge elf ornament family',
 [8, 6, 5, 1, 4, 0, 9, 2, 7, 3],
 584.088676418536)

In [None]:
import random

data = pd.read_csv("/kaggle/input/santa-2024/sample_submission.csv")

res = []

for text in data['text']:
    ts = TabuSearch(text, tabu_tenure=10, update_ani=50, max_iter=1000)
    sen, sol, perplexity = ts.optimize(ts.init())
    res.append((sen, sol, perplexity))

res

[START] [2, 3, 6, 7, 1, 8, 5, 0, 4, 9] --- 10541.869315984432
[IM] [1/1000] [2549.5211378804725]
[IM] [2/1000] [1067.78985306711]
[IM] [3/1000] [1037.2223535190114]
[IM] [14/1000] [720.8217523105569]
[IM] [15/1000] [653.1138361007249]
[IM] [16/1000] [593.2027189881885]
[IM] [27/1000] [552.8875041123517]
tabu_list || [NO IM] [77/1000] INCRESS TABU_LIST [15]
[CACHING] attempt: 4456, calc: 469, cached: 3987
tabu_list || [NO IM] [127/1000] INCRESS TABU_LIST [20]
tabu_list || [NO IM] [177/1000] INCRESS TABU_LIST [25]
[CACHING] attempt: 8956, calc: 469, cached: 8487
tabu_list || [NO IM] [227/1000] INCRESS TABU_LIST [30]
tabu_list || [NO IM] [277/1000] INCRESS TABU_LIST [35]
[CACHING] attempt: 13456, calc: 469, cached: 12987
tabu_list || [NO IM] [327/1000] INCRESS TABU_LIST [40]
tabu_list || [NO IM] [377/1000] INCRESS TABU_LIST [45]
[CACHING] attempt: 17956, calc: 469, cached: 17487
tabu_list || [NO IM] [427/1000] INCRESS TABU_LIST [50]
tabu_list || [NO IM] [477/1000] INCRESS TABU_LIST [55]
[