In [None]:
# Importing necessary libraries
import numpy as np  # For numerical computations
import pandas as pd  # For data manipulation and analysis

# Importing the os module to interact with the file system
import os

# This loop will list all files in the '/kaggle/input' directory.
# Useful for exploring the input data files available in the Kaggle environment.
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Note:
# - You can save up to 20GB of data in the '/kaggle/working/' directory, which persists when you save your notebook.
# - Temporary files can be written to '/kaggle/temp/', but they won't be saved after the session ends.

/kaggle/input/gemma/transformers/2b/2/model.safetensors.index.json
/kaggle/input/gemma/transformers/2b/2/gemma-2b.gguf
/kaggle/input/gemma/transformers/2b/2/config.json
/kaggle/input/gemma/transformers/2b/2/model-00001-of-00002.safetensors
/kaggle/input/gemma/transformers/2b/2/model-00002-of-00002.safetensors
/kaggle/input/gemma/transformers/2b/2/tokenizer.json
/kaggle/input/gemma/transformers/2b/2/tokenizer_config.json
/kaggle/input/gemma/transformers/2b/2/special_tokens_map.json
/kaggle/input/gemma/transformers/2b/2/.gitattributes
/kaggle/input/gemma/transformers/2b/2/tokenizer.model
/kaggle/input/gemma/transformers/2b/2/generation_config.json
/kaggle/input/santa-2024/sample_submission.csv


In [None]:
# Importing necessary libraries
import numpy as np  # For numerical computations
import pandas as pd  # For data manipulation and analysis
from collections import Counter  # For counting hashable objects
from tqdm import tqdm  # For displaying progress bars in loops
import random  # For generating random numbers
import pickle  # For serializing and deserializing Python objects
import math  # For mathematical operations
import warnings  # For handling warnings

# Uncomment the line below to suppress warnings (useful in some cases)
# warnings.simplefilter('ignore')

# Path to the input CSV file
p = '/kaggle/input/santa-2024/sample_submission.csv'

# Reading the CSV file into a DataFrame
df = pd.read_csv(p)  # The file contains columns like 'id' and 'text'

# Printing the number of words in each 'text' entry
# The lambda function splits the text into words and counts them
print(df['text'].map(lambda x: len(str(x).split(' '))).values)

[ 10  20  20  30  50 100]


In [None]:
# Extracting the 'text' column values from the DataFrame
tokens = df.text.values

# Joining all text entries into a single string
tokens = ' '.join(tokens)

# Splitting the string into individual words (tokens)
tokens = tokens.split(' ')

# Counting the frequency of each unique word using Counter
tokens = dict(Counter(tokens))

# Printing the total number of unique words
print(len(tokens))

# Uncomment the line below to print the dictionary of word frequencies
# print(tokens)

# Printing a separator for better readability
print('-' * 20)

# Printing all unique words as a single space-separated string
print(' '.join(k for k in tokens))

89
--------------------
advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh workshop stocking holly jingle beard naughty nice sing of is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy wonder believe dream hope peace joy merry season greeting card wrapping paper bow cookie milk star wish wreath angel to in that have it not with as you from we kaggle


In [None]:
# Importing necessary libraries
import transformers  # For working with pre-trained transformer models
import torch  # For tensor computations and deep learning
import gc, os, logging  # For garbage collection, environment variables, and logging
from math import exp  # For calculating exponential values
from typing import List, Optional, Union  # For type hinting

# Setting environment variables to optimize performance
os.environ['OMP_NUM_THREADS'] = '1'  # Limits the number of threads for parallelism
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # Disables parallelism in tokenizers

# Setting the device to CPU or GPU based on availability
DEVICE = torch.device('cpu')

# Path to the pre-trained model
MODEL_PATH = "/kaggle/input/gemma/transformers/2b/2"

# Custom exception class for errors visible to participants
class ParticipantVisibleError(Exception):
    print(Exception)
    pass

