# Building a Chat-Aligned LLM Using Quantized LoRA Fine-Tuning

## Introduction

This note explains how to take an open-source Large Language Model (LLM) like Llama-3 and transform it into a chat-aligned assistant through memory-efficient fine-tuning. We'll use Parameter-Efficient Fine-Tuning (PEFT) with Low-Rank Adaptation (LoRA) to achieve this with minimal computational resources.

## What Is Quantized LoRA Fine-Tuning?

Fine-tuning combines two powerful techniques:

1. **Quantization**: Reducing model precision (e.g., from 32-bit to 4-bit) to decrease memory usage
2. **LoRA (Low-Rank Adaptation)**: Training small adapter matrices instead of the entire model

Together, these techniques enable efficient fine-tuning of billion-parameter models on accessible hardware.


## Understanding LoRA

LoRA works by inserting small trainable matrices into the attention layers:

1. **Original Operation**: `Y = WX` where `W` is a large weight matrix
2. **LoRA Modification**: `Y = WX + ΔWX` where `ΔW = BA` (low-rank decomposition)
3. **Benefits**: `B` and `A` are much smaller than `W`, reducing trainable parameters

## Why Is This Useful?

- **Resource Efficiency**: Train 8B+ parameter models on a single consumer GPU or TPU
- **Parameter Efficiency**: Update only ~0.1% of model parameters instead of 100%
- **Knowledge Preservation**: Retain the base model's knowledge while adding new capabilities
- **Format Adaptation**: Convert "raw" models into chat-optimized assistants

## Prerequisites

- Access to TPU or other computing environment
- Hugging Face account with API token

## Step-by-Step Implementation

### 1. Environment Setup

First, install the necessary libraries:

```python
!pip install transformers peft datasets torch_xla torch
```

### 2. Authentication with Hugging Face

```python
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_KEY")
login(token=hf_token)
```

### 3. Define Special Tokens for Chat Format

The key to chat alignment is implementing a specific format with special tokens:

```python
SPECIAL_TOKENS = {
    'stop_token': {
        'token': '###STOP###',
        'replace_embedding_with': 'stop_talking'
    },
    'human_token': {
        'token': '###HUMAN###',
        'replace_embedding_with': 'human_speaking'
    },
    'bot_token': {
        'token': '###BOT###',
        'replace_embedding_with': 'assistant_speaking'
    }
}
```

These tokens help the model distinguish between human input and assistant responses.

### 4. Load and Prepare the Base Model

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B",  
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens([info['token'] for info in SPECIAL_TOKENS.values()])

# Load model with TPU optimizations
import torch_xla.core.xla_model as xm
device = xm.xla_device()

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True
)
model = model.to(device)

# Resize embeddings to accommodate new tokens
model.resize_token_embeddings(len(tokenizer))
```

### 5. Initialize Special Token Embeddings

```python
# Initialize embeddings with semantically similar words
for token_name, token_info in SPECIAL_TOKENS.items():
    token_id = tokenizer(token_info['token'])['input_ids'][-1]
    similar_word_ids = tokenizer(token_info['replace_embedding_with'])['input_ids'][1:]
    new_embedding = model.model.embed_tokens.weight.data[similar_word_ids].cpu().mean(dim=0, keepdim=True)
    model.model.embed_tokens.weight.data[token_id] = new_embedding.to(model.device).clone()
```

### 6. Configure LoRA for Efficient Training

```python
from peft import LoraConfig, get_peft_model

# Create LoRA configuration
lora_config = LoraConfig(
    r=8,                        # Rank of LoRA matrices
    lora_alpha=16,              # Scaling factor
    target_modules=["q_proj", "v_proj"],  # Attention layers to modify
    lora_dropout=0.1,           # Dropout for regularization
    bias="none",                # Don't train bias terms
    task_type="CAUSAL_LM"       # Causal language modeling task
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)

# Print statistics
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params/total_params:.3f}% of total)")
```

### 7. Prepare Conversation Dataset

```python
from datasets import Dataset, load_dataset

# Load conversation datasets
assistant_dataset = load_dataset("timdettmers/openassistant-guanaco")
original_guanaco_dataset = load_dataset("guanaco/guanaco")

# Parse conversations and format with special tokens
# (See full implementation in the code for details)

# Create and tokenize dataset
tokenized_dataset = {
    "train": train_dataset.map(tokenize_function, batched=True),
    "test": test_dataset.map(tokenize_function, batched=True)
}
```

### 8. Configure Training

```python
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

# Training arguments
training_args = TrainingArguments(
    output_dir="./llama-lora-finetuned",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    num_train_epochs=1,
    bf16=True,  # Use bfloat16 on TPUs
    label_names=["labels"]
)

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # Not using masked language modeling
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator,
    label_names=["labels"]
)
```

### 9. Train the Model

```python
# Start training
print("Starting training...")
trainer.train()

# Save the fine-tuned model
model.save_pretrained("./llama-lora-finetuned")
tokenizer.save_pretrained("./llama-lora-finetuned")
```

### 10. Test the Fine-Tuned Model

```python
# Test with a simple prompt
test_prompt = "What is machine learning?"
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    do_sample=True,
    temperature=0.7
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated: {generated_text}")
```

### 11. Upload to Hugging Face (Optional)

```python
# Upload to Hugging Face Hub
def upload_to_huggingface(model, tokenizer, output_dir, repository_id, token, private=True):
    model.push_to_hub(
        repository_id,
        use_auth_token=token,
        private=private
    )
    
    tokenizer.push_to_hub(
        repository_id,
        use_auth_token=token,
        private=private
    )
    
    print(f"Model uploaded to https://huggingface.co/{repository_id}")

