In [None]:
# Google Colab Setup
import sys
import os

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("üîß Running in Google Colab - Setting up environment...")
    if not os.path.exists('transformer_from_scratch'):
        print("üì• Cloning repository...")
        !git clone https://github.com/melhzy/transformer_from_scratch.git
        print("‚úÖ Repository cloned!")
    os.chdir('transformer_from_scratch')
    print("üì¶ Installing dependencies...")
    !pip install -q torch torchvision matplotlib seaborn numpy pandas tqdm datasets transformers
    print("‚úÖ Dependencies installed!")
    if '/content/transformer_from_scratch' not in sys.path:
        sys.path.insert(0, '/content/transformer_from_scratch')
    print("‚úÖ Setup complete!")
else:
    print("üíª Running locally - no setup needed.")

In [None]:
# Import libraries
import sys
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import json
from dataclasses import dataclass
from collections import Counter

if not IN_COLAB:
    sys.path.insert(0, str(Path.cwd().parent))

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm import tqdm

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"‚úÖ Device: {device}")
print(f"‚úÖ PyTorch version: {torch.__version__}")

## 1. Data Formats for Fine-Tuning üìù

### Common Formats:

1. **Prompt-Completion** (Simple)
```json
{"prompt": "Translate to French: Hello", "completion": "Bonjour"}
```

2. **Instruction Format** (Alpaca-style)
```json
{
  "instruction": "Summarize the following text",
  "input": "[long text]",
  "output": "[summary]"
}
```

3. **Chat Format** (Multi-turn)
```json
{
  "messages": [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "What is 2+2?"},
    {"role": "assistant", "content": "4"}
  ]
}
```

### Let's create sample datasets:

In [None]:
# Sample datasets in different formats

# 1. Prompt-Completion
prompt_completion_data = [
    {"prompt": "Translate to French: Hello", "completion": "Bonjour"},
    {"prompt": "Translate to French: Thank you", "completion": "Merci"},
    {"prompt": "Translate to Spanish: Good morning", "completion": "Buenos d√≠as"},
    {"prompt": "What is the capital of France?", "completion": "The capital of France is Paris."},
]

# 2. Instruction Format (Alpaca-style)
instruction_data = [
    {
        "instruction": "Explain the concept in simple terms",
        "input": "Quantum entanglement",
        "output": "Quantum entanglement is when two particles become connected so that the state of one instantly affects the other, no matter how far apart they are."
    },
    {
        "instruction": "Write a haiku about the topic",
        "input": "Spring",
        "output": "Cherry blossoms fall,\nGentle breeze whispers softly,\nSpring awakens joy."
    },
    {
        "instruction": "Summarize this text",
        "input": "The Transformer architecture, introduced in 2017, revolutionized natural language processing by replacing recurrent connections with self-attention mechanisms.",
        "output": "Transformers (2017) replaced RNNs with self-attention for NLP tasks."
    },
]

# 3. Chat Format
chat_data = [
    {
        "messages": [
            {"role": "system", "content": "You are a helpful AI assistant specializing in math."},
            {"role": "user", "content": "What is the derivative of x^2?"},
            {"role": "assistant", "content": "The derivative of x^2 is 2x."}
        ]
    },
    {
        "messages": [
            {"role": "system", "content": "You are a coding tutor."},
            {"role": "user", "content": "How do I reverse a string in Python?"},
            {"role": "assistant", "content": "You can reverse a string using slicing: `s[::-1]` or `''.join(reversed(s))`."}
        ]
    },
]

print("‚úÖ Sample datasets created!")
print(f"  Prompt-Completion: {len(prompt_completion_data)} examples")
print(f"  Instruction: {len(instruction_data)} examples")
print(f"  Chat: {len(chat_data)} examples")

## 2. Simple Tokenizer Implementation üî§

For demonstration, we'll implement a simple BPE-like tokenizer. In production, use Hugging Face tokenizers.

Reference: [transformer-foundation/01_embeddings_and_positional_encoding.ipynb](../transformer-foundation/01_embeddings_and_positional_encoding.ipynb)