# Class to calculate perplexity of text using a pre-trained language model
class PerplexityCalculator:
    def __init__(self):
        # Loading the tokenizer for the pre-trained model
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_PATH)
        
        # Loading the pre-trained model for causal language modeling
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
            MODEL_PATH, 
            device_map="auto", 
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        
        # Defining the loss function (CrossEntropyLoss)
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        
        # Setting the model to evaluation mode
        self.model.eval()
        
        # Setting the device to GPU if available, otherwise CPU
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

    # Method to calculate perplexity for a single text or a list of texts
    def get_perplexity(self, input_texts: Union[str, List[str]]) -> Union[float, List[float]]:
        # Check if the input is a single string or a list of strings
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts
        loss_list = []

        # Disable gradient computation for inference
        with torch.no_grad():
            for text in input_texts:
                # Adding special tokens (BOS and EOS) to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"
                
                # Tokenizing the text and converting it to tensors
                minputs = self.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False)
                minputs = {k: v.to(self.device) for k, v in minputs.items()}
                
                # Passing the inputs through the model
                output = self.model(**minputs, use_cache=False)
                logits = output['logits']
                
                # Shifting logits and labels for loss calculation
                slogits = logits[..., :-1, :].contiguous()
                slabels = minputs['input_ids'][..., 1:].contiguous()
                
                # Calculating the loss for the sequence
                loss = self.loss_fct(slogits.view(-1, slogits.size(-1)), slabels.view(-1))
                sequence_loss = loss.sum() / len(loss)
                loss_list.append(sequence_loss.cpu().item())

        # Calculating perplexity from the loss
        ppl = [exp(i) for i in loss_list]
        return ppl[0] if single_input else ppl

# Creating an instance of the PerplexityCalculator class
scorer = PerplexityCalculator()

<class 'Exception'>


