# EECS 595 HW3: Debug SFT Training

This notebook provides step-by-step verification of your SFT (Supervised Fine-Tuning) implementation.

## Instructions for Students:

1. **Implement the TODO sections** in `sft.py` before running the corresponding cells
2. **Run each cell in order** to verify your implementation step by step
3. **Use `importlib.reload()`** to reload your latest code changes
4. **Check the output** of each test to ensure your implementation is correct

## TODO Requirements by Cell:

- **Cell 2**: Requires TODO 1.15 (setup_tokenizer from gpt.py)
- **Cell 3**: Requires TODO 3.1 (SFTDataset _build_ids_labels)
- **Cell 4**: Requires TODO 3.2 (sft_data_collator)
- **Cell 5**: Requires TODO 3.5 (create_sft_dataloader)
- **Cell 6**: Requires TODO 3.3 (generate_chat_response)
- **Cell 7**: Requires TODO 3.4 (evaluate_validation_loss)
- **Cell 8**: Requires TODO 3.1, 3.2, 3.3, 3.4, 3.5 (Complete SFT pipeline integration test)

## Key Features of SFT Implementation:
- **Conversation Formatting**: Structures dialogue data with special tokens
- **Selective Masking**: Only trains on assistant responses, masks user/system tokens
- **Token Masking**: Uses -100 labels to ignore certain tokens during training
- **Conversational Generation**: Generates responses in chat-like format

Let's start by importing the necessary modules and setting up the environment.


In [None]:
# Cell 1: Imports and Setup
import os
import json
import math
import numpy as np
import random
import logging
import importlib
from typing import Optional, Callable, List, Tuple, Dict, Any, Iterable
from copy import deepcopy
import gzip

# PyTorch imports
import torch
import torch.nn as nn
import torch.functional as F
from torch.nn import RMSNorm
from torch.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# Transformers and tokenization
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup, default_data_collator

# Data handling
from datasets import load_from_disk
import orjson

# Progress tracking
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt
import wandb

# Import our GPT and SFT implementations
import gpt
import sft
importlib.reload(gpt)  # Reload to get latest changes
importlib.reload(sft)  # Reload to get latest changes

print("‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")


## Cell 2: Tokenizer Setup for SFT

**Required TODOs**: 1.15 (setup_tokenizer from gpt.py)

First, let's set up the tokenizer with special tokens for conversational AI. This is crucial for SFT as we need special tokens to mark different parts of conversations.


In [2]:
# Cell 2: Make the tokenizer
tokenizer = gpt.setup_tokenizer()

## Cell 3: Test SFTDataset Class

**Required TODOs**: 3.1 (SFTDataset _build_ids_labels)

Now let's test the SFTDataset class, which is responsible for loading and formatting conversational data with proper token masking.


In [3]:
# Cell 3: Test SFTDataset Class
importlib.reload(sft)  # Reload to get latest changes

# Create sample conversation data for testing
sample_conversations = [
    [
        {"role": "user", "content": "Hello! How are you?"},
        {"role": "assistant", "content": "Hi there! I'm doing well, thank you for asking. How can I help you today?"},
        {"role": "user", "content": "What's the weather like?"},
        {"role": "assistant", "content": "I don't have access to real-time weather data, but I'd be happy to help you find weather information for your location!"}
    ],
    [
        {"role": "system", "content": "You are a helpful math tutor."},
        {"role": "user", "content": "Can you help me solve 2x + 5 = 13?"},
        {"role": "assistant", "content": "Of course! Let's solve this step by step:\n\n1. Subtract 5 from both sides: 2x = 13 - 5 = 8\n2. Divide both sides by 2: x = 8/2 = 4\n\nSo x = 4!"}
    ],
    [
        {"role": "user", "content": "Tell me a joke"},
        {"role": "assistant", "content": "Why don't scientists trust atoms? Because they make up everything! üòÑ"}
    ]
]