In [None]:
class SimpleTokenizer:
    """
    Simple word-based tokenizer for demonstration.
    In production, use Hugging Face tokenizers (BPE, WordPiece, etc.)
    """
    def __init__(self, vocab_size: int = 10000):
        self.vocab_size = vocab_size
        # Special tokens (standard for LLMs)
        self.pad_token = "<PAD>"
        self.unk_token = "<UNK>"
        self.bos_token = "<BOS>"  # Beginning of sequence
        self.eos_token = "<EOS>"  # End of sequence
        
        # Initialize vocab with special tokens
        self.token2id = {
            self.pad_token: 0,
            self.unk_token: 1,
            self.bos_token: 2,
            self.eos_token: 3,
        }
        self.id2token = {v: k for k, v in self.token2id.items()}
        self.next_id = 4
        
    def build_vocab(self, texts: List[str]):
        """Build vocabulary from corpus"""
        word_freq = Counter()
        for text in texts:
            words = text.lower().split()
            word_freq.update(words)
        
        # Add most frequent words to vocab
        for word, _ in word_freq.most_common(self.vocab_size - 4):
            if word not in self.token2id:
                self.token2id[word] = self.next_id
                self.id2token[self.next_id] = word
                self.next_id += 1
        
        print(f"Built vocabulary: {len(self.token2id)} tokens")
    
    def encode(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """Convert text to token IDs"""
        words = text.lower().split()
        ids = [self.token2id.get(w, self.token2id[self.unk_token]) for w in words]
        
        if add_special_tokens:
            ids = [self.token2id[self.bos_token]] + ids + [self.token2id[self.eos_token]]
        
        return ids
    
    def decode(self, ids: List[int], skip_special_tokens: bool = True) -> str:
        """Convert token IDs back to text"""
        special_ids = {self.token2id[t] for t in [self.pad_token, self.bos_token, self.eos_token]}
        
        words = []
        for id in ids:
            if skip_special_tokens and id in special_ids:
                continue
            words.append(self.id2token.get(id, self.unk_token))
        
        return " ".join(words)
    
    @property
    def pad_token_id(self):
        return self.token2id[self.pad_token]
    
    @property
    def eos_token_id(self):
        return self.token2id[self.eos_token]


# Build tokenizer from our sample data
all_texts = []
for item in prompt_completion_data:
    all_texts.extend([item['prompt'], item['completion']])
for item in instruction_data:
    all_texts.extend([item['instruction'], item['input'], item['output']])

tokenizer = SimpleTokenizer(vocab_size=1000)
tokenizer.build_vocab(all_texts)

# Test tokenizer
test_text = "Hello, how are you?"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)

print(f"\nTokenizer Test:")
print(f"  Original: {test_text}")
print(f"  Encoded: {encoded}")
print(f"  Decoded: {decoded}")
print("\n‚úÖ Tokenizer works!")

## 3. Dataset Classes for Different Formats üóÇÔ∏è

In [None]:
class PromptCompletionDataset(Dataset):
    """
    Dataset for prompt-completion format.
    Used for: simple Q&A, translation, etc.
    """
    def __init__(
        self, 
        data: List[Dict],
        tokenizer: SimpleTokenizer,
        max_length: int = 512
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Concatenate prompt and completion
        # Format: <BOS> prompt <EOS> completion <EOS>
        prompt_ids = self.tokenizer.encode(item['prompt'], add_special_tokens=False)
        completion_ids = self.tokenizer.encode(item['completion'], add_special_tokens=False)
        
        # Add special tokens manually for more control
        input_ids = (
            [self.tokenizer.token2id[self.tokenizer.bos_token]] +
            prompt_ids +
            [self.tokenizer.token2id[self.tokenizer.eos_token]] +
            completion_ids +
            [self.tokenizer.token2id[self.tokenizer.eos_token]]
        )
        
        # Truncate if too long
        if len(input_ids) > self.max_length:
            input_ids = input_ids[:self.max_length]
        
        # Labels: -100 for prompt (don't compute loss), actual IDs for completion
        labels = [-100] * (len(prompt_ids) + 2) + completion_ids + [self.tokenizer.eos_token_id]
        if len(labels) > self.max_length:
            labels = labels[:self.max_length]
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long)
        }


class InstructionDataset(Dataset):
    """
    Dataset for instruction format (Alpaca-style).
    Used for: instruction tuning, task adaptation.
    """
    def __init__(
        self,
        data: List[Dict],
        tokenizer: SimpleTokenizer,
        max_length: int = 512
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Format: <BOS> Instruction: {inst} Input: {input} Output: {output} <EOS>
        prompt = f"Instruction: {item['instruction']} Input: {item['input']} Output:"
        completion = item['output']
        
        prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        completion_ids = self.tokenizer.encode(completion, add_special_tokens=False)
        
        input_ids = (
            [self.tokenizer.token2id[self.tokenizer.bos_token]] +
            prompt_ids +
            completion_ids +
            [self.tokenizer.token2id[self.tokenizer.eos_token]]
        )
        
        if len(input_ids) > self.max_length:
            input_ids = input_ids[:self.max_length]
        
        # Only compute loss on output
        labels = ([-100] * (len(prompt_ids) + 1) + 
                 completion_ids + 
                 [self.tokenizer.eos_token_id])
        if len(labels) > self.max_length:
            labels = labels[:self.max_length]
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long)
        }


