In [1]:
import re
import csv
import codecs
import pandas as pd
from datasets import Dataset
import json
from huggingface_hub import login
from transformers import AutoTokenizer
from dotenv import load_dotenv
import os

In [2]:
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
if (HF_TOKEN == None):
    raise ValueError("HF_TOKEN is not set")
login(token=HF_TOKEN)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [3]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
tokenizer = AutoTokenizer.from_pretrained(model_id)

### Whatsapp txt into filtered CSV

In [16]:
# function to convert special \<hex> charactersinto the corresponding Latin-1 characters
def decode_latin1_escapes(text):
    def repl(match):
        hex_str = match.group(1)
        return bytes.fromhex(hex_str).decode('latin-1')
    
    return re.sub(r"\\'([0-9a-fA-F]{2})", repl, text)

# reads whatsapp chat specified at input_path, removes metadata and noise (e.g. "audio omitted", "document omitted", "image omitted") and saves the cleaned messages in a csv file
# you can specify to only include messages from francesco by setting only_francesco to True

# LOGIC FOR PROCESSING DATA CAN BE FURTHER IMPROVED E.G. SAVE ONLY MESSAGES THAT ARE NEAR IN TIME
def process_data(input_path, output_path, only_sender_name=None):
    """
    Processes a WhatsApp chat export file, extracts messages, timestamps, and senders,
    cleans the messages, and saves them to a CSV file.

    Args:
        input_path (str): The path to the input WhatsApp chat .txt file.
        output_path (str): The path where the cleaned data CSV will be saved.
        only_sender_name (str, optional): If provided, only messages from this sender
                                           (case-insensitive) will be included. Defaults to None.
    """

    # Regex to capture timestamp, sender, and message
    # Group 1: Timestamp (e.g., [DD/MM/YY, HH:MM:SS])
    # Group 2: Sender name
    # Group 3: Message content
    pattern = re.compile(
        r'^.*?(\[\d{2}/\d{2}/\d{2},\s*\d{2}:\d{2}:\d{2}\])\s*(.*?):\s*(.*)$'
    )
    
    data = [] # List to store processed [timestamp, sender, message] tuples
    
    try:
        with open(input_path, 'r', encoding='utf-8') as infile:
            lines = infile.readlines()
    except FileNotFoundError:
        print(f"Error: Input file not found at {input_path}")
        return
    except Exception as e:
        print(f"An error occurred while reading the input file: {e}")
        return
    
    for line in lines:
        line = line.strip()
        if not line: # Skip empty lines
            continue
        
        match = pattern.match(line)
        if match:
            # Extract captured groups
            timestamp = match.group(1).strip()
            sender = match.group(2).strip()
            message = match.group(3).strip()
            
            # Filter by sender if only_sender_name is specified
            if only_sender_name and sender.lower() != only_sender_name.lower():
                continue
            
            # Skip messages with omitted content by Francesco
            omitted_match = re.search(r'\b(audio|document|image|video|sticker|contact) omitted\b', message, re.IGNORECASE)
            if omitted_match:
                if sender.lower() == "francesco brigante".lower():
                    continue # Skip these messages for Francesco
                else:
                    # For other senders, replace with a generic placeholder using the captured type
                    omitted_type = omitted_match.group(1).lower()
                    if omitted_type == "audio":
                        message = "*manda un audio*"
                    elif omitted_type == "document":
                        message = "*manda un documento*"
                    elif omitted_type == "image":
                        message = "*manda un'immagine*"
                    elif omitted_type == "video":
                        message = "*manda un video*"
                    elif omitted_type == "sticker":
                        message = "*manda uno sticker*"
                    elif omitted_type == "contact":
                        message = "*manda un contatto*"
                    else:
                        # Fallback for unexpected omitted types
                        message = "*manda un allegato*"
            
            # Skip initial chat message
            if "Messages and calls are end-to-end encrypted" in message:
                continue
            
            if "Voice call." in message:
                continue
            
            # Removes unfiltered special characters such as \uc0\u8206
            message = re.sub(r'\\[a-z]+\d*', '', message)

            # Decode Latin-1 escaped characters like \'ec which becomes ì
            message = decode_latin1_escapes(message)
            
            # Removes remaining backslashes (e.g., from original escapes that weren't Latin-1)
            message = message.replace("\\", "")

            # Normalizes multiple spaces to a single space and removes leading/trailing spaces,
            # also removes a trailing '}' which can sometimes appear due to regex artifacts
            message = re.sub(r'\s+', ' ', message).strip().rstrip('}')
            
            # Only add the message if it's not empty after cleaning
            if message:
                data.append([timestamp, sender, message])
    
    try:
        with open(output_path, 'w', encoding='utf-8', newline='') as csvfile:
            writer = csv.writer(csvfile)
            # Write the header row with the new 'Timestamp' column
            writer.writerow(['Timestamp', 'Sender', 'Message'])
            # Write all the collected data rows
            writer.writerows(data)
        
        print(f"File saved successfully in: {output_path}")
    except IOError as e:
        print(f"Error: Could not write to output file {output_path}. {e}")
    except Exception as e:
        print(f"An unexpected error occurred while saving the file: {e}")