# Create a temporary test file
test_file = "test_conversations.jsonl"
with open(test_file, 'w') as f:
    for conv in sample_conversations:
        f.write(json.dumps({"messages": conv}) + "\n")

print(f"‚úÖ Created test file with {len(sample_conversations)} conversations")

# Test SFTDataset
try:
    dataset = sft.SFTDataset(test_file, tokenizer, max_length=256)
    print(f"‚úÖ SFTDataset created successfully!")
    print(f"   Dataset length: {len(dataset)}")
    print(f"   Max length: {dataset.max_length}")

    # Test getting a sample
    sample_input_ids, sample_labels = dataset[0]
    print(f"   Sample input_ids shape: {sample_input_ids.shape}")
    print(f"   Sample labels shape: {sample_labels.shape}")

    # Check masking logic
    training_tokens = sum(1 for l in sample_labels if l != -100)
    masked_tokens = sum(1 for l in sample_labels if l == -100)
    total_tokens = len(sample_labels)

    print(f"   Training tokens: {training_tokens} ({training_tokens/total_tokens*100:.1f}%)")
    print(f"   Masked tokens: {masked_tokens} ({masked_tokens/total_tokens*100:.1f}%)")

    # Decode first few tokens to verify format
    first_tokens = sample_input_ids[:20]
    decoded = tokenizer.decode(first_tokens, skip_special_tokens=False)
    print(f"   First 20 tokens decoded: {decoded}")

except Exception as e:
    print(f"‚ùå Error testing SFTDataset: {e}")
    print("   Make sure you've implemented the SFTDataset class in sft.py")

# Clean up test file
os.remove(test_file)
print("‚úÖ Test file cleaned up")


‚úÖ Created test file with 3 conversations


Loading dataset: 0it [00:00, ?it/s]

‚úÖ SFTDataset created successfully!
   Dataset length: 3
   Max length: 256
   Sample input_ids shape: torch.Size([67])
   Sample labels shape: torch.Size([67])
   Training tokens: 51 (76.1%)
   Masked tokens: 16 (23.9%)
   First 20 tokens decoded: <|user|>Hello! How are you?<|end|><|assistant|>Hi there! I'm doing well, thank you for
‚úÖ Test file cleaned up


## Cell 4: Test Data Collators

**Required TODOs**: 3.2 (sft_data_collator)

Let's test the data collators that handle batching for SFT training. We'll test both the regular collator and the HuggingFace-style collator for packed datasets.


In [4]:
# Cell 4: Test Data Collators
importlib.reload(sft)  # Reload to get latest changes

# Create test data for collator testing
test_file = "test_collator_data.jsonl"
with open(test_file, 'w') as f:
    for conv in sample_conversations:
        f.write(json.dumps({"messages": conv}) + "\n")

# Create dataset
dataset = sft.SFTDataset(test_file, tokenizer, max_length=128)

print("Testing SFT Data Collator...")

# Test sft_data_collator
try:
    # Create a batch of samples
    batch_samples = []
    for i in range(min(3, len(dataset))):
        batch_samples.append(dataset[i])

    print(f"‚úÖ Created batch with {len(batch_samples)} samples")

    # Test the collator
    collated_batch = sft.sft_data_collator(batch_samples)

    print(f"‚úÖ Collator executed successfully!")
    print(f"   Batch input_ids shape: {collated_batch['input_ids'].shape}")
    print(f"   Batch labels shape: {collated_batch['labels'].shape}")

    # Check that all sequences are the same length
    batch_size, seq_length = collated_batch['input_ids'].shape
    print(f"   Batch size: {batch_size}, Sequence length: {seq_length}")

    # Check padding
    pad_token_id = tokenizer.pad_token_id or 0
    padding_tokens = (collated_batch['input_ids'] == pad_token_id).sum().item()
    total_tokens = batch_size * seq_length
    print(f"   Padding tokens: {padding_tokens} ({padding_tokens/total_tokens*100:.1f}%)")

    # Check masking
    masked_tokens = (collated_batch['labels'] == -100).sum().item()
    print(f"   Masked tokens: {masked_tokens} ({masked_tokens/total_tokens*100:.1f}%)")