# Example usage
repository_id = "your-username/llama-3-chat-lora"
upload_to_huggingface(model, tokenizer, "./llama-lora-finetuned", repository_id, hf_token)
```

### Conversation Format

The chat-aligned model uses special tokens to structure conversations:

```
###HUMAN###What is machine learning?###BOT###Machine learning is a subfield of artificial intelligence...###STOP###
```

This format teaches the model when to start and stop generating responses.

## Using Your Fine-Tuned Model

To use your newly chat-aligned model:

```python
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the fine-tuned model
config = PeftConfig.from_pretrained("your-username/llama-3-chat-lora")
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, "your-username/llama-3-chat-lora")
tokenizer = AutoTokenizer.from_pretrained("your-username/llama-3-chat-lora")

# Function to format user inputs for the model
def format_prompt(user_input):
    return f"###HUMAN###{user_input}###BOT###"

# Generate a response
def generate_response(user_input, max_length=100):
    prompt = format_prompt(user_input)
    inputs = tokenizer(prompt, return_tensors="pt")
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_length,
        do_sample=True,
        temperature=0.7,
        top_p=0.9
    )
    
    # Extract just the assistant's response
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    assistant_response = full_response.split("###BOT###")[-1].split("###STOP###")[0]
    
    return assistant_response.strip()