In [17]:
inpput = "chat/ludovica.txt"
output = "data/ludovica.csv"
process_data(inpput, output)
inpput = "chat/mamma.txt"
output = "data/mamma.csv"
process_data(inpput, output)
inpput = "chat/genni.txt"
output = "data/genni.csv"
process_data(inpput, output)
inpput = "chat/cammisa.txt"
output = "data/cammisa.csv"
process_data(inpput, output)

File saved successfully in: data/ludovica.csv
File saved successfully in: data/mamma.csv
File saved successfully in: data/genni.csv
File saved successfully in: data/cammisa.csv


### CSV to json

In [4]:
# process_conversation takes a csv file and returns a json with the conversation
# the format is hugging face standard i.e. a list of dictionaries with role and content
# role can be system, for the initial prompt, assistant (francesco) or user (other people)
def process_conversation(csv_file, chat_id):
    
    df = pd.read_csv(csv_file)
    
    messages = []
    
    for _, row in df.iterrows():
        
        timestamp = row['Timestamp']
        
        if row['Sender'] != 'Francesco Brigante':
            #content = f"{row['Sender']}: {row['Message']}" to save in format user_specific_name: message
            content = row['Message']    #to save in format user: message
            messages.append({
                "role": "user",
                "content": content,
                "timestamp": timestamp
            })
        else:
            messages.append({
                "role": "assistant",
                "content": row['Message'],
                "timestamp": timestamp
            })
    
    return {
        "chat_id": chat_id,  
        "messages": messages
    }

# function to create a json file merging all the chats specified at csv_files
def create_dataset(csv_files):
    all_conversations = []
    chat_id = 1
    
    for csv_file in csv_files:
        conversation = process_conversation(csv_file, chat_id)
        all_conversations.append(conversation)
        chat_id += 1
    
    # Save to JSON file
    with open('dataset.json', 'w', encoding='utf-8') as f:
        json.dump(all_conversations, f, ensure_ascii=False, indent=2)
    
    return all_conversations

In [5]:
csv_files = [
    "data/ludovica.csv",
    "data/genni.csv",
    "data/mamma.csv",
    "data/cammisa.csv"
    ##aggiungere altri
]

create_dataset(csv_files)