except Exception as e:
    print(f"‚ùå Error testing sft_data_collator: {e}")
    print("   Make sure you've implemented the sft_data_collator function in sft.py")

print("\nTesting HuggingFace-style Collator...")

# Test hf_collate (for packed datasets)
try:
    # Create mock packed data
    packed_examples = [
        {
            "input_ids": [1, 2, 3, 4, 5] + [0] * 123,  # 128 total
            "labels": [1, 2, 3, 4, 5] + [-100] * 123
        },
        {
            "input_ids": [6, 7, 8, 9, 10] + [0] * 123,  # 128 total
            "labels": [6, 7, 8, 9, 10] + [-100] * 123
        }
    ]

    hf_batch = sft.hf_collate(packed_examples)

    print(f"‚úÖ HF collator executed successfully!")
    print(f"   Batch input_ids shape: {hf_batch['input_ids'].shape}")
    print(f"   Batch labels shape: {hf_batch['labels'].shape}")
    print(f"   Attention mask shape: {hf_batch['attention_mask'].shape}")

    # Verify attention mask
    expected_mask = (hf_batch['input_ids'] != 0).long()
    mask_correct = torch.equal(hf_batch['attention_mask'], expected_mask)
    print(f"   Attention mask correct: {mask_correct}")

except Exception as e:
    print(f"‚ùå Error testing hf_collate: {e}")
    print("   Make sure you've implemented the hf_collate function in sft.py")

# Clean up
os.remove(test_file)
print("\n‚úÖ Test files cleaned up")


Loading dataset: 0it [00:00, ?it/s]

Testing SFT Data Collator...
‚úÖ Created batch with 3 samples
‚úÖ Collator executed successfully!
   Batch input_ids shape: torch.Size([3, 79])
   Batch labels shape: torch.Size([3, 79])
   Batch size: 3, Sequence length: 79
   Padding tokens: 0 (0.0%)
   Masked tokens: 113 (47.7%)

Testing HuggingFace-style Collator...
‚úÖ HF collator executed successfully!
   Batch input_ids shape: torch.Size([2, 128])
   Batch labels shape: torch.Size([2, 128])
   Attention mask shape: torch.Size([2, 128])
   Attention mask correct: True

‚úÖ Test files cleaned up


## Cell 5: Test DataLoader Creation

**Required TODOs**: 3.5 (create_sft_dataloader)

Let's test the create_sft_dataloader function, which handles both regular and packed dataset formats.


In [5]:
# Cell 5: Test DataLoader Creation
importlib.reload(sft)  # Reload to get latest changes

# Create test data file
test_file = "test_dataloader_data.jsonl"
with open(test_file, 'w') as f:
    for conv in sample_conversations:
        f.write(json.dumps({"messages": conv}) + "\n")

print("Testing Regular DataLoader Creation...")

# Test regular dataloader (use_packed=False)
try:
    regular_loader = sft.create_sft_dataloader(
        data_file=test_file,
        tokenizer=tokenizer,
        batch_size=2,
        max_length=128,
        shuffle=False,
        drop_last=False,
        num_workers=0,
        use_packed=False
    )

    print(f"‚úÖ Regular DataLoader created successfully!")
    print(f"   Number of batches: {len(regular_loader)}")

    # Test getting a batch
    batch = next(iter(regular_loader))
    print(f"   Batch input_ids shape: {batch['input_ids'].shape}")
    print(f"   Batch labels shape: {batch['labels'].shape}")

    # Check batch properties
    batch_size, seq_length = batch['input_ids'].shape
    print(f"   Batch size: {batch_size}, Sequence length: {seq_length}")