# Example usage
response = generate_response("Explain quantum computing in simple terms")
print(response)
```

## References

- [Hugging Face PEFT Documentation](https://huggingface.co/docs/peft/index)
- [LoRA Paper](https://arxiv.org/abs/2106.09685)
- [QLoRA Paper](https://arxiv.org/abs/2305.14314)

In [2]:
!pip install --upgrade pip
!pip install --no-warn-script-location transformers accelerate peft bitsandbytes bitsandbytes>=0.43.0 datasets

Collecting pip
  Downloading pip-25.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.1


In [None]:
"""
TPU-Compatible LoRA Fine-tuning for Llama-3 Models with Checkpointing

This script implements efficient fine-tuning of Llama-3 language models using:
1. Low-Rank Adaptation (LoRA) for parameter-efficient training
2. Special token handling for conversation formatting
3. TPU compatibility for Kaggle TPU environments
4. Checkpointing for resumable training

Author: Nawaraj Paudel, PhD
Date: April 26, 2025
"""

import os
import time
import numpy as np
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
from datasets import Dataset, load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
import pandas as pd
from peft import LoraConfig, get_peft_model, PeftModel
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
from tqdm.auto import tqdm
import logging
import datetime

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("LoRA-Finetune")

# Kaggle-specific paths
KAGGLE_WORKING_DIR = "/kaggle/working"
OUTPUT_DIR = os.path.join(KAGGLE_WORKING_DIR, "llama-lora-finetuned")


def setup_authentication():
    """
    Set up authentication with Hugging Face using Kaggle secrets.
    
    Returns:
        str: Authentication token
    """
    logger.info("Setting up Hugging Face authentication...")
    
    # Access the Hugging Face token from Kaggle secrets
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("HF_KEY")
    
    # Log in to Hugging Face Hub
    login(token=secret_value_0)
    logger.info("Successfully authenticated with Hugging Face")
    
    return secret_value_0


def load_tokenizer(model_name, special_tokens=None):
    """
    Load the tokenizer and add special tokens if provided.
    
    Args:
        model_name (str): Name or path of the pretrained model
        special_tokens (dict, optional): Dictionary of special tokens to add
    
    Returns:
        tokenizer: The loaded tokenizer with special tokens
    """
    logger.info(f"Loading tokenizer from {model_name}...")
    
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, 
        trust_remote_code=True
    )
    
    # Set padding token and padding side
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'  # Left padding for causal language modeling
    
    # Add special tokens if provided
    if special_tokens:
        tokens_to_add = [info['token'] for info in special_tokens.values()]
        num_added = tokenizer.add_tokens(tokens_to_add)
        logger.info(f"Added {num_added} special tokens to the tokenizer")
    
    return tokenizer


def load_tpu_model(model_name, token=None):
    """
    Load a pretrained model for TPU training.
    
    Args:
        model_name (str): Name or path of the pretrained model
        token (str, optional): Authentication token
    
    Returns:
        model: The loaded model
    """
    logger.info(f"Loading model {model_name} for TPU training...")
    start_time = time.time()
    
    # Get TPU device
    device = xm.xla_device()
    
    # Load the model with TPU compatibility settings
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,  # TPUs work well with bfloat16
        low_cpu_mem_usage=True,      # Reduce memory usage during loading
        token=token                  # Authentication token
    )
    
    # Move model to TPU
    model = model.to(device)
    
    elapsed_time = time.time() - start_time
    logger.info(f"Successfully loaded model on TPU (took {elapsed_time:.2f} seconds)")
    return model


def initialize_special_token_embeddings(model, tokenizer, special_tokens):
    """
    Initialize embeddings for special tokens using semantically similar words.
    
    Args:
        model: The loaded model
        tokenizer: The tokenizer with special tokens
        special_tokens (dict): Dictionary of special tokens and their embeddings
    
    Returns:
        dict: Updated special_tokens dictionary with token IDs and embeddings
    """
    logger.info("Initializing special token embeddings...")
    
    for token_name, token_info in special_tokens.items():
        token_id = tokenizer(token_info['token'])['input_ids'][-1]
        similar_word_ids = tokenizer(token_info['replace_embedding_with'])['input_ids'][1:]
        
        # Move to CPU for embedding manipulation
        new_embedding = model.model.embed_tokens.weight.data[similar_word_ids].cpu().mean(dim=0, keepdim=True)
        
        special_tokens[token_name]['new_embedding'] = new_embedding
        
        # Move embedding to TPU before updating
        model.model.embed_tokens.weight.data[token_id] = new_embedding.to(model.device).clone()
        special_tokens[token_name]['token_id'] = token_id
        
        logger.info(f"Initialized embedding for {token_info['token']} (ID: {token_id})")
    
    return special_tokens


def create_lora_config(rank=8, alpha=16, dropout=0.1):
    """
    Create a TPU-compatible LoRA configuration.
    
    Args:
        rank (int): Rank of the low-rank matrices
        alpha (int): Scaling factor for the low-rank updates
        dropout (float): Dropout probability for regularization
    
    Returns:
        LoraConfig: TPU-compatible configuration for LoRA training
    """
    logger.info(f"Creating LoRA configuration (rank={rank}, alpha={alpha}, dropout={dropout})...")
    
    # Create LoRA configuration optimized for TPU
    return LoraConfig(
        r=rank,
        lora_alpha=alpha,
        target_modules=["q_proj", "v_proj"],  # Target key layers
        lora_dropout=dropout,
        bias="none",
        task_type="CAUSAL_LM"
    )


def apply_lora_adapters(model, lora_config):
    """
    Apply LoRA adapters to the model for parameter-efficient fine-tuning.
    
    Args:
        model: The loaded model
        lora_config: LoRA configuration
    
    Returns:
        model: Model with LoRA adapters
    """
    logger.info("Applying LoRA adapters to the model...")
    
    # Apply LoRA adapters to the model
    model = get_peft_model(model, lora_config)
    
    # Print trainable parameter statistics
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params/total_params:.3f}% of total)")
    
    return model


def parse_conversations(json_data):
    """
    Parse conversation data from different dataset formats.
    
    Args:
        json_data (dict): Data entry from dataset
        
    Returns:
        list: List of (speaker, message) tuples representing the conversation
    """
    if 'prompt' in json_data and 'response' in json_data:
        return [['human', json_data['prompt']], ['assistant', json_data['response']]]
    
    conversations = json_data['text'].split('###')
    
    convos = []
    for convo in conversations:
        if convo.strip():
            parts = convo.strip().split(': ', 1)
            if len(parts) == 2:
                speaker, message = parts
                convos.append((speaker.lower(), message))
    
    return convos


def load_and_merge_datasets(tokenizer, max_length=256, test_size=0.2, seed=42, max_samples=2000):
    """
    Load, parse, and merge multiple conversation datasets with TPU optimizations.
    
    Args:
        tokenizer: Tokenizer for processing the text
        max_length (int): Maximum sequence length
        test_size (float): Proportion of data to use for testing
        seed (int): Random seed for reproducibility
        max_samples (int): Maximum number of samples to use
    
    Returns:
        dataset_dict: Dictionary containing 'train' and 'test' datasets
    """
    logger.info(f"Loading and merging datasets (max_samples={max_samples})...")
    start_time = time.time()
    
    # Load the datasets with streaming
    logger.info("Loading openassistant-guanaco dataset...")
    assistant_dataset = load_dataset("timdettmers/openassistant-guanaco", streaming=True)
    
    logger.info("Loading guanaco dataset...")
    original_guanaco_dataset = load_dataset("guanaco/guanaco", streaming=True)
    
    # Convert streaming datasets to lists with limited samples
    logger.info("Processing openassistant examples...")
    assistant_data = []
    for i, example in tqdm(enumerate(assistant_dataset['train']), 
                           desc="Loading openassistant examples", 
                           total=max_samples // 2):
        if i >= max_samples // 2:
            break
        assistant_data.append({'conversation': parse_conversations(example)})
    
    logger.info("Processing guanaco examples...")
    guanaco_data = []
    for i, example in tqdm(enumerate(original_guanaco_dataset['train']), 
                           desc="Loading guanaco examples", 
                           total=max_samples // 2):
        if i >= max_samples // 2:
            break
        guanaco_data.append({'conversation': parse_conversations(example)})
    
    logger.info(f"Loaded {len(assistant_data)} assistant examples and {len(guanaco_data)} guanaco examples")
    
    # Combine the datasets
    combined_data = assistant_data + guanaco_data
    
    # Shuffle the data
    import random
    random.seed(seed)
    random.shuffle(combined_data)
    
    # Create train/test split
    split_idx = int(len(combined_data) * (1 - test_size))
    train_data = combined_data[:split_idx]
    test_data = combined_data[split_idx:]
    
    logger.info(f"Split into {len(train_data)} training examples and {len(test_data)} test examples")
    
    # Filter to keep only valid conversations
    logger.info("Filtering conversations...")
    train_data = [ex for ex in train_data if 
                  ex['conversation'] and 
                  len(ex['conversation']) % 2 == 0 and 
                  ex['conversation'][-1][0] == 'assistant']
    
    test_data = [ex for ex in test_data if 
                ex['conversation'] and 
                len(ex['conversation']) % 2 == 0 and 
                ex['conversation'][-1][0] == 'assistant']
    
    logger.info(f"Filtered to {len(train_data)} training examples and {len(test_data)} test examples")
    
    # Format conversations with special tokens
    def join_conversation(example):
        convo_text = ""
        last_speaker = None
        
        for speaker, message in example['conversation']:
            last_speaker = speaker
            if speaker == 'human':
                convo_text += f"###HUMAN###{message}"
            elif speaker == 'assistant':
                convo_text += f"###BOT###{message}"
        
        if last_speaker == 'human':
            convo_text = convo_text.strip() + "###BOT###"
        else:
            convo_text = convo_text.strip() + "###STOP###"
        
        return {"text": convo_text}
    
    logger.info("Formatting conversations with special tokens...")
    train_formatted = [join_conversation(ex) for ex in tqdm(train_data, desc="Formatting train data")]
    test_formatted = [join_conversation(ex) for ex in tqdm(test_data, desc="Formatting test data")]
    
    # Create datasets
    train_dataset = Dataset.from_dict({"text": [ex['text'] for ex in train_formatted]})
    test_dataset = Dataset.from_dict({"text": [ex['text'] for ex in test_formatted]})
    
    # Tokenize the datasets in batches
    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=max_length
        )
    
    # Apply tokenization
    logger.info(f"Tokenizing datasets (max_length={max_length})...")
    train_tokenized = train_dataset.map(
        tokenize_function, 
        batched=True, 
        batch_size=32,
        desc="Tokenizing training data"
    )
    test_tokenized = test_dataset.map(
        tokenize_function, 
        batched=True, 
        batch_size=32,
        desc="Tokenizing test data"
    )
    
    # Create a combined dataset dictionary
    tokenized_dataset = {
        "train": train_tokenized,
        "test": test_tokenized
    }
    
    elapsed_time = time.time() - start_time
    logger.info(f"Dataset preparation completed in {elapsed_time:.2f} seconds")
    
    return tokenized_dataset


def create_training_args(output_dir, batch_size=4, gradient_accumulation_steps=8, 
                         learning_rate=1e-4, num_epochs=1):
    """
    Create TPU-compatible training arguments with checkpointing.
    
    Args:
        output_dir (str): Directory to save model checkpoints
        batch_size (int): Batch size for TPU
        gradient_accumulation_steps (int): Steps to accumulate gradients
        learning_rate (float): Learning rate
        num_epochs (float): Number of training epochs
    
    Returns:
        TrainingArguments: TPU-compatible training arguments
    """
    logger.info(f"Creating training arguments (batch_size={batch_size}, lr={learning_rate}, epochs={num_epochs})...")
    
    return TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        num_train_epochs=num_epochs,
        # Enable checkpoint saving
        save_strategy="steps",       # Save at regular steps
        save_steps=200,              # Save every 200 steps
        save_total_limit=3,          # Keep only the 3 most recent checkpoints
        # Other parameters
        logging_dir=f"{output_dir}/logs",
        logging_strategy="steps",
        logging_steps=10,
        report_to="none",
        # TPU-specific settings
        bf16=True,  # Use bfloat16 precision (native on TPUs)
        label_names=["labels"],  # Set label names explicitly
        # We don't use TensorFlow Profiler
        xla_tpu_config={
            "iterations_per_loop": 100,
        }
    )


def setup_trainer(model, training_args, tokenized_dataset, tokenizer):
    """
    Set up the trainer with TPU optimizations.
    
    Args:
        model: The model to train
        training_args: Training arguments
        tokenized_dataset: Tokenized dataset
        tokenizer: Tokenizer
    
    Returns:
        Trainer: Configured trainer
    """
    logger.info("Setting up trainer...")
    
    # Create a data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )
    
    # Calculate estimated training time
    num_examples = len(tokenized_dataset["train"])
    steps_per_epoch = num_examples // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)
    total_steps = steps_per_epoch * training_args.num_train_epochs
    # Estimated time: ~3 seconds per step for TPU (rough estimate)
    estimated_time_seconds = total_steps * 3
    estimated_time = str(datetime.timedelta(seconds=int(estimated_time_seconds)))
    
    logger.info(f"Training will run for approximately {total_steps} steps")
    logger.info(f"Estimated training time: {estimated_time}")
    
    # Create the trainer with XLA compilation
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["test"],
        data_collator=data_collator,
        label_names=["labels"]
    )
    
    # Add custom logging callback
    class LoggingCallback(TrainerCallback):
        def __init__(self, trainer):
            self.trainer = trainer
            self.start_time = time.time()
        
        def on_step_end(self, args, state, control, **kwargs):
            if state.global_step % args.logging_steps == 0:
                elapsed = time.time() - self.start_time
                steps_per_second = state.global_step / elapsed
                remaining_steps = state.max_steps - state.global_step
                remaining_time = remaining_steps / steps_per_second
                remaining_time_str = str(datetime.timedelta(seconds=int(remaining_time)))
                
                progress = state.global_step / state.max_steps * 100
                logger.info(f"Progress: {progress:.2f}% - Step: {state.global_step}/{state.max_steps}")
                logger.info(f"Loss: {state.log_history[-1]['loss']:.4f}")
                logger.info(f"Estimated time remaining: {remaining_time_str}")
    
    from transformers.trainer_callback import TrainerCallback
    trainer.add_callback(LoggingCallback(trainer))
    
    return trainer


def test_model(model, tokenizer, test_prompt):
    """
    Test the fine-tuned model with a simple prompt.
    
    Args:
        model: The fine-tuned model
        tokenizer: The tokenizer
        test_prompt (str): Prompt to test
    """
    logger.info(f"Testing model with prompt: \"{test_prompt}\"")
    
    # Tokenize the prompt
    inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)
    
    # Generate text
    logger.info("Generating response...")
    start_time = time.time()
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=True,
            temperature=0.7,
        )
    
    elapsed_time = time.time() - start_time
    
    # Decode the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    logger.info(f"Generation completed in {elapsed_time:.2f} seconds")
    logger.info(f"Generated response: {generated_text}")
    
    return generated_text


def find_latest_checkpoint(output_dir):
    """
    Find the latest checkpoint in the output directory.
    
    Args:
        output_dir (str): Directory to search for checkpoints
        
    Returns:
        str or None: Path to the latest checkpoint, or None if no checkpoints found
    """
    if not os.path.exists(output_dir):
        return None
    
    checkpoints = [
        os.path.join(output_dir, d) for d in os.listdir(output_dir)
        if os.path.isdir(os.path.join(output_dir, d)) and "checkpoint" in d
    ]
    
    if not checkpoints:
        return None
    
    # Sort checkpoints by the step number (extracted from the directory name)
    checkpoints = sorted(
        checkpoints,
        key=lambda x: int(x.split("-")[-1]) if x.split("-")[-1].isdigit() else 0
    )
    
    logger.info(f"Found checkpoint: {checkpoints[-1]}")
    return checkpoints[-1]


def save_notebook_state():
    """Save the current notebook state in Kaggle"""
    try:
        from IPython import get_ipython
        ipython = get_ipython()
        if ipython is not None:
            logger.info("Saving notebook state...")
            ipython.magic("notebook -e")
            logger.info("Notebook state saved")
    except:
        logger.warning("Could not save notebook state")


def main():
    """
    Main function with TPU optimizations and checkpointing support.
    """
    logger.info("="*50)
    logger.info("Starting LoRA fine-tuning process")
    logger.info("="*50)
    
    # Create output directory if it doesn't exist
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        logger.info(f"Created output directory: {OUTPUT_DIR}")
    
    # Define model and special tokens
    model_name = "meta-llama/Meta-Llama-3.1-8B"
    
    # Define special tokens
    SPECIAL_TOKENS = {
        'stop_token': {
            'token': '###STOP###',
            'replace_embedding_with': 'stop_talking'
        },
        'human_token': {
            'token': '###HUMAN###',
            'replace_embedding_with': 'human_speaking'
        },
        'bot_token': {
            'token': '###BOT###',
            'replace_embedding_with': 'assistant_speaking'
        }
    }
    
    # Check for existing checkpoints
    resume_from_checkpoint = find_latest_checkpoint(OUTPUT_DIR)
    has_existing_model = os.path.exists(os.path.join(OUTPUT_DIR, "adapter_config.json"))
    
    if resume_from_checkpoint:
        logger.info(f"Will resume training from checkpoint: {resume_from_checkpoint}")
    elif has_existing_model:
        logger.info(f"Found existing model at {OUTPUT_DIR}, will continue training")
    else:
        logger.info("No checkpoint found, starting new training run")
    
    # Initialize TPU
    logger.info("Initializing TPU...")
    device = xm.xla_device()
    logger.info(f"Using device: {device}")
    
    # Set up authentication
    hf_token = setup_authentication()
    
    # Load tokenizer with special tokens
    tokenizer = load_tokenizer(model_name, SPECIAL_TOKENS)
    
    # Load model based on checkpoint status
    if resume_from_checkpoint or has_existing_model:
        logger.info("Loading base model...")
        # First load the base model
        base_model = load_tpu_model(model_name, hf_token)
        base_model.resize_token_embeddings(len(tokenizer))
        
        # Then load LoRA adapters
        if resume_from_checkpoint:
            logger.info(f"Loading LoRA adapters from checkpoint: {resume_from_checkpoint}")
            model = PeftModel.from_pretrained(base_model, resume_from_checkpoint)
        else:
            logger.info(f"Loading LoRA adapters from: {OUTPUT_DIR}")
            model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
            
        logger.info("Successfully loaded model with adapters")
    else:
        # Start fresh training
        logger.info("Starting fresh training run")
        
        # Load model for TPU
        model = load_tpu_model(model_name, hf_token)
        
        # Resize token embeddings and initialize special tokens
        model.resize_token_embeddings(len(tokenizer))
        SPECIAL_TOKENS = initialize_special_token_embeddings(model, tokenizer, SPECIAL_TOKENS)
        
        # Create and apply LoRA configuration
        lora_config = create_lora_config()
        model = apply_lora_adapters(model, lora_config)
    
    # Load and merge datasets
    tokenized_dataset = load_and_merge_datasets(tokenizer)
    
    # Create TPU-optimized training arguments
    training_args = create_training_args(OUTPUT_DIR)
    
    # Set up trainer
    trainer = setup_trainer(model, training_args, tokenized_dataset, tokenizer)
    
    # Train the model
    logger.info("="*50)
    logger.info("Starting training process")
    logger.info("="*50)
    
    try:
        # Start or resume training
        trainer.train(resume_from_checkpoint=resume_from_checkpoint)
        logger.info("Training completed successfully!")
        
        # Save notebook state
        save_notebook_state()
    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
        # Save model and state even when interrupted
        logger.info("Saving current model state...")
        model.save_pretrained(OUTPUT_DIR)
        tokenizer.save_pretrained(OUTPUT_DIR)
        save_notebook_state()
        logger.info("Current state saved")
        return
    except Exception as e:
        logger.error(f"Training error: {str(e)}")
        # Try to save model state on error
        try:
            logger.info("Attempting to save current model state...")
            model.save_pretrained(OUTPUT_DIR)
            tokenizer.save_pretrained(OUTPUT_DIR)
            save_notebook_state()
            logger.info("Current state saved despite error")
        except:
            logger.error("Failed to save model state")
        
        logger.info("Training interrupted. You can resume from the latest checkpoint.")
        return
    
    # Save the fine-tuned model
    logger.info("Saving final model...")
    model.save_pretrained(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)
    logger.info(f"Model saved to {OUTPUT_DIR}")
    
    # Test the model
    logger.info("Testing the fine-tuned model...")
    test_prompt = "What is machine learning?"
    test_model(model, tokenizer, test_prompt)
    
    logger.info("="*50)
    logger.info("LoRA fine-tuning process completed")
    logger.info("="*50)


if __name__ == "__main__":
    main()

# Exploring individual parts of the above code

# Set Up
Request llama-3-8b gated model for your use here: https://huggingface.co/meta-llama/Meta-Llama-3-8B. You can access your HuggingFace token here: https://huggingface.co/settings/tokens. On the top menu in Kaggle, click on Add-ons > Secrets and provide key name and value there which will generate the code snippet to copy into clipboard.

In [3]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("HF_KEY")

In [4]:
from huggingface_hub import login

# Using the environment variable you set earlier
login(token = secret_value_0)

In [5]:
from datasets import load_dataset

In [6]:
# lets merge these two conversational datasets together
assistant_dataset = load_dataset("timdettmers/openassistant-guanaco")
original_guanaco_dataset = load_dataset("guanaco/guanaco")

README.md:   0%|          | 0.00/395 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


openassistant_best_replies_train.jsonl:   0%|          | 0.00/20.9M [00:00<?, ?B/s]

openassistant_best_replies_eval.jsonl:   0%|          | 0.00/1.11M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9846 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/518 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

guanaco.json:   0%|          | 0.00/56.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/48701 [00:00<?, ? examples/s]

In [7]:
assistant_dataset['train'][0]

{'text': '### Human: Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.### Assistant: "Monopsony" refers to a market structure where there is only one buyer for a particular good or service. In economics, this term is particularly relevant in the labor market, where a monopsony employer has significant power over the wages and working conditions of their employees. The presence of a monopsony can result in lower wages and reduced employment opportunities for workers, as the employer has little incentive to increase wages or provide better working conditions.\n\nRecent research has identified potential monopsonies in industries such as retail and fast food, where a few large companies control a significant portion of the market (Bivens & Mishel, 2013). In these industries, workers often face low wages, limited benefits, and reduced bargaining po

In [8]:
original_guanaco_dataset['train'][0]

{'text': "Describe the following:\nA tall, thin man with dark hair styled in a slick pompadour.### Response:\nThe man is quite tall with a slender physique. His dark hair is styled in a sleek, classic pompadour with a slight sheen to it, suggesting perhaps he spends time grooming it. The hair appears to be well-maintained, as though he's put some effort into it. The dark color of his hair contrasts with his pale skin, which gives him a somewhat striking appearance. He conveys a sense of poise and confidence, as though he's comfortable in his own skin. Overall, he seems to exude a sophisticated, polished vibe.",
 'prompt': 'Describe the following:\nA tall, thin man with dark hair styled in a slick pompadour.',
 'response': "The man is quite tall with a slender physique. His dark hair is styled in a sleek, classic pompadour with a slight sheen to it, suggesting perhaps he spends time grooming it. The hair appears to be well-maintained, as though he's put some effort into it. The dark col

# Parse Conversation

Lets standardize each conversation into the form:
```
[
('human', 'human utterance 1'),
('assistant', 'assistant utterance 1'),
('human', 'human utterance 2'),
('assistant', 'assistant utterance 2')
]
```

In [9]:
# Function to parse the JSON and extract human/bot conversations
def parse_conversations(json_data):
    if 'prompt' in json_data and 'response' in json_data:
        return [['human', json_data['prompt']], ['assistant', json_data['response']]]
    # Split the text on '###' to separate conversations
    conversations = json_data['text'].split('###')

    # Process each conversation and split each data into human and bot parts
    convos = []
    for convo in conversations:
        if convo.strip():  # checks if the string is not just whitespace
            parts = convo.strip().split(': ', 1) # Split on the first occurrence of ': '
            if len(parts) == 2:
                speaker,message = parts
                convos.append((speaker.lower(), message))
    return convos

# Parse the conversations
print(parse_conversations(assistant_dataset['train'][0]))
print(parse_conversations(original_guanaco_dataset['train'][0]))
            

[('human', 'Can you write a short introduction about the relevance of the term "monopsony" in economics? Please use examples related to potential monopsonies in the labour market and cite relevant research.'), ('assistant', '"Monopsony" refers to a market structure where there is only one buyer for a particular good or service. In economics, this term is particularly relevant in the labor market, where a monopsony employer has significant power over the wages and working conditions of their employees. The presence of a monopsony can result in lower wages and reduced employment opportunities for workers, as the employer has little incentive to increase wages or provide better working conditions.\n\nRecent research has identified potential monopsonies in industries such as retail and fast food, where a few large companies control a significant portion of the market (Bivens & Mishel, 2013). In these industries, workers often face low wages, limited benefits, and reduced bargaining power, 

In [10]:
assistant_dataset = assistant_dataset.map(lambda x: {'conversation': parse_conversations(x)})
assistant_dataset

Map:   0%|          | 0/9846 [00:00<?, ? examples/s]

Map:   0%|          | 0/518 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'conversation'],
        num_rows: 9846
    })
    test: Dataset({
        features: ['text', 'conversation'],
        num_rows: 518
    })
})

In [11]:
original_guanaco_dataset = original_guanaco_dataset.map(lambda x: {'conversation': parse_conversations(x)})
original_guanaco_dataset

Map:   0%|          | 0/48701 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'prompt', 'response', 'conversation'],
        num_rows: 48701
    })
})

In [12]:
# combine the datasets together
dataset = Dataset.from_dict({'conversation': list(assistant_dataset['train']['conversation']) + list(original_guanaco_dataset['train']['conversation'])})
dataset = dataset.train_test_split(test_size = 0.2, seed = 42)
dataset

DatasetDict({
    train: Dataset({
        features: ['conversation'],
        num_rows: 46837
    })
    test: Dataset({
        features: ['conversation'],
        num_rows: 11710
    })
})

In [13]:
# Let's define some special tokens! We aren't using Meta's standard tokens because the no-chat aligned version doesn't have any
STOP_TOKEN = '###STOP###'  # can be any token, we are just using STOP
HUMAN_TOKEN = '###HUMAN###'
BOT_TOKEN = '###BOT###'

# Define the extra tokens dictionary with all three tokens
EXTRA_TOKENS = {
    'stop_token': {
        'token': STOP_TOKEN,
        'replace_embedding_with': 'stop_talking'
    },
    'human_token': {
        'token': HUMAN_TOKEN,
        'replace_embedding_with': 'human_speaking'
    },
    'bot_token': {
        'token': BOT_TOKEN,
        'replace_embedding_with': 'assistant_speaking'
    }
}

In [14]:
# Use model from Hugging Face - we are using non chat for mat to illustrate we dont need to start from the chat checkpoint
# the chat checkpoint
base_model = 'meta-llama/Meta-Llama-3.1-8B'

# Load LlaMA tokenizer and add the tree special tokens
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code = True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'  # Technically not necessary

tokenizer.add_tokens([extra['token'] for extra in EXTRA_TOKENS.values()])

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

3

In [1]:
# Load with PyTorch's native half precision
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    torch_dtype=torch.bfloat16,
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant=False,
        bnb_4bit_type='nf4',
    ),
)

peft_model = PeftModel.from_pretrained(
    base_model,
    subfolder="lofta_init",
    is_trainable=True,
)

NameError: name 'AutoModelForCausalLM' is not defined

In [None]:
print(model.config.vocab_size)

In [None]:
# resize model's embedding size to match tokenizer + 3 new special tokens
model.resize_token_embeddings(len(tokenizer))
print(model.config.vocab_size)

In [None]:
# Custom Token Embedding Initialization
# ------------------------------------
# When adding new special tokens to a pre-trained model, we need to initialize their embeddings.
# Since we're using LoRA (which cannot modify the embedding layer), we need a different approach.
for extra_token, extra_info in EXTRA_TOKENS.items():
    token_id = tokenizer(extra_info['token'])['input_ids'][-1]
    new_embedding = model.model.embed_tokens.weight.data[tokenizer(extra_info['replace_embedding_with'])['input_ids'][1:]].mean(dim=0, keepdim=True)
    EXTRA_TOKENS[extra_token]['new_embedding'] = new_embedding
    model.model.embed_tokens.weight.data[token_id] = EXTRA_TOKENS[extra_token]['new_embedding'].clone()
    EXTRA_TOKENS[extra_token]['token_id'] = token_id
    print(f"Replaced token \"{extra_info['token']}\" (token_id {token_id}) weight with weight for \"{extra_info['replace_embedding_with']}\"")

In [None]:
# Lets take a look into our data
text_pd = pd.concat([pd.DataFrame(dataset['train']), pd.DataFrame(dataset['test'])])
text_pd['split'] = ['train'] * len(dataset['train']) + ['test'] * len(dataset['test'])
text_pd['last_speaker'] = text_pd['conversation'].apply(lambda x: x[-1][0])
text_pd['convo_length'] = text_pd['conversation'].apply(len)
text_pd = text_pd[text_pd['last_speaker'].isin(['assistant', 'human'])]

In [None]:
# We want conversations where assistants speeks last, those with human last arent as useful
text_pd['last_speaker'].value_counts()

In [None]:
# There are multiple languages in the dataset
text_pd.head(2)

In [None]:
# Only keep convo ending with assistant
text_pd = text_pd[text_pd['last_speaker'] == 'assistant']
text_pd = text_pd[text_pd['convo_length'] % 2 == 0]

In [None]:
text_pd = text_pd.reset_index()

In [None]:
text_pd['convo_length'].value_counts().sort_index().plot.bar()

# Filter out by language

In [None]:
FILTER_EN = False

In [None]:
from transformers import pipeline
import pandas as pd
from tqdm import tqdm

In [None]:
# Define a function for batch processing
def detect_language_in_batches(batch, batch_size = 128):
    batch_results = pipe([b[0][1] for b in batch['conversation']])
    return {'lang': [b['label'] for b in batch_results]}

# Apply the function to DataFrame
if FILTER_EN:
    pipe = pipeline("text-classification", model = "papluca/xlm-roberta-base-language-detection", truncation = True, max_length = 64)
    dataset = dataset.map(detect_language_in_batches, batch_size = 128, batched = True)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# Set a professional style
sns.set_style("whitegrid")
plt.rcParams['font.family'] = 'sans-serif'

# Only create the language chart if we've filtered for language
if FILTER_EN:
    # Get language distribution
    language_counts = pd.Series(dataset['train']['lang'].value_counts())
    
    # Create the visualization
    plt.figure(figsize=(10, 6))
    ax = language_counts.plot.bar(
        color=sns.color_palette("viridis", len(language_counts)),
        edgecolor='black',
        width=0.7
    )
    
    # Enhance with labels and styling
    plt.title('Distribution of Languages in Training Dataset', fontsize=16, pad=20)
    plt.xlabel('Language', fontsize=12)
    plt.ylabel('Number of Examples', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    
    # Add value labels on top of bars
    for i, v in enumerate(language_counts):
        ax.text(
            i, 
            v + (language_counts.max() * 0.02),  # Slight offset above bar
            f'{v:,}',  # Format with commas for thousands
            ha='center',
            fontsize=10
        )
    
    # Adjust layout and display
    plt.tight_layout()
    plt.show()

In [None]:
if FILTER_EN:
    dataset = dataset.filter(lambda x: x['lang'] == 'en')

In [None]:
def join_convo(conversation):
    convo = ''''''
    last_speaker = None
    for speaker, message in conversation:
        last_speaker = speaker
        if speaker == 'human':
            convo += f"{EXTRA_TOKENS['human_token']['token']}{message}"
        elif speaker == 'assistant':
            convo += f"{EXTRA_TOKENS['bot_token']['token']}{message}"
        
    if last_speaker == 'human':
        return convo.strip() + f"{EXTRA_TOKENS['bot_token']['token']}"
    return convo.strip() + f"{EXTRA_TOKENS['stop_token']['token']}"

In [None]:
print(join_convo([['human', 'who was the first president of the USA?']]))

In [1]:
# # First, install the correct version of bitsandbytes
# !pip install -q bitsandbytes>=0.39.0
# !pip install -q accelerate

# # Import necessary libraries
# import numpy as np
# import os
# import torch
# from datasets import Dataset, load_dataset
# from transformers import (
#     AutoModelForCausalLM,
#     AutoTokenizer,
#     BitsAndBytesConfig,
#     TrainingArguments,
#     Trainer,
#     pipeline,
#     logging,
#     DataCollatorForLanguageModeling
# )
# import json
# import pandas as pd
# from peft import LoraConfig, PeftModel, get_peft_model
# from tqdm import tqdm

# # Set up base model identifier
# base_model = 'meta-llama/Meta-Llama-3.1-8B'

# # Define special tokens for chat format
# STOP_TOKEN = '###STOP###'
# HUMAN_TOKEN = '###HUMAN###'
# BOT_TOKEN = '###BOT###'

# # Define the extra tokens dictionary
# EXTRA_TOKENS = {
#     'stop_token': {
#         'token': STOP_TOKEN,
#         'replace_embedding_with': 'stop_talking'
#     },
#     'human_token': {
#         'token': HUMAN_TOKEN,
#         'replace_embedding_with': 'human_speaking'
#     },
#     'bot_token': {
#         'token': BOT_TOKEN,
#         'replace_embedding_with': 'assistant_speaking'
#     }
# }

# # Load tokenizer and add special tokens
# tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = 'left'
# tokenizer.add_tokens([extra['token'] for extra in EXTRA_TOKENS.values()])

# # Check if bitsandbytes is properly installed
# try:
#     import bitsandbytes as bnb
#     print(f"bitsandbytes version: {bnb.__version__}")
# except ImportError:
#     print("bitsandbytes is not installed properly. Trying a different approach...")
#     !pip install -q bitsandbytes --force-reinstall
#     import bitsandbytes as bnb
#     print(f"bitsandbytes version after reinstall: {bnb.__version__}")

# # Now try to configure quantization
# try:
#     # Configure quantization settings
#     quantization_config = BitsAndBytesConfig(
#         load_in_4bit=True,
#         bnb_4bit_compute_dtype=torch.bfloat16,
#         bnb_4bit_use_double_quant=True,
#         bnb_4bit_quant_type='nf4',
#     )
    
#     # Load model with proper quantization
#     model = AutoModelForCausalLM.from_pretrained(
#         base_model,
#         quantization_config=quantization_config,
#         torch_dtype=torch.bfloat16,
#         device_map="auto",
#     )
    
#     print("Model loaded successfully with quantization!")
# except Exception as e:
#     print(f"Error when loading with 4-bit quantization: {e}")
#     print("Falling back to 8-bit quantization...")
    
#     try:
#         # Try 8-bit quantization as fallback
#         quantization_config = BitsAndBytesConfig(
#             load_in_8bit=True,
#         )
        
#         model = AutoModelForCausalLM.from_pretrained(
#             base_model,
#             quantization_config=quantization_config,
#             torch_dtype=torch.float16,
#             device_map="auto",
#         )
#         print("Model loaded successfully with 8-bit quantization!")
#     except Exception as e:
#         print(f"Error when loading with 8-bit quantization: {e}")
#         print("Attempting to load with float16 without quantization (warning: high memory usage)...")
        
#         # Last resort - load with fp16 but no quantization
#         model = AutoModelForCausalLM.from_pretrained(
#             base_model,
#             torch_dtype=torch.float16,
#             device_map="auto",
#         )
#         print("Model loaded in float16 without quantization")

# # Resize token embeddings to account for the new tokens
# model.resize_token_embeddings(len(tokenizer))

# # Initialize special token embeddings with semantically similar words
# for extra_token, extra_info in EXTRA_TOKENS.items():
#     token_id = tokenizer(extra_info['token'])['input_ids'][-1]
#     replace_token_ids = tokenizer(extra_info['replace_embedding_with'])['input_ids'][1:]
#     new_embedding = model.model.embed_tokens.weight.data[replace_token_ids].mean(dim=0, keepdim=True)
#     EXTRA_TOKENS[extra_token]['new_embedding'] = new_embedding
#     model.model.embed_tokens.weight.data[token_id] = EXTRA_TOKENS[extra_token]['new_embedding'].clone()
#     print(f"Replaced token \"{extra_info['token']}\" (token_id {token_id}) weight with weight for \"{extra_info['replace_embedding_with']}\"")

# # Configure LoRA for efficient fine-tuning
# lora_config = LoraConfig(
#     r=16,  # Rank
#     lora_alpha=32,  # Alpha parameter
#     target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Target the attention modules
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
# )

# # Apply LoRA to the base model
# model = get_peft_model(model, lora_config)
# print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")