[{'chat_id': 1,
  'messages': [{'role': 'user',
    'content': '282... Impazzisco',
    'timestamp': '[13/03/22, 17:05:54]'},
   {'role': 'assistant',
    'content': 'Ahahahahahah, mi sa che lo dovrai fare anche tu',
    'timestamp': '[13/03/22, 17:08:34]'},
   {'role': 'assistant',
    'content': 'Sono i programmi dal 6 al 13, perché l’esercizio chiede di farli tutti in un unico file',
    'timestamp': '[13/03/22, 17:08:59]'},
   {'role': 'user',
    'content': 'Io per il momento sto cercando di usare i puntatori nello switch del 3.3👀',
    'timestamp': '[13/03/22, 17:09:16]'},
   {'role': 'user',
    'content': 'Ah sisi vero',
    'timestamp': '[13/03/22, 17:09:22]'},
   {'role': 'assistant',
    'content': 'Ti serve una mano?',
    'timestamp': '[13/03/22, 17:10:04]'},
   {'role': 'user',
    'content': 'Grazie Fra, aspetta te lo mando',
    'timestamp': '[13/03/22, 17:13:02]'},
   {'role': 'user',
    'content': "*manda un'immagine*",
    'timestamp': '[13/03/22, 17:15:03]'},
   {'

### json to dataset

In [6]:
from datetime import datetime, timedelta



# Define special tokens
BOS_TOKEN = tokenizer.bos_token if tokenizer.bos_token else '<begin_of_sentence>'  # beginning of sentence token, used to give the instructions

# NOTE: we're using deepseek's special tokens for user and assistant roles, which have ｜ instead of |, don't confuse (｜|)
USER_TOKEN_START = '<｜User｜>' 
ASSISTANT_TOKEN_START = '<｜Assistant｜>'

#NOTE: same thing here for eos token, which is <｜end▁of▁sentence｜>, using also ▁ instead of _
EOS_TOKEN = tokenizer.eos_token if tokenizer.eos_token else "<|end_of_sentence|>"
END_TURN_TOKEN = "<|turn_end|>"

system_prompt = """You are Francesco Brigante, a 22 years old Italian Computer Science student in Rome. 
Respond naturally as him in Italian, maintaining his characteristic communication style. 
Keep responses concise and contextual.

Continue this conversation with the User, who is a friend of Francesco:""" 


TIME_GAP_MINUTES = 30
TIME_GAP_SECONDS = TIME_GAP_MINUTES * 60 # 2 hours in seconds

# Helper function to parse timestamp strings into datetime objects
def parse_timestamp(timestamp_str):
    """
    Parses a timestamp string in the format '[DD/MM/YY, HH:MM:SS]' into a datetime object.
    """
    # Remove brackets and strip whitespace
    cleaned_ts = timestamp_str.strip('[]').strip()
    # Parse format: DD/MM/YY, HH:MM:SS
    return datetime.strptime(cleaned_ts, '%d/%m/%y, %H:%M:%S')



# creates a formatted prompt using the special tokens
# the prompt starts with the system message if provided, then the context window and finally the current user message
def create_formatted_prompt(messages, current_user_msg_idx, max_context_messages=5):
    
    prompt_parts = []

    if system_prompt:
        prompt_parts.append(f"{BOS_TOKEN}{system_prompt}")

    current_msg = messages[current_user_msg_idx]
    current_msg_ts = parse_timestamp(current_msg['timestamp'])
    
    relevant_context_messages = []

    # Iterate backwards from the message right before the current user message
    for i in range(current_user_msg_idx - 1, -1, -1):
        msg = messages[i]
        msg_ts = parse_timestamp(msg['timestamp'])
        
        time_diff = current_msg_ts - msg_ts
        
        # Add message if within time gap AND we haven't exceeded the max number of context messages
        if time_diff <= timedelta(seconds=TIME_GAP_SECONDS) and len(relevant_context_messages) < max_context_messages:
            relevant_context_messages.insert(0, msg) # Insert at the beginning to maintain chronological order
        else:
            # If message is too old or max context messages reached, stop
            break
            
    # The full set of messages to include in the prompt will be the relevant context + the current user message
    messages_to_process = relevant_context_messages + [current_msg]

    # add messages to the prompt
    # there's a logic to group consecutive messages with the same role using only '\n' as separator instead of re-using the role token
    current_role = None
    current_contents = []
    for msg in messages_to_process:
        role = msg["role"]
        content = msg["content"]
        
        if role != current_role:
            # add the previous group if exists
            if current_role is not None:
                role_token = USER_TOKEN_START if current_role == "user" else ASSISTANT_TOKEN_START
                prompt_parts.append(role_token + '\n'.join(str(c) for c in current_contents) + END_TURN_TOKEN)
                
            # start new group
            current_role = role
            current_contents = [content]
        else:
            # continue current group
            current_contents.append(content)
    
    # add the last group
    if current_role is not None:
        role_token = USER_TOKEN_START if current_role == "user" else ASSISTANT_TOKEN_START
        prompt_parts.append(role_token + '\n'.join(str(c) for c in current_contents) + END_TURN_TOKEN)

    return "\n".join(prompt_parts) + f"\n{ASSISTANT_TOKEN_START}"

In [7]:
# creates a list from the dataset
def create_dataset_list(json_file_path, max_context_messages=5):

    # load the JSON file
    with open(json_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    dataset_list = []

    for conversation in data:
        messages = conversation['messages']

        if not messages:
            print(f"Warning: Conversation {conversation.get('chat_id', 'N/A')} is empty. Skipping.")
            continue

        # iterate through messages to find and save ONLY user -> assistant pairs, saving the last n_context_messages messages as context
        # the new format is a list of dictionaries with prompt and response
        for i in range(0, len(messages) - 1):
            current_message = messages[i]
            next_message = messages[i+1] # This is the immediate assistant response

            #time-stamps
            current_msg_ts = parse_timestamp(current_message['timestamp'])
            next_msg_ts = parse_timestamp(next_message['timestamp'])
            time_diff = next_msg_ts - current_msg_ts
            max_time_diff = timedelta(seconds=TIME_GAP_SECONDS)
            
            if current_message['role'] == 'user' and next_message['role'] == 'assistant' and time_diff <= max_time_diff:
                
                # Create the formatted prompt using the new time-based context logic
                prompt = create_formatted_prompt(
                    messages,
                    current_user_msg_idx=i,
                    max_context_messages=max_context_messages
                )

                response_parts = [next_message['content']]
                response_base_timestamp = parse_timestamp(next_message['timestamp'])

                # Look for subsequent assistant messages to group
                # Start checking from the message *after* the immediate assistant response
                for j in range(i + 2, len(messages)):
                    subsequent_msg = messages[j]
                    
                    # Only group if it's an assistant message and its timestamp
                    # is within the defined grouping time window from the base response timestamp
                    if subsequent_msg['role'] == 'assistant':
                        subsequent_msg_ts = parse_timestamp(subsequent_msg['timestamp'])
                        time_diff = subsequent_msg_ts - response_base_timestamp
                        
                        if time_diff <= timedelta(seconds=TIME_GAP_SECONDS / 4):  #  minutes for grouping
                            response_parts.append(subsequent_msg['content'])
                        else:
                            # If too much time has passed, stop grouping
                            break
                    else:
                        # If the role is not assistant (e.g., a new user message), stop grouping
                        break
                        
                # Join all collected response parts and append the End Of Sentence token
                response = '\n'.join(response_parts) + EOS_TOKEN

                dataset_list.append({
                    'prompt': prompt,
                    'response': response
                })

    return dataset_list

In [None]:
# # DATASET CON MENO SOVRAPPOSIZIONI

# def create_formatted_prompt(messages, current_user_msg_idx, max_context_messages=5, context_start_limit_idx=0):
#     """
#     Creates a formatted prompt string for a language model, including system message,
#     time-relevant context messages, and the current user message.

#     Args:
#         messages (list): A list of message dictionaries, each with 'role', 'content', and 'timestamp'.
#         current_user_msg_idx (int): The index of the current user message in the 'messages' list.
#         max_context_messages (int): The maximum number of context messages to include,
#                                      even if more are within the time window.
#         context_start_limit_idx (int): The minimum index from which context messages can be drawn.
#                                          Messages with indices less than this will be ignored.

#     Returns:
#         str: The formatted prompt string.
#     """
    
#     prompt_parts = []

#     if system_prompt:
#         prompt_parts.append(f"{BOS_TOKEN}{system_prompt}")

#     current_msg = messages[current_user_msg_idx]
#     current_msg_ts = parse_timestamp(current_msg['timestamp'])
    
#     relevant_context_messages = []

#     # Iterate backwards from the message right before the current user message
#     # Ensure 'i' does not go below context_start_limit_idx
#     for i in range(current_user_msg_idx - 1, context_start_limit_idx - 1, -1):
#         msg = messages[i]
#         msg_ts = parse_timestamp(msg['timestamp'])
        
#         time_diff = current_msg_ts - msg_ts
        
#         # Add message if within time gap AND we haven't exceeded the max number of context messages
#         # Also ensure the message index is not less than the context_start_limit_idx
#         if time_diff <= timedelta(seconds=TIME_GAP_SECONDS) and len(relevant_context_messages) < max_context_messages:
#             relevant_context_messages.insert(0, msg) # Insert at the beginning to maintain chronological order
#         else:
#             # If message is too old or max context messages reached, or it's before the limit, stop
#             break
            
#     # The full set of messages to include in the prompt will be the relevant context + the current user message
#     messages_to_process = relevant_context_messages + [current_msg]

#     # add messages to the prompt
#     # there's a logic to group consecutive messages with the same role using only '\n' as separator instead of re-using the role token
#     current_role = None
#     current_contents = []
#     for msg in messages_to_process:
#         role = msg["role"]
#         content = msg["content"]
        
#         if role != current_role:
#             # add the previous group if exists
#             if current_role is not None:
#                 role_token = USER_TOKEN_START if current_role == "user" else ASSISTANT_TOKEN_START
#                 prompt_parts.append(role_token + '\n'.join(str(c) for c in current_contents) + END_TURN_TOKEN)
                
#             # start new group
#             current_role = role
#             current_contents = [content]
#         else:
#             # continue current group
#             current_contents.append(content)
    
#     # add the last group
#     if current_role is not None:
#         role_token = USER_TOKEN_START if current_role == "user" else ASSISTANT_TOKEN_START
#         prompt_parts.append(role_token + '\n'.join(str(c) for c in current_contents) + END_TURN_TOKEN)

#     return "\n".join(prompt_parts) + f"\n{ASSISTANT_TOKEN_START}"

# def create_dataset_list(json_file_path, max_context_messages=5):
#     """
#     Creates a list of formatted training examples (prompt-response pairs) from a JSON dataset.
#     Context messages are selected based on a time gap and a maximum number of messages.
#     Crucially, this version ensures no messages that were part of a previous
#     response are included in the context of a subsequent prompt, to avoid overlap.
#     Consecutive assistant messages after a response are grouped if within a smaller time gap.

#     Args:
#         json_file_path (str): The path to the JSON dataset file.
#         max_context_messages (int): The maximum number of context messages to include in the prompt.

#     Returns:
#         list: A list of dictionaries, each containing a 'prompt' and a 'response'.
#     """

#     # load the JSON file
#     with open(json_file_path, 'r', encoding='utf-8') as f:
#         data = json.load(f)

#     dataset_list = []

#     for conversation in data:
#         messages = conversation['messages']

#         if not messages:
#             print(f"Warning: Conversation {conversation.get('chat_id', 'N/A')} is empty. Skipping.")
#             continue

#         current_idx = 0 
#         # This variable tracks the index AFTER the last message that was part of a response
#         # in a *previous* training example within the current conversation.
#         # It sets the lower bound for messages considered for context.
#         last_response_end_idx_for_context = -1 

#         while current_idx < len(messages) - 1:
#             current_message = messages[current_idx]
#             next_message = messages[current_idx + 1] # This is the immediate assistant response

#             #time-stamps
#             current_msg_ts = parse_timestamp(current_message['timestamp'])
#             next_msg_ts = parse_timestamp(next_message['timestamp'])
#             time_diff = next_msg_ts - current_msg_ts
#             max_time_diff = timedelta(seconds=TIME_GAP_SECONDS)
            
#             if current_message['role'] == 'user' and next_message['role'] == 'assistant' and time_diff <= max_time_diff:
                
#                 # The context_start_limit_idx for this prompt is one index after
#                 # where the last response in the *previous* example ended.
#                 # This prevents previous responses from appearing in current prompts.
#                 context_limit = last_response_end_idx_for_context + 1

#                 prompt = create_formatted_prompt(
#                     messages,
#                     current_user_msg_idx=current_idx,
#                     max_context_messages=max_context_messages,
#                     context_start_limit_idx=context_limit # Pass the calculated limit
#                 )

#                 response_parts = [next_message['content']]
#                 response_base_timestamp = parse_timestamp(next_message['timestamp'])
                
#                 last_response_msg_idx = current_idx + 1 # Initial last message in response
                
#                 # Look for subsequent assistant messages to group
#                 for j in range(current_idx + 2, len(messages)):
#                     subsequent_msg = messages[j]
                    
#                     if subsequent_msg['role'] == 'assistant':
#                         subsequent_msg_ts = parse_timestamp(subsequent_msg['timestamp'])
#                         time_diff = subsequent_msg_ts - response_base_timestamp
                        
#                         if time_diff <= timedelta(seconds=TIME_GAP_SECONDS / 4):  # minutes for grouping
#                             response_parts.append(subsequent_msg['content'])
#                             last_response_msg_idx = j
#                         else:
#                             break
#                     else:
#                         break
                        
#                 response = '\n'.join(response_parts) + EOS_TOKEN

#                 dataset_list.append({
#                     'prompt': prompt,
#                     'response': response
#                 })
                
#                 # Update the boundary for the *next* training example's context
#                 last_response_end_idx_for_context = last_response_msg_idx
#                 # Advance the main pointer past the entire response just consumed.
#                 current_idx = last_response_msg_idx + 1
#             else:
#                 # If the current pair is not a user message followed by an assistant message,
#                 # simply move to the next message to find a valid pair.
#                 current_idx += 1

#     return dataset_list

In [8]:
json_file_path = 'dataset.json'
dataset_list = create_dataset_list(json_file_path, max_context_messages=5)

# print some examples
start_idx = 14000
print(f"Generated {len(dataset_list)} training examples.\n")
for i, example in enumerate(dataset_list[start_idx:start_idx+5]): # print 3 examples
    print(f"--- Example {i+1} ---")
    print("Prompt:")
    print(example['prompt'])
    print("\nResponse:")
    print(example['response'])
    print("---------------------\n")

if not dataset_list:
    print("ERROR: No training examples were generated.")

Generated 14390 training examples.

--- Example 1 ---
Prompt:
<｜begin▁of▁sentence｜>You are Francesco Brigante, a 22 years old Italian Computer Science student in Rome. 
Respond naturally as him in Italian, maintaining his characteristic communication style. 
Keep responses concise and contextual.

Continue this conversation with the User, who is a friend of Francesco:
<｜User｜>*manda un audio*<|turn_end|>
<｜Assistant｜>

Response:
Che ti perdi<｜end▁of▁sentence｜>
---------------------

--- Example 2 ---
Prompt:
<｜begin▁of▁sentence｜>You are Francesco Brigante, a 22 years old Italian Computer Science student in Rome. 
Respond naturally as him in Italian, maintaining his characteristic communication style. 
Keep responses concise and contextual.

Continue this conversation with the User, who is a friend of Francesco:
<｜Assistant｜>Okok buonanotte<|turn_end|>
<｜User｜>Ahahahahhahahahaha
Che è successo?<|turn_end|>
<｜Assistant｜>

Response:
Nulla amo ieri sera mi andava l’ultima tua marlboro e te

In [9]:
# create a Hugging Face dataset
dataset = Dataset.from_list(dataset_list)

# creating dataset splits
train_test = dataset.train_test_split(test_size=0.2, seed=42)   #80% train, 20% test
test_valid = train_test['test'].train_test_split(test_size=0.5, seed=42)  #10% test, 10% valid

train_dataset = train_test['train']
val_dataset = test_valid['train']
test_dataset = test_valid['test']

### Tokenization

In [21]:
def tokenize_function(batch, tokenizer=tokenizer, max_length=256):
    prompts   = batch["prompt"]
    responses = batch["response"]
    
    tokenized_prompts = tokenizer(
        prompts,
        max_length=max_length,
        truncation=True,
        add_special_tokens=False
    )
    
    # compute length of tokenized prompts
    prompt_lengths = [len(tokens) for tokens in tokenized_prompts['input_ids']]
    
    tokenized_conversation = tokenizer(
        [p + r for p, r in zip(prompts, responses)],
        max_length=max_length,
        padding='max_length',
        truncation=True,
        add_special_tokens=False
    )
    
    input_ids = tokenized_conversation['input_ids']                 #full tokenized conversation
    attention_mask = tokenized_conversation['attention_mask']       #real tokens are 1, padding tokens are 0
    labels = []                                                     #used for training, labels are the same as input_ids but with the prompt part masked out
    
    
    for id in input_ids:
        label = id.copy()
        
        # find the index of the assistant token in the input_ids
        response_start_idx = None
        for i in reversed(range(len(label))):
            if label[i] == tokenizer.convert_tokens_to_ids(ASSISTANT_TOKEN_START):
                response_start_idx = i + 1                                                  # exclude assistant token
                break

        if response_start_idx is None:
            print("[❌] Assistant token not found in input_ids.")
            #print("Decoded text:\n", tokenizer.decode(ids, skip_special_tokens=False))
            label = [-100] * len(label)                                                     # ignore everything
        else:
            # mask everything before the assistant's response
            label[:response_start_idx] = [-100] * response_start_idx

        labels.append(label)   
        
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

In [22]:
# apply tokenization
tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names
)

tokenized_val = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=val_dataset.column_names
)

tokenized_test = test_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=test_dataset.column_names
)

# save tokenized datasets
tokenized_train.save_to_disk('datasets/tokenized_train')
tokenized_val.save_to_disk('datasets/tokenized_val')
tokenized_test.save_to_disk('datasets/tokenized_test')

Map:   0%|          | 0/11512 [00:00<?, ? examples/s]

[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.


Map:   0%|          | 0/1439 [00:00<?, ? examples/s]

Map:   0%|          | 0/1439 [00:00<?, ? examples/s]

[❌] Assistant token not found in input_ids.
[❌] Assistant token not found in input_ids.


Saving the dataset (0/1 shards):   0%|          | 0/11512 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1439 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1439 [00:00<?, ? examples/s]

### Check for correctness

In [23]:
# check to see if the tokenized datasets are loaded correctly
from datasets import load_from_disk

# load each dataset from disk
tokenized_train = load_from_disk('datasets/tokenized_train')
tokenized_val = load_from_disk('datasets/tokenized_val')
tokenized_test = load_from_disk('datasets/tokenized_test')

# print
print(f"Training examples: {len(tokenized_train)}")
print(f"Validation examples: {len(tokenized_val)}")
print(f"Test examples: {len(tokenized_test)}")

Training examples: 11512
Validation examples: 1439
Test examples: 1439


In [24]:
import random

# print a random or indexed example from the tokenized dataset
def print_tokenized_example(tokenized_dataset, tokenizer, index=None):
    if index is None:
        index = random.randint(0, len(tokenized_dataset) - 1)
    example = tokenized_dataset[index]
    
    input_ids = example['input_ids']
    labels = example['labels']
    attention_mask = example['attention_mask']
    
    # decode full input_ids
    decoded_input = tokenizer.decode(input_ids, skip_special_tokens=False)
    
    # decode labels, skipping -100 tokens
    decoded_labels = tokenizer.decode(
        [id for id in labels if id != -100],
        skip_special_tokens=False
    )
    
    print(f"--- Example index: {index} ---")
    print("== INPUT IDS ==")
    print(input_ids)
    print("\n== INPUT TEXT ==")
    print(decoded_input)
    print("\n== ATTENTION MASK ==")
    print(attention_mask)
    print("\n== LABELS (masked prompt) ==")
    print(labels)
    print("\n== LABEL TEXT ==")
    print(decoded_labels)
    
print_tokenized_example(tokenized_train, tokenizer)

--- Example index: 845 ---
== INPUT IDS ==
[151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 1516

In [25]:
# function to check correctness of tokenized examples verifying some properties:
# 1. Attention mask should match padding tokens in input_ids
# 2. Labels are correctly masking the prompt tokens, i.e. only the response part should be not -100
def verify_tokenized_example(example, tokenizer, assistant_token=ASSISTANT_TOKEN_START, eos_token_id=151643):
    input_ids = example['input_ids']
    labels = example['labels']
    attention_mask = example['attention_mask']
    
    # sanity check
    if len(input_ids) != len(labels) or len(input_ids) != len(attention_mask):
        print("[❌] Mismatch in tensor lengths.")
        return False

    # 1. Check padding (0s in attention mask == padding token IDs)
    pad_id = tokenizer.pad_token_id or eos_token_id                                 # use EOS as pad if pad_token is not defined (standard for deepseek)
    num_padding_tokens = max(0, input_ids.count(pad_id) - 1)                        # exclude the last token which is EOS
    num_attention_mask_zeros = attention_mask.count(0)

    if num_padding_tokens != num_attention_mask_zeros:
        print(f"[❌] Padding mismatch: found {num_padding_tokens} pad IDs but {num_attention_mask_zeros} zeros in attention_mask.")
        return False
    else:
        print(f"[✅] Attention mask matches padding: {num_padding_tokens}.")

    # 2. Visit in reverse to find the last assistant token and count response tokens i.e. non -100 labels 
    response_token_count = 0
    for i in reversed(range(len(labels))):
        if labels[i] != -100:
            response_token_count += 1
        elif input_ids[i] == tokenizer.convert_tokens_to_ids(assistant_token):
            break

    actual_response_tokens = sum(1 for l in labels if l != -100)

    if actual_response_tokens != response_token_count:
        print(f"[❌] Label mismatch: expected {response_token_count} response tokens, got {actual_response_tokens}.")
        return False
    else:
        print(f"[✅] Labels correctly mask prompt tokens, {response_token_count} response tokens detected.")

    return True

for example in tokenized_train:
    if not verify_tokenized_example(example, tokenizer):
        print("Error in tokenized example!")
        break

[✅] Attention mask matches padding: 148.
[✅] Labels correctly mask prompt tokens, 12 response tokens detected.
[✅] Attention mask matches padding: 169.
[✅] Labels correctly mask prompt tokens, 6 response tokens detected.
[✅] Attention mask matches padding: 176.
[✅] Labels correctly mask prompt tokens, 7 response tokens detected.
[✅] Attention mask matches padding: 139.
[✅] Labels correctly mask prompt tokens, 3 response tokens detected.
[✅] Attention mask matches padding: 172.
[✅] Labels correctly mask prompt tokens, 9 response tokens detected.
[✅] Attention mask matches padding: 141.
[✅] Labels correctly mask prompt tokens, 5 response tokens detected.
[✅] Attention mask matches padding: 145.
[✅] Labels correctly mask prompt tokens, 11 response tokens detected.
[✅] Attention mask matches padding: 123.
[✅] Labels correctly mask prompt tokens, 25 response tokens detected.
[✅] Attention mask matches padding: 140.
[✅] Labels correctly mask prompt tokens, 7 response tokens detected.
[✅] Att

In [26]:
# counts the number of examples with full attention, i.e. no padding tokens
# useful to set correct batch size for training
def count_full_attention(dataset):
    full_attention_count = 0

    for i, example in enumerate(dataset):
        attention_mask = example['attention_mask']
        if all(token == 1 for token in attention_mask):
            full_attention_count += 1

    percentage = (full_attention_count / len(dataset)) * 100

    print(f"✅ Percentage of examples with full attention (no padding): {full_attention_count} / {len(dataset)} = {percentage:.2f}%")
    return percentage

count_full_attention(tokenized_train)

✅ Percentage of examples with full attention (no padding): 92 / 11512 = 0.80%


0.7991660875608061