except Exception as e:
    print(f"‚ùå Error testing regular DataLoader: {e}")
    print("   Make sure you've implemented the create_sft_dataloader function in sft.py")

print("\nTesting Packed DataLoader Creation...")

# Test packed dataloader (use_packed=True)
# Note: This will fail if no packed dataset exists, which is expected
try:
    packed_loader = sft.create_sft_dataloader(
        data_file="nonexistent_packed_data.arrow",  # This will fail
        tokenizer=tokenizer,
        batch_size=2,
        max_length=128,
        shuffle=False,
        drop_last=False,
        num_workers=0,
        use_packed=True
    )

    print(f"‚úÖ Packed DataLoader created successfully!")

except FileNotFoundError:
    print("‚ö†Ô∏è  Packed dataset not found (expected for this test)")
    print("   This is normal - packed datasets are created separately")
except Exception as e:
    print(f"‚ùå Error testing packed DataLoader: {e}")

print("\nTesting DataLoader Iteration...")

# Test iterating through the regular dataloader
try:
    batch_count = 0
    total_samples = 0

    for batch in regular_loader:
        batch_count += 1
        total_samples += batch['input_ids'].shape[0]

        if batch_count <= 2:  # Only show first 2 batches
            print(f"   Batch {batch_count}: {batch['input_ids'].shape[0]} samples")

    print(f"‚úÖ DataLoader iteration successful!")
    print(f"   Total batches: {batch_count}")
    print(f"   Total samples: {total_samples}")

except Exception as e:
    print(f"‚ùå Error iterating through DataLoader: {e}")

# Clean up
os.remove(test_file)
print("\n‚úÖ Test file cleaned up")


Testing Regular DataLoader Creation...
Creating SFTDataset from test_dataloader_data.jsonl...


Loading dataset: 0it [00:00, ?it/s]

‚úÖ DataLoader created successfully for test_dataloader_data.jsonl!
‚úÖ Regular DataLoader created successfully!
   Number of batches: 2
   Batch input_ids shape: torch.Size([2, 79])
   Batch labels shape: torch.Size([2, 79])
   Batch size: 2, Sequence length: 79

Testing Packed DataLoader Creation...
Loading packed dataset from nonexistent_packed_data.arrow...
‚ö†Ô∏è  Packed dataset not found (expected for this test)
   This is normal - packed datasets are created separately

Testing DataLoader Iteration...
   Batch 1: 2 samples
   Batch 2: 1 samples
‚úÖ DataLoader iteration successful!
   Total batches: 2
   Total samples: 3

‚úÖ Test file cleaned up


## Cell 6: Test Text Generation Functions

**Required TODOs**: 3.3 (generate_chat_response)

Now let's test the conversational text generation functions. We'll need a simple model for testing, so we'll create a minimal GPT model.


In [None]:
# Cell 6: Test Text Generation Functions
importlib.reload(gpt)  # Reload to get latest changes
importlib.reload(sft)  # Reload to get latest changes

print("Creating a minimal GPT model for testing...")
actual_vocab_size = 200
# Create a very small model for testing
model_config = {
    "vocab_size": actual_vocab_size,
    "context_length": 128,
    "emb_dim": 64,
    "n_heads": 4,
    "n_layers": 2,
    "drop_rate": 0.1,
}

try:
    # Create model
    test_model = gpt.GPTModel(model_config)
    print(f"‚úÖ Test model created successfully!")
    print(f"   Model parameters: {sum(p.numel() for p in test_model.parameters()):,}")

    # Move to CPU for testing
    device = 'cpu'
    test_model = test_model.to(device)
    test_model.eval()

    print(f"   Model moved to {device}")

except Exception as e:
    print(f"‚ùå Error creating test model: {e}")
    print("   Make sure you've implemented the GPTModel class in gpt.py")
    test_model = None

