In [20]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    TrainerCallback,
    EarlyStoppingCallback
)
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import time
import json
from tqdm import tqdm
import argparse
import logging
from datetime import datetime


# ------------------ Custom Callback for Tracking Metrics ------------------ #
class MetricsTrackingCallback(TrainerCallback):
    """
    Callback to track and save training metrics during training.
    Includes loss, validation loss, and training time.
    """
    
    def __init__(self, log_dir="./logs"):
        self.log_dir = log_dir
        os.makedirs(log_dir, exist_ok=True)
        
        self.train_losses = []
        self.eval_losses = []
        self.train_times = []
        self.start_time = None
        
        # Setup logging
        self.log_file = os.path.join(log_dir, f"training_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
        logging.basicConfig(
            filename=self.log_file,
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)
    
    def on_train_begin(self, args, state, control, **kwargs):
        """Record the starting time when training begins."""
        self.start_time = time.time()
        self.logger.info(f"Training started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        self.logger.info(f"Training arguments: {args}")
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        """Record metrics at each logging step."""
        logs = logs or {}
        
        # Record training loss
        if 'loss' in logs:
            self.train_losses.append((state.global_step, logs['loss']))
            self.logger.info(f"Step {state.global_step}: Training loss = {logs['loss']}")
        
        # Record eval loss
        if 'eval_loss' in logs:
            self.eval_losses.append((state.global_step, logs['eval_loss']))
            self.logger.info(f"Step {state.global_step}: Evaluation loss = {logs['eval_loss']}")
        
        # Record elapsed time
        if self.start_time is not None:
            elapsed_time = time.time() - self.start_time
            self.train_times.append((state.global_step, elapsed_time))
            self.logger.info(f"Step {state.global_step}: Training time = {elapsed_time:.2f}s")
    
    def on_train_end(self, args, state, control, **kwargs):
        """Save all metrics at the end of training."""
        # Calculate total training time
        total_time = time.time() - self.start_time if self.start_time is not None else 0
        
        # Log final metrics
        self.logger.info(f"Training completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        self.logger.info(f"Total training time: {total_time:.2f} seconds")
        self.logger.info(f"Final training loss: {self.train_losses[-1][1] if self.train_losses else 'N/A'}")
        self.logger.info(f"Final evaluation loss: {self.eval_losses[-1][1] if self.eval_losses else 'N/A'}")
        
        # Save metrics to JSON file
        metrics_file = os.path.join(self.log_dir, "training_metrics.json")
        metrics = {
            "train_losses": self.train_losses,
            "eval_losses": self.eval_losses,
            "train_times": self.train_times,
            "total_time": total_time
        }
        
        with open(metrics_file, 'w') as f:
            json.dump(metrics, f, indent=2)
        
        # Create visualization of training progress
        self.visualize_training_progress()
    
    def visualize_training_progress(self):
        """Create visualization of training and validation loss over time."""
        plt.figure(figsize=(12, 8))
        
        # Plot training loss
        if self.train_losses:
            steps, losses = zip(*self.train_losses)
            plt.plot(steps, losses, label='Training Loss')
        
        # Plot validation loss
        if self.eval_losses:
            steps, losses = zip(*self.eval_losses)
            plt.plot(steps, losses, label='Validation Loss')
        
        plt.xlabel('Training Steps')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        
        # Save plot
        plt.savefig(os.path.join(self.log_dir, "training_loss.png"))
        plt.close()


# ------------------ Dataset Preparation ------------------ #
class CodeAlpacaDataset(Dataset):
    """
    Custom dataset for CodeAlpaca data with Python code examples.
    """
    
    def __init__(self, tokenizer, max_length=512, data_subset="train"):
        """
        Initialize the dataset.
        
        Args:
            tokenizer: Tokenizer to use for encoding
            max_length: Maximum sequence length
            data_subset: Data subset to use ("train" or "validation")
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Load dataset
        self.alpaca_data = load_dataset("HuggingFaceH4/CodeAlpaca_20K", split="train")
        
        # Filter for Python code
        python_keywords = ['def ', 'import ', 'lambda ', 'class ']
        def is_python_code(text):
            return any(keyword in text for keyword in python_keywords)
        
        self.python_dataset = self.alpaca_data.filter(lambda example: is_python_code(example['completion']))
        print(f"Loaded {len(self.python_dataset)} Python code examples from CodeAlpaca dataset")
        
        # Split into train/validation sets (90% / 10%)
        if data_subset == "train":
            self.dataset = self.python_dataset.select(range(int(len(self.python_dataset) * 0.9)))
        else:  # validation
            self.dataset = self.python_dataset.select(range(int(len(self.python_dataset) * 0.9), len(self.python_dataset)))
        
        print(f"Using {len(self.dataset)} examples for {data_subset}")
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        """Get a formatted and tokenized example."""
        example = self.dataset[idx]
        
        # Format the input
        input_text = self.format_example(example)
        
        # Tokenize
        encodings = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Remove the batch dimension
        item = {key: val.squeeze(0) for key, val in encodings.items()}
        item["labels"] = item["input_ids"].clone()
        
        return item
    
    def format_example(self, example):
        """Format an example for instruction fine-tuning."""
        return (
            "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            f"### Instruction:\n{example['prompt']}\n\n### Response:\n{example['completion']}"
        )

In [21]:
model_name = "gpt2-medium"

print(f"Loading tokenizer from {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Loading tokenizer from gpt2-medium...


In [22]:
# Setup model
print(f"Loading model from {model_name}...")
model = AutoModelForCausalLM.from_pretrained(model_name)


# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

Loading model from gpt2-medium...
Using device: cuda


In [23]:
class AdapterLayer(torch.nn.Module):
    """
    Implementation of an adapter layer for transformer models.
    """
    def __init__(self, input_dim, adapter_dim, dropout_rate=0.1):
        super(AdapterLayer, self).__init__()
        
        # Down-projection
        self.down_proj = torch.nn.Linear(input_dim, adapter_dim)
        
        # Non-linearity (GELU)
        self.activation = torch.nn.GELU()
        
        # Up-projection
        self.up_proj = torch.nn.Linear(adapter_dim, input_dim)
        
        # Layer normalization for stability
        self.layer_norm = torch.nn.LayerNorm(input_dim)
        
        # Dropout for regularization
        self.dropout = torch.nn.Dropout(dropout_rate)
        
        # Initialize weights
        self._init_weights()
        
        # Make sure parameters require gradients
        for param in self.parameters():
            param.requires_grad = True
    
    def _init_weights(self):
        """Initialize weights for stability."""
        # Initialize down projection with small values
        torch.nn.init.normal_(self.down_proj.weight, std=1e-3)
        torch.nn.init.zeros_(self.down_proj.bias)
        
        # Initialize up projection with zeros for residual stability
        torch.nn.init.zeros_(self.up_proj.weight)
        torch.nn.init.zeros_(self.up_proj.bias)
    
    def forward(self, x):
        """Forward pass with residual connection."""
        # Save residual
        residual = x
        
        # Ensure x is on the same device as our parameters
        device = next(self.parameters()).device
        if x.device != device:
            x = x.to(device)
            residual = residual.to(device)
        
        # Apply layer normalization
        x = self.layer_norm(x)
        
        # Down-projection
        x = self.down_proj(x)
        
        # Activation
        x = self.activation(x)
        
        # Dropout for regularization
        x = self.dropout(x)
        
        # Up-projection
        x = self.up_proj(x)
        
        # Residual connection
        return residual + x

In [24]:
def add_adapters_to_model(model, adapter_dim):
    """
    Add adapter layers to a GPT-2 model properly with correct device placement.
    
    Args:
        model: The GPT-2 model
        adapter_dim: Dimension of the adapter bottleneck
    
    Returns:
        Modified model with adapters
    """
    # Determine the device the model is on
    device = next(model.parameters()).device
    print(f"Model is on device: {device}")
    
    # Freeze all parameters in the original model
    for param in model.parameters():
        param.requires_grad = False
    
    # Get the hidden size from the model config
    hidden_size = model.config.hidden_size
    
    # Create a container to hold our adapters so they remain in memory
    if not hasattr(model, 'adapters'):
        model.adapters = {}
    
    # Add adapters to each transformer block
    for i, block in enumerate(model.transformer.h):
        # Create and attach the adapters to the model to ensure they're tracked
        attn_adapter_name = f"adapter_attn_{i}"
        mlp_adapter_name = f"adapter_mlp_{i}"
        
        # Create adapters and make sure their parameters require gradients
        attn_adapter = AdapterLayer(hidden_size, adapter_dim).to(device)
        mlp_adapter = AdapterLayer(hidden_size, adapter_dim).to(device)
        
        # Store adapters in the model
        model.adapters[attn_adapter_name] = attn_adapter
        model.adapters[mlp_adapter_name] = mlp_adapter
        
        # Create closures that correctly capture the adapters
        def make_attn_forward(orig_forward, adapter):
            def new_forward(self, *args, **kwargs):
                output = orig_forward(*args, **kwargs)
                # Check if the output is on the same device as the adapter
                if output[0].device != next(adapter.parameters()).device:
                    print(f"Warning: Device mismatch - output: {output[0].device}, adapter: {next(adapter.parameters()).device}")
                
                # Apply adapter to the output
                modified_output = adapter(output[0])
                # Return as a tuple like the original output
                return (modified_output,) + output[1:]
            return new_forward
        
        def make_mlp_forward(orig_forward, adapter):
            def new_forward(self, x):
                output = orig_forward(x)
                # Check if the output is on the same device as the adapter
                if output.device != next(adapter.parameters()).device:
                    print(f"Warning: Device mismatch - output: {output.device}, adapter: {next(adapter.parameters()).device}")
                
                # Apply adapter to the output
                return adapter(output)
            return new_forward
        
        # Store original forward methods
        original_attn_forward = block.attn.forward
        original_mlp_forward = block.mlp.forward
        
        # Apply the new forward methods with proper closure scope
        block.attn.forward = make_attn_forward(
            original_attn_forward, 
            model.adapters[attn_adapter_name]
        ).__get__(block.attn)
        
        block.mlp.forward = make_mlp_forward(
            original_mlp_forward, 
            model.adapters[mlp_adapter_name]
        ).__get__(block.mlp)
    
    # Register adapters as proper modules to ensure they're tracked
    for name, adapter in model.adapters.items():
        model.add_module(name, adapter)
    
    # Count trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    all_params = sum(p.numel() for p in model.parameters())
    
    print(f"Trainable parameters: {trainable_params:,} ({trainable_params/all_params:.2%} of total)")
    
    return model

In [25]:
# 1. Custom adapters
adapter_dim = 64  # Bottleneck dimension, typically 1/8 to 1/64 of hidden size
model = add_adapters_to_model(model, adapter_dim)

Model is on device: cuda:0
Trainable parameters: 6,441,984 (1.78% of total)


In [26]:
max_length = 512

# Create datasets
train_dataset = CodeAlpacaDataset(tokenizer, max_length=max_length, data_subset="train")
eval_dataset = CodeAlpacaDataset(tokenizer, max_length=max_length, data_subset="validation")


# Setup data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

Loaded 6409 Python code examples from CodeAlpaca dataset
Using 5768 examples for train
Loaded 6409 Python code examples from CodeAlpaca dataset
Using 641 examples for validation


In [27]:
# # Create datasets and move tensors to the correct device in the training loop
# def collate_fn(batch):
#     # Default collation
#     collated = data_collator(batch)
#     # Move to device
#     return {k: v.to(device) for k, v in collated.items()}


In [28]:
per_device_train_batch_size = 2
# Setup data loader with the custom collate function
# train_loader = DataLoader(
#     train_dataset,
#     batch_size=per_device_train_batch_size,
#     shuffle=True,
#     collate_fn=collate_fn
# )

In [29]:
output_dir = "./outputs"
batch_size = 1
num_train_epochs = 3
gradient_accumulation_steps = 4
eval_steps = 500
save_steps = 1000
logging_steps = 100
learning_rate = 5e-4  # Higher learning rate for adapters
weight_decay = 0.01
warmup_steps = 500
early_stopping_patience = 3

training_args = TrainingArguments(
    output_dir=os.path.join(output_dir, "checkpoints"),
    overwrite_output_dir=True,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    eval_steps=250,
    save_steps=250,
    logging_steps=100,
    save_total_limit=3,
    eval_strategy="steps",
    load_best_model_at_end=True,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    warmup_steps=warmup_steps,
    fp16=True,
    lr_scheduler_type="cosine",
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

In [30]:
callbacks = [
    MetricsTrackingCallback(log_dir=os.path.join(output_dir, "logs")),
    EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)
]

In [31]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    callbacks=callbacks
)

# Train the model
print(f"Starting adapter fine-tuning for {num_train_epochs} epochs...")
trainer.train()

Starting adapter fine-tuning for 3 epochs...


Step,Training Loss,Validation Loss
250,1.1137,0.935739
500,0.9009,0.829503
750,0.8462,0.781986
1000,0.7876,0.7547


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


TrainOutput(global_step=1080, training_loss=1.0035847416630497, metrics={'train_runtime': 1996.0074, 'train_samples_per_second': 8.669, 'train_steps_per_second': 0.541, 'total_flos': 1.6374739463307264e+16, 'train_loss': 1.0035847416630497, 'epoch': 2.9930651872399445})

In [32]:
# Save the fine-tuned model
model_save_path = os.path.join(output_dir, "gpt2-medium_adapter_eps")
trainer.save_model(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f"Model and tokenizer saved to {model_save_path}")

Model and tokenizer saved to ./outputs/gpt2-medium_adapter_eps
