In [1]:
import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ParticipantVisibleError(Exception):
    pass


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = False,
    clear_mem: bool = False,
) -> float:
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))


class PerplexityCalculator:
    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = False,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()

    def get_perplexity(
        self, input_texts: Union[str, List[str]], debug=False
    ) -> Union[float, List[float]]:
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input

                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )

                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

                # Debug output
                if debug:
                    print(f"\nProcessing: '{text}'")
                    print(f"With special tokens: '{text_with_special}'")
                    print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                    print(f"Target tokens: {shift_labels[0].tolist()}")
                    print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                    print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                    print(f"Individual losses: {loss.tolist()}")
                    print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        if debug:
            print("\nFinal perplexities:")
            for text, perp in zip(input_texts, ppl):
                print(f"Text: '{text}'")
                print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

In [2]:
import pandas as pd
model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
scorer = PerplexityCalculator(model_path=model_path)

# submission = pd.DataFrame({
#      'id': [0, 1, 2],
#      'text': ["this is a normal english sentence", "thsi is a slihgtly misspelled zr4g sentense", "the quick brown fox jumps over the lazy dog"]
# })
# perplexities = scorer.get_perplexity(submission["text"].tolist())
# perplexities

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

---

In [3]:
sentences = pd.read_csv("/kaggle/input/santa-2024/sample_submission.csv")
sentences

Unnamed: 0,id,text
0,0,advent chimney elf family fireplace gingerbrea...
1,1,advent chimney elf family fireplace gingerbrea...
2,2,yuletide decorations gifts cheer holiday carol...
3,3,yuletide decorations gifts cheer holiday carol...
4,4,hohoho candle poinsettia snowglobe peppermint ...
5,5,advent chimney elf family fireplace gingerbrea...


In [31]:
sentence = sentences['text'][0]
sentence

'advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge'

In [54]:
import random

def calculate_cost(sentence):
    submission = pd.DataFrame({
         'id': [0],
         'text': [" ".join(sentence)]
    })
    perplexities = scorer.get_perplexity(submission["text"].tolist())
    return perplexities[0]

def generate_initial_solution(words):
    shuffled_words = words[:]
    random.shuffle(shuffled_words)
    return shuffled_words

In [58]:
def local_search(words, max_iter):
    n = len(words)
    
    current_solution = generate_initial_solution(words)
    current_cost = calculate_cost(current_solution)
    
    iterations = 0

    while iterations < max_iter:
        iterations += 1
        print(f"------------------ [{iterations}] ------------------")
        
        for i in range(n):
            for j in range(i + 1, n):
                new_solution = current_solution[:]
                new_solution[i], new_solution[j] = new_solution[j], new_solution[i]
                
                new_cost = calculate_cost(new_solution) 
                
                if new_cost < current_cost:
                    print(f"sentence: {' '.join(new_solution)} -> perplexity: {new_cost}")
                    current_solution = new_solution
                    current_cost = new_cost
                    break

    return current_solution, current_cost


In [60]:
words, score = local_search(sentences['text'][0].split(" "), max_iter=20)
" ".join(words)
# sentence: reindeer mistletoe elf scrooge gingerbread family ornament advent chimney fireplace -> perplexity: 532.3726691377844


------------------ [1] ------------------
sentence: gingerbread ornament scrooge chimney family fireplace reindeer advent mistletoe elf -> perplexity: 1894.8121816143816
sentence: gingerbread ornament fireplace chimney family scrooge reindeer advent mistletoe elf -> perplexity: 1601.8368003527937
sentence: gingerbread ornament fireplace elf family scrooge reindeer advent mistletoe chimney -> perplexity: 1540.4713962049877
sentence: gingerbread ornament fireplace elf scrooge family reindeer advent mistletoe chimney -> perplexity: 1487.2551219966645
sentence: gingerbread ornament fireplace elf scrooge family reindeer mistletoe advent chimney -> perplexity: 1397.1468893399162
------------------ [2] ------------------
sentence: gingerbread family fireplace elf scrooge ornament reindeer mistletoe advent chimney -> perplexity: 1386.2742059875643
sentence: gingerbread family advent elf scrooge ornament reindeer mistletoe fireplace chimney -> perplexity: 1237.8035409730126
sentence: gingerbrea

'reindeer mistletoe gingerbread family scrooge elf advent chimney fireplace ornament'