if test_model is not None:
    print("\nTesting Single-Turn Chat Generation...")

    # Test generate_chat_response
    try:
        user_message = "Hello! How are you?"
        print(f"User message: '{user_message}'")

        response = sft.generate_chat_response(
            model=test_model,
            tokenizer=tokenizer,
            user_message=user_message,
            max_new_tokens=20,
            temperature=0.7
        )

        print(f"‚úÖ Single-turn generation successful!")
        print(f"   Generated response: '{response}'")

    except Exception as e:
        print(f"‚ùå Error testing single-turn generation: {e}")
        print("   Make sure you've implemented the generate_chat_response function in sft.py")

    print("\nTesting Multi-Turn Chat Generation...")

    # Test generate_multi_turn_response
    try:
        conversation_history = [
            {"role": "user", "content": "Hi there!"},
            {"role": "assistant", "content": "Hello! How can I help you?"},
            {"role": "user", "content": "What's 2+2?"}
        ]

        print("Conversation history:")
        for msg in conversation_history:
            print(f"  {msg['role']}: {msg['content']}")

        response = sft.generate_multi_turn_response(
            model=test_model,
            tokenizer=tokenizer,
            conversation_history=conversation_history,
            max_new_tokens=20,
            temperature=0.7
        )

        print(f"‚úÖ Multi-turn generation successful!")
        print(f"   Generated response: '{response}'")

    except Exception as e:
        print(f"‚ùå Error testing multi-turn generation: {e}")
        print("   Make sure you've implemented the generate_multi_turn_response function in sft.py")

print("\n‚úÖ Text generation testing complete!")
print("Note: The generated text will be random since we're using an untrained model.")
print("This is expected - we're just testing that the functions work correctly.")


Creating a minimal GPT model for testing...
‚úÖ Test model created successfully!
   Model parameters: 112,304
   Model moved to cpu

Testing Single-Turn Chat Generation...
User message: 'Hello! How are you?'
‚ùå Error testing single-turn generation: index out of range in self
   Make sure you've implemented the generate_chat_response function in sft.py

Testing Multi-Turn Chat Generation...
Conversation history:
  user: Hi there!
  assistant: Hello! How can I help you?
  user: What's 2+2?
‚ùå Error testing multi-turn generation: generate_chat_response() got an unexpected keyword argument 'context'
   Make sure you've implemented the generate_multi_turn_response function in sft.py

‚úÖ Text generation testing complete!
Note: The generated text will be random since we're using an untrained model.
This is expected - we're just testing that the functions work correctly.


## Cell 7: Test Model Loading and Validation

**Required TODOs**: 3.4 (evaluate_validation_loss)

Let's test the utility functions for loading pre-trained models and evaluating validation loss.


In [7]:
# Cell 7: Test Model Loading and Validation
importlib.reload(sft)  # Reload to get latest changes

print("Testing Model Loading Function...")

# Test load_pretrained_model with a non-existent file (expected to fail)
try:
    fake_model = sft.load_pretrained_model("nonexistent_model.pth", model_config)
    print("‚ùå This should have failed!")
except FileNotFoundError:
    print("‚úÖ Model loading correctly handles missing files")
except Exception as e:
    print(f"‚úÖ Model loading error handling works: {type(e).__name__}")

print("\nTesting Validation Loss Evaluation...")

# Create a small validation dataset for testing
val_file = "test_validation_data.jsonl"
with open(val_file, 'w') as f:
    for conv in sample_conversations[:2]:  # Use only 2 conversations
        f.write(json.dumps({"messages": conv}) + "\n")