# Test datasets
pc_dataset = PromptCompletionDataset(prompt_completion_data, tokenizer, max_length=128)
inst_dataset = InstructionDataset(instruction_data, tokenizer, max_length=128)

print("Testing datasets...\n")
print(f"Prompt-Completion Dataset: {len(pc_dataset)} examples")
sample = pc_dataset[0]
print(f"  Input IDs shape: {sample['input_ids'].shape}")
print(f"  Labels shape: {sample['labels'].shape}")
print(f"  Decoded input: {tokenizer.decode(sample['input_ids'].tolist())}")

print(f"\nInstruction Dataset: {len(inst_dataset)} examples")
sample = inst_dataset[0]
print(f"  Input IDs shape: {sample['input_ids'].shape}")
print(f"  Labels shape: {sample['labels'].shape}")

print("\n‚úÖ Datasets work!")

## 4. Data Collator with Padding üì¶

Efficiently batch variable-length sequences.

In [None]:
@dataclass
class DataCollatorForLanguageModeling:
    """
    Collate batch of examples with padding.
    Based on Hugging Face's DataCollator.
    """
    tokenizer: SimpleTokenizer
    max_length: int = 512
    pad_to_multiple_of: Optional[int] = None
    
    def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        # Find max length in batch
        batch_max = max(len(ex['input_ids']) for ex in examples)
        
        # Optionally pad to multiple (for TPU efficiency)
        if self.pad_to_multiple_of is not None:
            batch_max = (
                (batch_max + self.pad_to_multiple_of - 1) 
                // self.pad_to_multiple_of 
                * self.pad_to_multiple_of
            )
        
        batch_max = min(batch_max, self.max_length)
        
        # Pad all sequences
        input_ids = []
        labels = []
        attention_mask = []
        
        for ex in examples:
            seq_len = len(ex['input_ids'])
            padding_len = batch_max - seq_len
            
            # Pad input_ids
            padded_input = torch.cat([
                ex['input_ids'],
                torch.full((padding_len,), self.tokenizer.pad_token_id, dtype=torch.long)
            ])
            input_ids.append(padded_input)
            
            # Pad labels (use -100 for padding tokens)
            padded_labels = torch.cat([
                ex['labels'],
                torch.full((padding_len,), -100, dtype=torch.long)
            ])
            labels.append(padded_labels)
            
            # Attention mask (1 for real tokens, 0 for padding)
            mask = torch.cat([
                torch.ones(seq_len, dtype=torch.long),
                torch.zeros(padding_len, dtype=torch.long)
            ])
            attention_mask.append(mask)
        
        return {
            'input_ids': torch.stack(input_ids),
            'labels': torch.stack(labels),
            'attention_mask': torch.stack(attention_mask)
        }


# Test data collator
collator = DataCollatorForLanguageModeling(tokenizer, max_length=128, pad_to_multiple_of=8)

# Create dataloader
dataloader = DataLoader(
    pc_dataset,
    batch_size=2,
    collate_fn=collator,
    shuffle=True
)

# Test batch
batch = next(iter(dataloader))
print("Testing Data Collator...\n")
print(f"Batch keys: {batch.keys()}")
print(f"Input IDs shape: {batch['input_ids'].shape}")
print(f"Labels shape: {batch['labels'].shape}")
print(f"Attention mask shape: {batch['attention_mask'].shape}")
print(f"\nSample attention mask:")
print(batch['attention_mask'][0])
print("\n‚úÖ Data collator works!")

## 5. Data Quality Control üîç

Filter and validate data quality.