`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# Defining a string of holiday-related tokens (words)
tokens = "advent chimney elf family fireplace gingerbread mistletoe ornament reindeer scrooge walk give jump drive bake the sleep night laugh and yuletide decorations gifts cheer holiday carol magi nutcracker polar grinch sleigh workshop stocking holly jingle beard naughty nice sing of is eat visit relax unwrap hohoho candle poinsettia snowglobe peppermint eggnog fruitcake chocolate candy puzzle game doll toy wonder believe dream hope peace joy merry season greeting card wrapping paper bow cookie milk star wish wreath angel to in that have it not with as you from we kaggle"

# Adding special tokens (BOS and EOS) to the text
text_with_special = f"{scorer.tokenizer.bos_token}{tokens}{scorer.tokenizer.eos_token}"

# Tokenizing the text and converting it to tensors
minputs = scorer.tokenizer(text_with_special, return_tensors='pt', add_special_tokens=False)

# Printing the tokenized input IDs
minputs['input_ids']

tensor([[     2, 104828,  67905,  52931,   2730,  43485, 136507,   7727, 165493,
          29138, 103360,   1513,  80108,    541,   5376,   2734,   9902,   6109,
          44528,    573,   6284,   3354,  10084,    578,    597,   1731,  23675,
          42768,  17196,  22867,  12083, 138763, 198447,  16621,  99946,  16573,
           2660,  14111, 155702,  20257,  77515, 108548, 204063,  38175,  97840,
           4866,   2800,    576,    603,   7812,   3532,  10228,    748,  14660,
           1965, 215898,  28162,  83096,    881,   9437,   8529, 112671, 149218,
           7815,   1312,    869,   9471,  23144,  13171,  25720,  24754,   2398,
           7474,  12849,   5144,   4564,   6523,   4077,   7124,  10300,  46301,
           3891,  32338,   4076,  56178,   4368,   7181,  17467,   9512,   2343,
           6199,  58409,  22448,    577,    575,    674,    791,    665,    780,
            675,    685,    692,    774,    783, 124555,   2315,      1]])

In [None]:
# Dictionary containing text sequences as keys and their perplexity scores as values
past = {
    'reindeer mistletoe elf scrooge gingerbread chimney fireplace ornament family advent': 495.6812574407127,
    'reindeer mistletoe elf gingerbread ornament advent family fireplace chimney sleep drive walk jump laugh give and bake the night scrooge': 514.3236249412977,
    'magi yuletide cheer grinch carol holly jingle naughty nice polar beard sleigh chimney workshop nutcracker holiday ornament decorations gifts stocking': 327.18895045897995,
    'sleigh of the magi yuletide cheer grinch is unwrap gifts decorations ornament holly stocking and chimney naughty nice polar beard nutcracker visit workshop eat relax carol sing holiday cheer jingle': 327.33220436011106,
    'wreath merry have and season hohoho to you from the star of wonder workshop that it not with joy we believe hope peace fruitcake chocolate candy peppermint candle snowglobe angel poinsettia wrapping paper bow greeting card cookie milk wish dream fireplace kaggle toy doll game night puzzle eggnog as in': 222.93269021691262,
    'decorations eggnog yuletide poinsettia fruitcake scrooge nutcracker mistletoe holly wreath gingerbread cookie snowglobe reindeer angel star merry and the season of joy and wonder and peace to you from the family of the grinch holiday cheer is not as cheer unwrap gifts laugh hohoho sing carol in sleigh drive visit chimney chimney elf naughty nice eat bake sleep dream chocolate peppermint ornament stocking fireplace fireplace advent candle wish hope give card ornament wrapping paper toy doll bow game night night puzzle candy walk jingle jump relax believe it with we have that kaggle workshop workshop polar beard milk greeting magi': 120.92163009638384
}

# Printing the number of entries in the dictionary
print(len(past))

6


In [None]:
%%time

# Function to optimize text sequences by minimizing perplexity
def get_best_plex(df, past):
    l = 2  # Number of iterations for shuffling and testing new sequences
    
    # Calculate perplexity scores for all rows in the DataFrame
    df['score'] = df['text'].map(lambda x: scorer.get_perplexity(x))
    print(np.mean(df['score'].values))  # Print the mean perplexity score
    
    # Iterate through each row in the DataFrame
    for r in range(len(df)):
        # Print the current row and its perplexity score
        print("row: ", r, df.at[r, 'score'])
        
        # Split the text into words
        t = df['text'][r].split(' ')
        
        # Check if the current text matches any sequence in the past dictionary
        for k in past:
            kp = k.split(' ')
            if len(kp) == len(t):  # Ensure the lengths match
                if sorted(kp) == sorted(t):  # Check if the words match (ignoring order)
                    if df['score'][r] > past[k]:  # If a better score exists in the past dictionary
                        df.at[r, 'score'] = past[k]
                        df.at[r, 'text'] = k
        
        # Print the updated score for the current row
        print("row: ", r, df.at[r, 'score'])
        
        # Shuffle and test new sequences
        for i in range(l):
            t = df['text'][r].split(' ')
            random.shuffle(t)  # Shuffle the words
            t = ' '.join(t)  # Join the shuffled words into a string
            
            # If the new sequence is not in the past dictionary
            if t not in past:
                s = scorer.get_perplexity(t)  # Calculate its perplexity
                past[t] = s  # Add it to the past dictionary
                
                # If the new sequence has a lower perplexity score, update the DataFrame
                if s < df['score'][r]:
                    df.at[r, 'score'] = s
                    df.at[r, 'text'] = t
                    print(r, i, "New Score: ", s, np.mean(df['score'].values))
        
        # Save the updated past dictionary to a file
        with open('past0.pickle', 'wb') as f:
            pickle.dump(past, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    # Print the final mean perplexity score
    print('MEAN SCORE: ', np.mean(df['score'].values))
    
    # Return the updated DataFrame and past dictionary
    return df[['id', 'text']], past

# Call the function to optimize the text sequences
df, past = get_best_plex(df, past)

# Save the optimized DataFrame to a CSV file
df.to_csv("submission.csv", index=False)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


1018.7141148854306
row:  0 1350.9216564227258
row:  0 495.6812574407127
row:  1 2146.111641264015
row:  1 514.3236249412977
row:  2 872.250522619955
row:  2 327.18895045897995
row:  3 834.3357780659894
row:  3 327.33220436011106
row:  4 489.6441270405202
row:  4 222.93269021691262
row:  5 419.02096389937867
row:  5 120.92163009638384
MEAN SCORE:  334.73005958573293
CPU times: user 2min 12s, sys: 2.65 s, total: 2min 15s
Wall time: 1min 9s