try:
    # Create validation dataloader
    val_loader = sft.create_sft_dataloader(
        data_file=val_file,
        tokenizer=tokenizer,
        batch_size=1,
        max_length=64,
        shuffle=False,
        drop_last=False,
        num_workers=0,
        use_packed=False
    )

    print(f"‚úÖ Validation DataLoader created with {len(val_loader)} batches")

    # Test evaluate_validation_loss
    if test_model is not None:
        loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)

        val_loss = sft.evaluate_validation_loss(
            model=test_model,
            val_loader=val_loader,
            loss_fn=loss_fn,
            device=device
        )

        print(f"‚úÖ Validation loss evaluation successful!")
        print(f"   Validation loss: {val_loss:.4f}")
        print("   Note: This loss is from an untrained model, so it will be high.")

    else:
        print("‚ö†Ô∏è  Skipping validation loss test - no test model available")

except Exception as e:
    print(f"‚ùå Error testing validation loss: {e}")
    print("   Make sure you've implemented the evaluate_validation_loss function in sft.py")

# Clean up
os.remove(val_file)
print("\n‚úÖ Test files cleaned up")


Testing Model Loading Function...
‚úÖ Model loading error handling works: NameError

Testing Validation Loss Evaluation...
Creating SFTDataset from test_validation_data.jsonl...


Loading dataset: 0it [00:00, ?it/s]

‚úÖ DataLoader created successfully for test_validation_data.jsonl!
‚úÖ Validation DataLoader created with 2 batches
‚ùå Error testing validation loss: name 'test_model' is not defined
   Make sure you've implemented the evaluate_validation_loss function in sft.py

‚úÖ Test files cleaned up


## Cell 8: Integration Test - Complete SFT Pipeline

**Required TODOs**: 3.1, 3.2, 3.3, 3.4, 3.5 (Complete SFT pipeline integration test)

Let's run a complete integration test that combines all the SFT components together.


In [None]:
# Cell 8: Integration Test - Complete SFT Pipeline
importlib.reload(gpt)  # Reload to get latest changes
importlib.reload(sft)  # Reload to get latest changes

print("üöÄ Running Complete SFT Pipeline Integration Test")
print("=" * 60)

# Create comprehensive test data
integration_test_file = "integration_test_data.jsonl"
test_conversations = [
    [
        {"role": "user", "content": "What is machine learning?"},
        {"role": "assistant", "content": "Machine learning is a subset of artificial intelligence that enables computers to learn and make decisions from data without being explicitly programmed."}
    ],
    [
        {"role": "system", "content": "You are a helpful coding assistant."},
        {"role": "user", "content": "How do I create a list in Python?"},
        {"role": "assistant", "content": "You can create a list in Python using square brackets. For example: my_list = [1, 2, 3, 'hello']"}
    ],
    [
        {"role": "user", "content": "Explain neural networks"},
        {"role": "assistant", "content": "Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes (neurons) that process information through weighted connections."}
    ],
    [
        {"role": "user", "content": "What's the difference between supervised and unsupervised learning?"},
        {"role": "assistant", "content": "Supervised learning uses labeled training data to learn patterns, while unsupervised learning finds patterns in data without labels."}
    ]
]

with open(integration_test_file, 'w') as f:
    for conv in test_conversations:
        f.write(json.dumps({"messages": conv}) + "\n")

print(f"‚úÖ Created test data with {len(test_conversations)} conversations")

# Step 1: Create dataset
print("\nüìä Step 1: Creating SFT Dataset")
try:
    dataset = sft.SFTDataset(integration_test_file, tokenizer, max_length=128)
    print(f"‚úÖ Dataset created: {len(dataset)} conversations")

    # Analyze masking
    sample_input_ids, sample_labels = dataset[0]
    training_tokens = sum(1 for l in sample_labels if l != -100)
    masked_tokens = sum(1 for l in sample_labels if l == -100)
    total_tokens = len(sample_labels)

    print(f"   Training tokens: {training_tokens}/{total_tokens} ({training_tokens/total_tokens*100:.1f}%)")
    print(f"   Masked tokens: {masked_tokens}/{total_tokens} ({masked_tokens/total_tokens*100:.1f}%)")