In [None]:
class DataQualityChecker:
    """
    Check and filter data for quality issues.
    """
    def __init__(
        self,
        tokenizer: SimpleTokenizer,
        min_length: int = 10,
        max_length: int = 1024,
        max_prompt_completion_ratio: float = 10.0
    ):
        self.tokenizer = tokenizer
        self.min_length = min_length
        self.max_length = max_length
        self.max_ratio = max_prompt_completion_ratio
        
        self.stats = {
            'total': 0,
            'too_short': 0,
            'too_long': 0,
            'bad_ratio': 0,
            'passed': 0
        }
    
    def check_example(self, prompt: str, completion: str) -> Tuple[bool, str]:
        """Check single example. Returns (is_valid, reason)"""
        self.stats['total'] += 1
        
        # Check lengths
        prompt_len = len(self.tokenizer.encode(prompt, add_special_tokens=False))
        completion_len = len(self.tokenizer.encode(completion, add_special_tokens=False))
        total_len = prompt_len + completion_len
        
        if total_len < self.min_length:
            self.stats['too_short'] += 1
            return False, f"Too short: {total_len} tokens"
        
        if total_len > self.max_length:
            self.stats['too_long'] += 1
            return False, f"Too long: {total_len} tokens"
        
        # Check prompt/completion ratio (avoid extremely long prompts)
        if completion_len > 0:
            ratio = prompt_len / completion_len
            if ratio > self.max_ratio:
                self.stats['bad_ratio'] += 1
                return False, f"Bad ratio: {ratio:.1f}"
        
        self.stats['passed'] += 1
        return True, "OK"
    
    def filter_dataset(self, data: List[Dict], format_type: str = 'prompt_completion') -> List[Dict]:
        """Filter dataset and return valid examples"""
        filtered = []
        
        for item in data:
            if format_type == 'prompt_completion':
                is_valid, _ = self.check_example(item['prompt'], item['completion'])
            elif format_type == 'instruction':
                prompt = f"{item['instruction']} {item['input']}"
                is_valid, _ = self.check_example(prompt, item['output'])
            else:
                is_valid = True
            
            if is_valid:
                filtered.append(item)
        
        return filtered
    
    def print_stats(self):
        """Print filtering statistics"""
        print("\nüìä Data Quality Statistics:")
        print(f"  Total examples: {self.stats['total']}")
        print(f"  Passed: {self.stats['passed']} ({self.stats['passed']/self.stats['total']*100:.1f}%)")
        print(f"  Too short: {self.stats['too_short']}")
        print(f"  Too long: {self.stats['too_long']}")
        print(f"  Bad ratio: {self.stats['bad_ratio']}")


# Test quality checker
checker = DataQualityChecker(tokenizer, min_length=5, max_length=200)

# Check examples
filtered_data = checker.filter_dataset(prompt_completion_data, 'prompt_completion')
checker.print_stats()

print(f"\n‚úÖ Filtered data: {len(filtered_data)}/{len(prompt_completion_data)} examples")

## 6. Data Visualization üìà

Understand your data distribution.

In [None]:
def analyze_dataset(data: List[Dict], tokenizer: SimpleTokenizer, format_type: str = 'prompt_completion'):
    """
    Analyze and visualize dataset statistics.
    """
    prompt_lengths = []
    completion_lengths = []
    total_lengths = []
    
    for item in data:
        if format_type == 'prompt_completion':
            prompt = item['prompt']
            completion = item['completion']
        elif format_type == 'instruction':
            prompt = f"{item['instruction']} {item['input']}"
            completion = item['output']
        
        p_len = len(tokenizer.encode(prompt, add_special_tokens=False))
        c_len = len(tokenizer.encode(completion, add_special_tokens=False))
        
        prompt_lengths.append(p_len)
        completion_lengths.append(c_len)
        total_lengths.append(p_len + c_len)
    
    # Create visualizations
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Histogram of lengths
    axes[0, 0].hist(prompt_lengths, bins=20, alpha=0.7, label='Prompt', color='blue')
    axes[0, 0].hist(completion_lengths, bins=20, alpha=0.7, label='Completion', color='orange')
    axes[0, 0].set_xlabel('Length (tokens)')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('Token Length Distribution', fontweight='bold')
    axes[0, 0].legend()
    axes[0, 0].grid(alpha=0.3)
    
    # Box plot
    data_to_plot = [prompt_lengths, completion_lengths, total_lengths]
    axes[0, 1].boxplot(data_to_plot, labels=['Prompt', 'Completion', 'Total'])
    axes[0, 1].set_ylabel('Length (tokens)')
    axes[0, 1].set_title('Length Statistics', fontweight='bold')
    axes[0, 1].grid(alpha=0.3)
    
    # Scatter plot: prompt vs completion length
    axes[1, 0].scatter(prompt_lengths, completion_lengths, alpha=0.6, color='green')
    axes[1, 0].set_xlabel('Prompt Length (tokens)')
    axes[1, 0].set_ylabel('Completion Length (tokens)')
    axes[1, 0].set_title('Prompt vs Completion Length', fontweight='bold')
    axes[1, 0].grid(alpha=0.3)
    
    # Statistics table
    stats_data = {
        'Metric': ['Mean', 'Median', 'Min', 'Max', 'Std Dev'],
        'Prompt': [
            np.mean(prompt_lengths),
            np.median(prompt_lengths),
            np.min(prompt_lengths),
            np.max(prompt_lengths),
            np.std(prompt_lengths)
        ],
        'Completion': [
            np.mean(completion_lengths),
            np.median(completion_lengths),
            np.min(completion_lengths),
            np.max(completion_lengths),
            np.std(completion_lengths)
        ],
        'Total': [
            np.mean(total_lengths),
            np.median(total_lengths),
            np.min(total_lengths),
            np.max(total_lengths),
            np.std(total_lengths)
        ]
    }
    
    df = pd.DataFrame(stats_data)
    axes[1, 1].axis('tight')
    axes[1, 1].axis('off')
    table = axes[1, 1].table(
        cellText=df.values,
        colLabels=df.columns,
        cellLoc='center',
        loc='center'
    )
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1, 2)
    axes[1, 1].set_title('Length Statistics', fontweight='bold', pad=20)
    
    plt.tight_layout()
    plt.show()
    
    return {
        'prompt_lengths': prompt_lengths,
        'completion_lengths': completion_lengths,
        'total_lengths': total_lengths
    }