except Exception as e:
    print(f"‚ùå Dataset creation failed: {e}")
    dataset = None

# Step 2: Create dataloader
print("\nüîÑ Step 2: Creating DataLoader")
if dataset is not None:
    try:
        dataloader = sft.create_sft_dataloader(
            data_file=integration_test_file,
            tokenizer=tokenizer,
            batch_size=2,
            max_length=128,
            shuffle=False,
            drop_last=False,
            num_workers=0,
            use_packed=False
        )
        print(f"‚úÖ DataLoader created: {len(dataloader)} batches")

        # Test batch iteration
        batch = next(iter(dataloader))
        print(f"   Batch shape: {batch['input_ids'].shape}")

    except Exception as e:
        print(f"‚ùå DataLoader creation failed: {e}")
        dataloader = None
else:
    dataloader = None

# Step 3: Test model forward pass
print("\nüß† Step 3: Testing Model Forward Pass")
if dataloader is not None and test_model is not None:
    try:
        batch = next(iter(dataloader))
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        with torch.no_grad():
            logits = test_model(input_ids)

        print(f"‚úÖ Forward pass successful!")
        print(f"   Input shape: {input_ids.shape}")
        print(f"   Logits shape: {logits.shape}")
        print(f"   Labels shape: {labels.shape}")

        # Test loss computation
        loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()
        loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        print(f"   Loss: {loss.item():.4f}")

    except Exception as e:
        print(f"‚ùå Forward pass failed: {e}")

# Step 4: Test text generation
print("\nüí¨ Step 4: Testing Text Generation")
if test_model is not None:
    try:
        # Test single-turn generation
        user_message = "What is artificial intelligence?"
        response = sft.generate_chat_response(
            model=test_model,
            tokenizer=tokenizer,
            user_message=user_message,
            max_new_tokens=15,
            temperature=0.7
        )

        print(f"‚úÖ Single-turn generation successful!")
        print(f"   User: {user_message}")
        print(f"   Assistant: {response}")

        # Test multi-turn generation
        conv_history = [
            {"role": "user", "content": "Hello!"},
            {"role": "assistant", "content": "Hi there!"},
            {"role": "user", "content": "Tell me about AI"}
        ]

        multi_response = sft.generate_multi_turn_response(
            model=test_model,
            tokenizer=tokenizer,
            conversation_history=conv_history,
            max_new_tokens=15,
            temperature=0.7
        )

        print(f"‚úÖ Multi-turn generation successful!")
        print(f"   Multi-turn response: {multi_response}")

    except Exception as e:
        print(f"‚ùå Text generation failed: {e}")

# Step 5: Test validation
print("\nüìà Step 5: Testing Validation")
if dataloader is not None and test_model is not None:
    try:
        loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
        val_loss = sft.evaluate_validation_loss(test_model, dataloader, loss_fn, device)

        print(f"‚úÖ Validation successful!")
        print(f"   Validation loss: {val_loss:.4f}")

    except Exception as e:
        print(f"‚ùå Validation failed: {e}")

# Summary
print("\n" + "=" * 60)
print("üéâ SFT Pipeline Integration Test Complete!")
print("=" * 60)

if dataset is not None and dataloader is not None and test_model is not None:
    print("‚úÖ All core SFT components are working correctly!")
    print("‚úÖ You're ready to implement the training loop in sft_gpt.py")
else:
    print("‚ö†Ô∏è  Some components need to be implemented:")
    if dataset is None:
        print("   - SFTDataset class")
    if dataloader is None:
        print("   - create_sft_dataloader function")
    if test_model is None:
        print("   - GPTModel class (from gpt.py)")

print("\nNext steps:")
print("1. Implement any missing components in sft.py")
print("2. Run the SFT training script: python sft_gpt.py")
print("3. Use the ChatWithGPT.ipynb notebook to test your trained model")

# Clean up
os.remove(integration_test_file)
print("\n‚úÖ Test files cleaned up")


## Summary and Next Steps

### What We've Tested

This notebook has systematically tested all the core SFT components:

1. **‚úÖ Tokenizer Setup**: Special tokens for conversational AI
2. **‚úÖ SFTDataset**: Loading and formatting conversations with proper masking
3. **‚úÖ Data Collators**: Both regular and packed dataset collation
4. **‚úÖ DataLoader Creation**: Support for both regular and packed formats
5. **‚úÖ Text Generation**: Single-turn and multi-turn conversation generation
6. **‚úÖ Model Loading**: Pre-trained model loading utilities
7. **‚úÖ Validation**: Loss evaluation on validation data
8. **‚úÖ Integration Test**: Complete SFT pipeline verification

### Key Concepts Verified

- **Token Masking**: Only assistant tokens contribute to loss (labels != -100)
- **Special Tokens**: Proper handling of `<|user|>`, `<|assistant|>`, `<|end|>`, `<|system|>`
- **Data Formats**: Both regular jsonlines and packed Arrow datasets
- **Conversation Format**: Proper structuring of multi-turn dialogues
- **Generation**: Autoregressive text generation with conversation context

### Implementation Checklist

Before running SFT training, make sure you've implemented:

- [ ] `SFTDataset` class in `sft.py`
- [ ] `sft_data_collator` function in `sft.py`
- [ ] `hf_collate` function in `sft.py`
- [ ] `create_sft_dataloader` function in `sft.py`
- [ ] `generate_chat_response` function in `sft.py`
- [ ] `generate_multi_turn_response` function in `sft.py`
- [ ] `load_pretrained_model` function in `sft.py`
- [ ] `evaluate_validation_loss` function in `sft.py`

### Next Steps

1. **Complete Implementation**: Implement any missing functions in `sft.py`
2. **Run Training**: Use `python sft_gpt.py` or the provided shell scripts
3. **Test Your Model**: Use `ChatWithGPT.ipynb` to interact with your trained model
4. **Evaluate Performance**: Use `score_gpt.py` to test on multiple choice questions

### Troubleshooting Tips

- **Import Errors**: Make sure all functions are properly implemented in `sft.py`
- **Shape Mismatches**: Check tensor dimensions at each step
- **Masking Issues**: Verify that only assistant tokens have labels != -100
- **Generation Problems**: Ensure special tokens are properly handled
- **Memory Issues**: Use smaller batch sizes or max_length for testing

Good luck with your SFT implementation! üöÄ


In [None]:
# Cell 2: Tokenizer Setup for SFT
importlib.reload(sft)  # Reload to get latest changes

# Set up the tokenizer with special tokens for conversation
tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True)

# Ensure we have a pad token (GPT-2 doesn't by default)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

# Add conversation-specific special tokens
special_tokens_dict = {
    "additional_special_tokens": ["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"]
}
tokenizer.add_special_tokens(special_tokens_dict)

# Calculate the actual vocabulary size after adding special tokens
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>", "<|system|>", "<|pad|>"]
max_token_id = max(tokenizer.convert_tokens_to_ids(token) for token in special_tokens)
actual_vocab_size = max_token_id + 1

print(f"‚úÖ Tokenizer initialized with {actual_vocab_size} tokens")
print(f"Special token IDs:")
for token in special_tokens:
    token_id = tokenizer.convert_tokens_to_ids(token)
    print(f"  {token}: {token_id}")

# Test tokenization with conversation format
test_conversation = "<|user|>Hello!<|end|><|assistant|>Hi there!<|end|>"
tokens = tokenizer.encode(test_conversation)
decoded = tokenizer.decode(tokens)

print(f"\nTest conversation: '{test_conversation}'")
print(f"Tokens: {tokens}")
print(f"Decoded: '{decoded}'")

print("\n‚úÖ Tokenizer setup complete!")