# Analyze datasets
print("Analyzing Prompt-Completion Dataset...")
pc_stats = analyze_dataset(prompt_completion_data, tokenizer, 'prompt_completion')

print("\nAnalyzing Instruction Dataset...")
inst_stats = analyze_dataset(instruction_data, tokenizer, 'instruction')

## 7. Practical Tips & Best Practices üí°

### Data Preparation Checklist:

‚úÖ **Data Format**
- Choose appropriate format (prompt-completion, instruction, chat)
- Ensure consistent formatting across dataset
- Include proper special tokens

‚úÖ **Tokenization**
- Use subword tokenization (BPE, WordPiece) in production
- Handle special tokens correctly
- Set appropriate `max_length`

‚úÖ **Quality Control**
- Filter too short/long examples
- Check prompt/completion ratios
- Remove duplicates
- Validate data integrity

‚úÖ **Batching**
- Use efficient padding (pad to batch max, not global max)
- Consider `pad_to_multiple_of` for TPU
- Set appropriate batch size for GPU memory

‚úÖ **Labels**
- Use `-100` for tokens to ignore in loss
- Only compute loss on completions (not prompts)
- Handle padding in labels

### Recommended Hyperparameters:

```python
# For 7B model on A100 (40GB)
max_length = 2048
batch_size = 4
gradient_accumulation_steps = 4
effective_batch_size = 16  # batch_size * grad_accum

# For 13B model on A100 (80GB)
max_length = 2048
batch_size = 2
gradient_accumulation_steps = 8
effective_batch_size = 16
```

### Production Considerations:

1. **Use Hugging Face Datasets**: Better memory efficiency, caching
2. **Streaming for large datasets**: Don't load everything into RAM
3. **Data mixing**: Combine multiple datasets with sampling ratios
4. **Packing**: Pack short sequences together to reduce padding waste

---

## 8. Summary üìù

### What We Learned:

‚úÖ Different data formats for fine-tuning  
‚úÖ Tokenization strategies and special tokens  
‚úÖ Dataset classes for different formats  
‚úÖ Efficient batching with padding  
‚úÖ Data quality control and filtering  
‚úÖ Data analysis and visualization  

### Key Takeaways:

1. **Format matters**: Choose the right format for your task
2. **Quality over quantity**: Filter low-quality examples
3. **Efficient batching**: Minimize padding waste
4. **Label correctly**: Use `-100` for ignored tokens
5. **Monitor statistics**: Understand your data distribution

### Next Steps:

- **Tutorial 4**: Instruction tuning with LoRA
- **Tutorial 5**: Evaluation metrics and model assessment

---

## üìö Resources

**Hugging Face:**
- Datasets: https://huggingface.co/docs/datasets
- Tokenizers: https://huggingface.co/docs/tokenizers

**Datasets:**
- Alpaca: https://github.com/tatsu-lab/stanford_alpaca
- OpenAssistant: https://huggingface.co/datasets/OpenAssistant/oasst1
- Dolly: https://huggingface.co/datasets/databricks/databricks-dolly-15k

**Related:**
- [transformer-foundation/01_embeddings_and_positional_encoding.ipynb](../transformer-foundation/01_embeddings_and_positional_encoding.ipynb)
- [01_introduction_to_fine_tuning.ipynb](01_introduction_to_fine_tuning.ipynb)
- [02_lora_implementation.ipynb](02_lora_implementation.ipynb)

---

**Ready to start training? Continue to Tutorial 4! üöÄ**