# Enhanced CodeBERT for Swift Code Understanding (TPU Version)

In this notebook, we fine-tune the [CodeBERT](https://github.com/microsoft/CodeBERT) model on the [Swift Code Intelligence dataset](https://huggingface.co/datasets/mvasiliniuc/iva-swift-codeint) using TPU acceleration. CodeBERT is a pre-trained model specifically designed for programming languages, much like how BERT was pre-trained for natural language text. Created by Microsoft Research, CodeBERT can understand both programming language and natural language, making it ideal for code-related tasks.

Unlike the previous version that focused only on identifying Package.swift files, this enhanced version trains the model on the entire dataset by classifying Swift files into meaningful categories based on their purpose in a codebase.

## Overview

The process of fine-tuning CodeBERT involves:

1. **🔧 Setup**: Install necessary libraries and prepare our TPU environment
2. **📥 Data Loading**: Load the Swift code dataset from Hugging Face with detailed logging
3. **🧹 Enhanced Preprocessing**: Prepare the data for training by categorizing files and tokenizing the code samples
4. **🧠 TPU-Accelerated Model Training**: Fine-tune CodeBERT on our prepared data using TPU acceleration
5. **📊 Evaluation**: Assess how well our model performs with comprehensive metrics
6. **📤 Export & Upload**: Save the model and upload it to Dropbox

Let's start by installing the necessary libraries:

In [None]:
# Install required libraries for TPU training
!pip install -q cloud-tpu-client==0.10 torch==1.13.1 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.13-cp38-cp38-linux_x86_64.whl
!pip install -q transformers datasets evaluate scikit-learn tqdm dropbox requests pandas matplotlib

# Log installation process
import logging
import sys
import os
from datetime import datetime

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler('codebert_tpu_training.log')
    ]
)
logger = logging.getLogger('codebert-tpu')

logger.info("Starting CodeBERT TPU training notebook setup")
logger.info(f"Notebook execution started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("Required libraries installation completed")

In [None]:
# Import necessary libraries
import os
import json
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import random
import numpy as np
import time
import gc
import re
import collections
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from datasets import load_dataset, ClassLabel
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support, confusion_matrix
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification,
    RobertaForSequenceClassification,
    Trainer, 
    TrainingArguments,
    set_seed,
    DataCollatorWithPadding,
    EarlyStoppingCallback,
    get_scheduler
)

# Import AdamW from torch.optim
from torch.optim import AdamW
from transformers.trainer_utils import get_last_checkpoint

# Set a seed for reproducibility
SEED = 42
set_seed(SEED)

# Add memory management function
def cleanup_memory():
    """Force garbage collection and clear XLA cache if available."""
    gc.collect()
    if 'torch_xla' in sys.modules:
        xm.rendezvous('cleanup_memory')
        xm.mark_step()
    logger.info("Memory cleaned up.")

logger.info("Libraries imported successfully")
logger.info(f"Random seed set to: {SEED}")

## TPU Detection and Configuration

Let's detect and configure the TPU environment:

In [None]:
# TPU detection and configuration
def detect_and_configure_tpu():
    """Detect and configure TPU for training."""
    try:
        # Check if TPU is available
        import torch_xla.core.xla_model as xm
        logger.info("TPU libraries imported successfully")
        
        # Get TPU device
        device = xm.xla_device()
        logger.info(f"TPU device detected: {device}")
        
        # Get TPU core count
        tpu_cores = xm.xrt_world_size()
        logger.info(f"Number of TPU cores available: {tpu_cores}")
        
        return device, tpu_cores
    except Exception as e:
        logger.error(f"Error detecting TPU: {e}")
        logger.warning("Falling back to CPU. This notebook is designed for TPU and may not work optimally on CPU.")
        return torch.device('cpu'), 1

# Detect and configure TPU
device, tpu_cores = detect_and_configure_tpu()

# Set random seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

logger.info(f"Using device: {device}")
logger.info(f"Random seeds set for reproducibility: {SEED}")

## Dataset and Model Configuration

Let's define the model and dataset we'll be using with TPU-optimized parameters:

In [None]:
# Dataset configuration
DATASET_ID = "mvasiliniuc/iva-swift-codeint"

# Model configuration - optimized for TPU
MODEL_NAME = "microsoft/codebert-base"
MAX_LENGTH = 512

# TPU-optimized batch sizes and training parameters
# Adjust batch size based on TPU cores available
BATCH_SIZE = 8 * tpu_cores  # Scale batch size with TPU cores
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 5
WARMUP_STEPS = 500
GRADIENT_ACCUMULATION_STEPS = 1  # Reduced for TPU efficiency

logger.info("Configuration parameters:")
logger.info(f"Dataset: {DATASET_ID}")
logger.info(f"Model: {MODEL_NAME}")
logger.info(f"Max sequence length: {MAX_LENGTH}")
logger.info(f"Batch size: {BATCH_SIZE} (scaled for {tpu_cores} TPU cores)")
logger.info(f"Learning rate: {LEARNING_RATE}")
logger.info(f"Weight decay: {WEIGHT_DECAY}")
logger.info(f"Number of epochs: {NUM_EPOCHS}")
logger.info(f"Warmup steps: {WARMUP_STEPS}")
logger.info(f"Gradient accumulation steps: {GRADIENT_ACCUMULATION_STEPS}")

## Data Loading

Now let's load the Swift code dataset and examine its structure with proper error handling and detailed logging:

In [None]:
# Function to load dataset with retry logic and detailed logging
def load_dataset_with_retry(dataset_id, max_retries=3, retry_delay=5):
    """Load a dataset with retry logic and detailed logging."""
    for attempt in range(max_retries):
        try:
            logger.info(f"Loading dataset (attempt {attempt+1}/{max_retries})...")
            start_time = time.time()
            data = load_dataset(dataset_id, trust_remote_code=True)
            load_time = time.time() - start_time
            logger.info(f"Dataset loaded successfully in {load_time:.2f} seconds")
            logger.info(f"Dataset contains {len(data['train'])} examples")
            return data
        except Exception as e:
            logger.error(f"Error loading dataset (attempt {attempt+1}/{max_retries}): {e}")
            if attempt < max_retries - 1:
                logger.info(f"Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
            else:
                logger.error("Maximum retries reached. Could not load dataset.")
                raise

# Make sure dataset ID is defined (in case previous cell didn't execute)
if 'DATASET_ID' not in globals():
    logger.warning("DATASET_ID not found. Using default value.")
    DATASET_ID = "mvasiliniuc/iva-swift-codeint"  # Default value as fallback
    MAX_LENGTH = 384
    MODEL_NAME = "microsoft/codebert-base"
    BATCH_SIZE = 8 * tpu_cores
    GRADIENT_ACCUMULATION_STEPS = 1
    logger.info("Using default configuration values.")

# Load the dataset with retry logic
try:
    logger.info(f"Loading dataset: {DATASET_ID}")
    data = load_dataset_with_retry(DATASET_ID)
    logger.info("Dataset structure:")
    logger.info(str(data))
except Exception as e:
    logger.error(f"Fatal error loading dataset: {e}")
    raise

In [None]:
# Verify dataset structure and column names with detailed logging
def verify_dataset_structure(dataset):
    """Verify that the dataset has the expected structure and columns."""
    required_columns = ['repo_name', 'path', 'content']
    
    logger.info("Verifying dataset structure...")
    
    if 'train' not in dataset:
        logger.warning("Dataset does not have a 'train' split.")
        return False
    
    # Log available columns
    available_columns = dataset['train'].column_names
    logger.info(f"Available columns: {available_columns}")
    
    missing_columns = [col for col in required_columns if col not in available_columns]
    if missing_columns:
        logger.warning(f"Dataset is missing required columns: {missing_columns}")
        return False
    
    logger.info("Dataset structure verification passed.")
    return True

# Verify dataset structure
dataset_valid = verify_dataset_structure(data)
if not dataset_valid:
    logger.warning("Dataset structure is not as expected. Proceeding with caution.")

In [None]:
# Let's take a look at an example from the dataset with detailed logging
try:
    logger.info("Exploring dataset example...")
    
    if 'train' in data:
        example = data['train'][0]
        split_name = 'train'
    else:
        split_name = list(data.keys())[0]
        example = data[split_name][0]
    
    logger.info(f"Example from '{split_name}' split:")
    
    # Log each feature with appropriate truncation for long values
    for key, value in example.items():
        if isinstance(value, str):
            logger.info(f"{key}: {value[:100]}..." if len(value) > 100 else f"{key}: {value}")
            # Additional logging for content analysis
            if key == 'content':
                logger.info(f"Content length: {len(value)} characters")
                logger.info(f"Content lines: {value.count(chr(10))+1}")
        else:
            logger.info(f"{key}: {value}")
    
    # Log dataset statistics
    logger.info(f"Total examples in dataset: {sum(len(data[split]) for split in data)}")
    for split in data:
        logger.info(f"Examples in {split} split: {len(data[split])}")
        
except Exception as e:
    logger.error(f"Error exploring dataset example: {e}")

## Loading the CodeBERT Tokenizer

Now, let's load the CodeBERT tokenizer, which has been specially trained to handle code tokens:

In [None]:
# Load the CodeBERT tokenizer with error handling and detailed logging
try:
    logger.info(f"Loading tokenizer from {MODEL_NAME}...")
    start_time = time.time()
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    load_time = time.time() - start_time
    logger.info(f"Tokenizer loaded successfully in {load_time:.2f} seconds")
    logger.info(f"Tokenizer vocabulary size: {len(tokenizer)}")
    logger.info(f"Tokenizer type: {tokenizer.__class__.__name__}")
    
    # Log special tokens
    logger.info(f"Special tokens: {tokenizer.special_tokens_map}")
    logger.info(f"Padding token: {tokenizer.pad_token}, ID: {tokenizer.pad_token_id}")
    logger.info(f"Unknown token: {tokenizer.unk_token}, ID: {tokenizer.unk_token_id}")
    
except Exception as e:
    logger.error(f"Error loading tokenizer: {e}")
    raise

## Enhanced Data Preparation

Instead of focusing only on Package.swift files, we'll create a more meaningful multi-class classification task that categorizes Swift files based on their purpose in a codebase. This approach utilizes the entire dataset and provides more valuable insights into code understanding.

We'll categorize files into the following classes:
1. **Models** - Data structures and model definitions
2. **Views** - UI related files
3. **Controllers** - Application logic
4. **Utilities** - Helper functions and extensions
5. **Tests** - Test files
6. **Configuration** - Package and configuration files

In [None]:
def extract_file_type(path):
    """
    Extract the file type/category based on the file path and naming conventions in Swift projects.
    
    Args:
        path (str): The file path
        
    Returns:
        int: The category label (0-5)
    """
    path_lower = path.lower()
    filename = path.split('/')[-1].lower()
    
    # Category 0: Models - Data structures and model definitions
    if ('model' in path_lower or 
        'struct' in path_lower or 
        'entity' in path_lower or
        'data' in path_lower and 'class' in path_lower):
        return 0
    
    # Category 1: Views - UI related files
    elif ('view' in path_lower or 
          'ui' in path_lower or 
          'screen' in path_lower or 
          'page' in path_lower or
          'controller' in path_lower and 'view' in path_lower):
        return 1
    
    # Category 2: Controllers - Application logic
    elif ('controller' in path_lower or 
          'manager' in path_lower or 
          'coordinator' in path_lower or
          'service' in path_lower):
        return 2
    
    # Category 3: Utilities - Helper functions and extensions
    elif ('util' in path_lower or 
          'helper' in path_lower or 
          'extension' in path_lower or
          'common' in path_lower):
        return 3
    
    # Category 4: Tests - Test files
    elif ('test' in path_lower or 
          'spec' in path_lower or 
          'mock' in path_lower):
        return 4
    
    # Category 5: Configuration - Package and configuration files
    elif ('package.swift' in path_lower or 
          'config' in path_lower or 
          'settings' in path_lower or
          'info.plist' in path_lower):
        return 5
    
    # Default to category 3 (Utilities) if no clear category is found
    return 3

def analyze_content_for_category(content):
    """
    Analyze file content to help determine its category when path-based classification is ambiguous.
    
    Args:
        content (str): The file content
        
    Returns:
        int: The suggested category based on content analysis
    """
    content_lower = content.lower()
    
    # Check for model patterns
    if (re.search(r'struct\s+\w+', content) or 
        re.search(r'class\s+\w+\s*:\s*\w*codable', content_lower) or
        'encodable' in content_lower or 'decodable' in content_lower):
        return 0
    
    # Check for view patterns
    elif ('uiview' in content_lower or 
          'uitableview' in content_lower or 
          'uicollectionview' in content_lower or
          'swiftui' in content_lower or
          'view {' in content_lower):
        return 1
    
    # Check for controller patterns
    elif ('viewcontroller' in content_lower or 
          'uiviewcontroller' in content_lower or
          'navigationcontroller' in content_lower or
          'viewdidload' in content_lower):
        return 2
    
    # Check for utility patterns
    elif ('extension' in content_lower or 
          'func ' in content and not 'class' in content_lower[:100] or
          'protocol' in content_lower):
        return 3
    
    # Check for test patterns
    elif ('xctest' in content_lower or 
          'testcase' in content_lower or
          'func test' in content_lower):
        return 4
    
    # Check for configuration patterns
    elif ('package(' in content_lower or 
          'dependencies' in content_lower and 'package' in content_lower or
          'products' in content_lower and 'targets' in content_lower):
        return 5
    
    # Default to category 3 (Utilities) if no clear category is found
    return 3

# Process the dataset to add category labels with detailed logging
def process_dataset_with_categories(dataset):
    """
    Process the dataset to add category labels based on file path and content analysis.
    
    Args:
        dataset: The Hugging Face dataset
        
    Returns:
        processed_dataset: The dataset with added category labels
    """
    logger.info("Processing dataset to add category labels...")
    start_time = time.time()
    
    # Define category names for better interpretability
    category_names = [
        "Models",        # 0
        "Views",         # 1
        "Controllers",   # 2
        "Utilities",     # 3
        "Tests",         # 4
        "Configuration"  # 5
    ]
    
    # Function to determine category from path and content
    def determine_category(example):
        path_category = extract_file_type(example['path'])
        content_category = analyze_content_for_category(example['content'])
        
        # If path and content analysis agree, use that category
        # Otherwise, prioritize path-based categorization
        final_category = path_category if path_category == content_category else path_category
        
        return {
            'category': final_category,
            'category_name': category_names[final_category]
        }
    
    # Process each split in the dataset
    processed_dataset = {}
    category_counts = {name: 0 for name in category_names}
    
    for split in dataset:
        logger.info(f"Processing {split} split...")
        
        # Map the determine_category function to each example
        processed_split = dataset[split].map(
            determine_category,
            desc=f"Adding categories to {split} split"
        )
        
        # Count categories for logging
        split_categories = processed_split['category']
        for i, name in enumerate(category_names):
            count = split_categories.count(i)
            category_counts[name] += count
            logger.info(f"  {name}: {count} files")
        
        processed_dataset[split] = processed_split
    
    # Log overall category distribution
    logger.info("Overall category distribution:")
    for name, count in category_counts.items():
        logger.info(f"  {name}: {count} files ({count/sum(category_counts.values())*100:.2f}%)")
    
    processing_time = time.time() - start_time
    logger.info(f"Dataset processing completed in {processing_time:.2f} seconds")
    
    return processed_dataset, category_names

# Process the dataset to add category labels
try:
    processed_data, category_names = process_dataset_with_categories(data)
    num_labels = len(category_names)
    logger.info(f"Processed dataset with {num_labels} categories")
    
    # Create class weights for handling imbalanced classes
    if 'train' in processed_data:
        category_counts = collections.Counter(processed_data['train']['category'])
        total_samples = len(processed_data['train'])
        
        # Calculate class weights inversely proportional to class frequencies
        class_weights = torch.tensor(
            [total_samples / (len(category_counts) * count) for label, count in sorted(category_counts.items())],
            dtype=torch.float32
        )
        
        logger.info(f"Class weights calculated: {class_weights}")
        
        # Move class weights to the appropriate device
        class_weights = class_weights.to(device)
except Exception as e:
    logger.error(f"Error processing dataset: {e}")
    raise

## Data Splitting

Now let's split our data into training and validation sets with stratification to maintain label distribution:

In [None]:
from sklearn.model_selection import train_test_split

def split_dataset(dataset, val_size=0.1, seed=42):
    """
    Split a dataset into training and validation sets with stratification.
    
    Args:
        dataset: The dataset to split
        val_size: The proportion of the dataset to include in the validation split
        seed: Random seed for reproducibility
        
    Returns:
        train_dataset, val_dataset: The split datasets
    """
    logger.info(f"Splitting dataset with validation size {val_size} and seed {seed}...")
    
    # If dataset already has train/validation splits, use those
    if 'train' in dataset and 'validation' in dataset:
        logger.info("Dataset already has train/validation splits. Using existing splits.")
        return dataset['train'], dataset['validation']
    
    # If dataset only has train split, create validation split
    if 'train' in dataset and 'validation' not in dataset:
        train_data = dataset['train']
    else:
        # Use the first available split
        split_name = list(dataset.keys())[0]
        train_data = dataset[split_name]
        logger.info(f"Using '{split_name}' split for training data")
    
    # Convert to pandas for easier splitting
    df = train_data.to_pandas()
    
    # Split with stratification by category
    train_df, val_df = train_test_split(
        df, 
        test_size=val_size, 
        random_state=seed,
        stratify=df['category'] if 'category' in df.columns else None
    )
    
    # Log split sizes and category distributions
    logger.info(f"Training set size: {len(train_df)}")
    logger.info(f"Validation set size: {len(val_df)}")
    
    if 'category' in df.columns:
        logger.info("Category distribution in training set:")
        train_cat_dist = train_df['category'].value_counts(normalize=True)
        for cat, prop in train_cat_dist.items():
            logger.info(f"  Category {cat} ({category_names[cat]}): {prop*100:.2f}%")
            
        logger.info("Category distribution in validation set:")
        val_cat_dist = val_df['category'].value_counts(normalize=True)
        for cat, prop in val_cat_dist.items():
            logger.info(f"  Category {cat} ({category_names[cat]}): {prop*100:.2f}%")
    
    # Convert back to Hugging Face datasets
    from datasets import Dataset
    train_dataset = Dataset.from_pandas(train_df)
    val_dataset = Dataset.from_pandas(val_df)
    
    logger.info("Dataset splitting completed successfully")
    return train_dataset, val_dataset

# Split the dataset
try:
    train_data, val_data = split_dataset(processed_data, val_size=0.1, seed=SEED)
    logger.info(f"Split dataset into {len(train_data)} training and {len(val_data)} validation examples")
except Exception as e:
    logger.error(f"Error splitting dataset: {e}")
    raise

## Tokenization

Now let's tokenize our data for the model:

In [None]:
def tokenize_function(examples):
    """
    Tokenize the code content with proper truncation and padding.
    
    Args:
        examples: Batch of examples to tokenize
        
    Returns:
        tokenized_examples: The tokenized examples
    """
    # Tokenize the code content
    tokenized_examples = tokenizer(
        examples['content'],
        padding='max_length',
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    )
    
    # Add labels
    tokenized_examples["labels"] = examples["category"]
    
    return tokenized_examples

# Tokenize the datasets with detailed logging
try:
    logger.info("Tokenizing datasets...")
    start_time = time.time()
    
    # Tokenize training data
    logger.info("Tokenizing training data...")
    tokenized_train_data = train_data.map(
        tokenize_function,
        batched=True,
        desc="Tokenizing training data",
        remove_columns=train_data.column_names
    )
    
    # Tokenize validation data
    logger.info("Tokenizing validation data...")
    tokenized_val_data = val_data.map(
        tokenize_function,
        batched=True,
        desc="Tokenizing validation data",
        remove_columns=val_data.column_names
    )
    
    tokenization_time = time.time() - start_time
    logger.info(f"Tokenization completed in {tokenization_time:.2f} seconds")
    logger.info(f"Tokenized {len(tokenized_train_data)} training examples")
    logger.info(f"Tokenized {len(tokenized_val_data)} validation examples")
    
    # Log tokenized data structure
    logger.info(f"Tokenized data features: {tokenized_train_data.column_names}")
    
    # Log token statistics for a sample
    sample_lengths = [len(x) for x in tokenized_train_data['input_ids'][:100]]
    logger.info(f"Average sequence length in sample: {sum(sample_lengths)/len(sample_lengths):.1f} tokens")
    logger.info(f"Max sequence length in sample: {max(sample_lengths)} tokens")
    logger.info(f"Min sequence length in sample: {min(sample_lengths)} tokens")
    
    # Set format for PyTorch
    tokenized_train_data.set_format("torch")
    tokenized_val_data.set_format("torch")
    logger.info("Dataset format set to PyTorch tensors")
    
except Exception as e:
    logger.error(f"Error tokenizing data: {e}")
    raise

## Model Setup

Now let's load the CodeBERT model and prepare it for TPU training:

In [None]:
try:
    # Load the model with the correct number of labels
    logger.info(f"Loading model {MODEL_NAME} with {num_labels} output classes...")
    start_time = time.time()
    
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME, 
        num_labels=num_labels,
        problem_type="single_label_classification"
    )
    
    # Log model architecture
    logger.info(f"Model type: {model.__class__.__name__}")
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Move model to TPU device
    model.to(device)
    logger.info(f"Model moved to device: {device}")
    
    model_load_time = time.time() - start_time
    logger.info(f"Model loaded in {model_load_time:.2f} seconds")
    
except Exception as e:
    logger.error(f"Error loading model: {e}")
    raise

## Class Weights for Imbalanced Data

Let's calculate class weights to handle any imbalance in our dataset:

In [None]:
# Visualize class distribution
try:
    logger.info("Visualizing class distribution...")
    
    # Count labels in training data
    label_counts = collections.Counter(tokenized_train_data['labels'].numpy())
    
    # Create a DataFrame for visualization
    label_df = pd.DataFrame({
        'Category': [category_names[i] for i in range(num_labels)],
        'Count': [label_counts.get(i, 0) for i in range(num_labels)]
    })
    
    # Calculate percentages
    total = label_df['Count'].sum()
    label_df['Percentage'] = label_df['Count'] / total * 100
    
    # Log class distribution
    logger.info("Class distribution in training data:")
    for _, row in label_df.iterrows():
        logger.info(f"  {row['Category']}: {row['Count']} examples ({row['Percentage']:.2f}%)")
    
    # Plot class distribution
    plt.figure(figsize=(10, 6))
    bars = plt.bar(label_df['Category'], label_df['Count'])
    plt.title('Class Distribution in Training Data')
    plt.xlabel('Category')
    plt.ylabel('Number of Examples')
    plt.xticks(rotation=45, ha='right')
    
    # Add count labels on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 5,
                f'{height:,}',
                ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig('class_distribution.png')
    logger.info("Class distribution plot saved to 'class_distribution.png'")
    
    # Calculate class weights if not already done
    if 'class_weights' not in globals():
        # Calculate class weights inversely proportional to class frequencies
        class_weights = torch.tensor(
            [total / (num_labels * count) for label, count in sorted(label_counts.items())],
            dtype=torch.float32
        )
        # Move class weights to the appropriate device
        class_weights = class_weights.to(device)
        
    logger.info(f"Class weights: {class_weights}")
    
except Exception as e:
    logger.error(f"Error visualizing class distribution: {e}")
    # If visualization fails, ensure we still have class weights
    if 'class_weights' not in globals():
        logger.warning("Using equal class weights due to error")
        class_weights = torch.ones(num_labels, dtype=torch.float32).to(device)

## TPU-Optimized Training Setup

Let's set up the training configuration optimized for TPU:

In [None]:
# Create a custom loss function with class weights for TPU
class TPUWeightedLossTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        """Custom loss function with class weights for TPU."""
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        # Use class weights in the loss calculation
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        
        # Log batch loss
        if self.state.global_step % 10 == 0:
            logger.info(f"Step {self.state.global_step}: Batch loss = {loss.item():.4f}")
        
        return (loss, outputs) if return_outputs else loss
    
    def training_step(self, model, inputs):
        """Custom training step with TPU synchronization."""
        # Regular training step
        loss = super().training_step(model, inputs)
        
        # TPU-specific: Mark step for XLA compilation
        if 'torch_xla' in sys.modules:
            xm.mark_step()
        
        return loss

# Define TPU-optimized training arguments
logger.info("Setting up TPU-optimized training arguments...")
training_args = TrainingArguments(
    output_dir="./results_tpu",
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE*2,
    warmup_steps=WARMUP_STEPS,
    weight_decay=WEIGHT_DECAY,
    learning_rate=LEARNING_RATE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    logging_dir="./logs_tpu",
    logging_steps=10,  # More frequent logging
    save_total_limit=2,
    # TPU-specific settings
    dataloader_drop_last=True,  # Important for TPU to have fixed size batches
    dataloader_num_workers=4,   # Parallel data loading
    fp16=False,                 # TPU uses bfloat16 instead of fp16
    bf16=True,                  # Use bfloat16 precision for TPU
    report_to="none"
)

# Log training arguments
logger.info("Training arguments:")
for key, value in training_args.to_dict().items():
    logger.info(f"  {key}: {value}")

# Define early stopping callback
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)
logger.info("Early stopping callback configured with patience=3, threshold=0.01")

# Define compute_metrics function for evaluation with detailed logging
def compute_metrics(eval_pred):
    """Compute evaluation metrics with detailed logging."""
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    # Calculate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    acc = accuracy_score(labels, predictions)
    
    # Calculate per-class metrics
    per_class_precision, per_class_recall, per_class_f1, per_class_support = \
        precision_recall_fscore_support(labels, predictions, average=None)
    
    # Log detailed metrics
    logger.info("Evaluation metrics:")
    logger.info(f"  Accuracy: {acc:.4f}")
    logger.info(f"  Weighted F1: {f1:.4f}")
    logger.info(f"  Weighted Precision: {precision:.4f}")
    logger.info(f"  Weighted Recall: {recall:.4f}")
    
    # Log per-class metrics
    logger.info("Per-class metrics:")
    for i in range(num_labels):
        logger.info(f"  Class {i} ({category_names[i]}):\n"
                   f"    Precision: {per_class_precision[i]:.4f}\n"
                   f"    Recall: {per_class_recall[i]:.4f}\n"
                   f"    F1: {per_class_f1[i]:.4f}\n"
                   f"    Support: {per_class_support[i]}")
    
    # Calculate and log confusion matrix
    cm = confusion_matrix(labels, predictions)
    logger.info(f"Confusion matrix:\n{cm}")
    
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Create data collator for padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest')
logger.info("Data collator created for padding")

# Create trainer with weighted loss for TPU
logger.info("Creating TPU-optimized trainer...")
trainer = TPUWeightedLossTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_val_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback]
)

logger.info("TPU-optimized training setup complete")

## TPU-Accelerated Model Training

Now let's train the model using TPU acceleration:

In [None]:
try:
    logger.info("Starting TPU-accelerated training...")
    logger.info(f"Training on device: {device}")
    logger.info(f"Number of TPU cores: {tpu_cores}")
    logger.info(f"Training for {NUM_EPOCHS} epochs with batch size {BATCH_SIZE}")
    
    # Record start time for training
    training_start_time = time.time()
    
    # Train the model
    train_result = trainer.train()
    
    # Calculate training time
    training_time = time.time() - training_start_time
    
    # Log detailed training results
    logger.info(f"Training completed in {training_time:.2f} seconds ({training_time/60:.2f} minutes)")
    logger.info(f"Training loss: {train_result.metrics['train_loss']:.4f}")
    logger.info(f"Training steps: {train_result.metrics['train_runtime']/train_result.metrics['train_steps_per_second']:.0f}")
    logger.info(f"Steps per second: {train_result.metrics['train_steps_per_second']:.2f}")
    
    # Log all metrics
    logger.info("All training metrics:")
    for key, value in train_result.metrics.items():
        logger.info(f"  {key}: {value}")
    
    # Save the model
    logger.info("Saving the final model...")
    trainer.save_model("./final_model_tpu")
    logger.info("Model saved to ./final_model_tpu")
    
    # Save tokenizer alongside the model
    tokenizer.save_pretrained("./final_model_tpu")
    logger.info("Tokenizer saved with the model")
    
    # Clean up memory
    cleanup_memory()
    logger.info("Memory cleaned up after training")
    
except Exception as e:
    logger.error(f"Error during training: {e}")
    import traceback
    logger.error(f"Traceback: {traceback.format_exc()}")
    raise

## Comprehensive Model Evaluation

Let's evaluate our model on the validation set with detailed metrics:

In [None]:
try:
    logger.info("Starting comprehensive model evaluation...")
    
    # Evaluate the model
    logger.info("Evaluating model on validation set...")
    eval_start_time = time.time()
    eval_results = trainer.evaluate()
    eval_time = time.time() - eval_start_time
    
    # Log evaluation results
    logger.info(f"Evaluation completed in {eval_time:.2f} seconds")
    logger.info("Evaluation results:")
    for key, value in eval_results.items():
        logger.info(f"  {key}: {value:.4f}")
    
    # Get predictions for detailed analysis
    logger.info("Generating predictions for detailed analysis...")
    predictions = trainer.predict(tokenized_val_data)
    preds = np.argmax(predictions.predictions, axis=-1)
    labels = predictions.label_ids
    
    # Calculate confusion matrix
    cm = confusion_matrix(labels, preds)
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(category_names))
    plt.xticks(tick_marks, category_names, rotation=45, ha='right')
    plt.yticks(tick_marks, category_names)
    
    # Add text annotations to confusion matrix
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                    horizontalalignment="center",
                    color="white" if cm[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('confusion_matrix.png')
    logger.info("Confusion matrix saved to 'confusion_matrix.png'")
    
    # Calculate per-class metrics
    precision, recall, f1, support = precision_recall_fscore_support(labels, preds, average=None)
    
    # Create a DataFrame for metrics visualization
    metrics_df = pd.DataFrame({
        'Category': category_names,
        'Precision': precision,
        'Recall': recall,
        'F1 Score': f1,
        'Support': support
    })
    
    # Save metrics to CSV
    metrics_df.to_csv('class_metrics.csv', index=False)
    logger.info("Per-class metrics saved to 'class_metrics.csv'")
    
    # Plot F1 scores by class
    plt.figure(figsize=(10, 6))
    bars = plt.bar(metrics_df['Category'], metrics_df['F1 Score'])
    plt.title('F1 Score by Category')
    plt.xlabel('Category')
    plt.ylabel('F1 Score')
    plt.xticks(rotation=45, ha='right')
    plt.ylim(0, 1.0)
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 0.02,
                f'{height:.2f}',
                ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig('f1_scores.png')
    logger.info("F1 scores plot saved to 'f1_scores.png'")
    
    # Log detailed per-class metrics
    logger.info("Detailed per-class metrics:")
    for i, category in enumerate(category_names):
        logger.info(f"  {category}:\n"
                   f"    Precision: {precision[i]:.4f}\n"
                   f"    Recall: {recall[i]:.4f}\n"
                   f"    F1 Score: {f1[i]:.4f}\n"
                   f"    Support: {support[i]}")
    
    # Calculate overall metrics
    overall_precision, overall_recall, overall_f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    overall_accuracy = accuracy_score(labels, preds)
    
    logger.info("Overall metrics:")
    logger.info(f"  Accuracy: {overall_accuracy:.4f}")
    logger.info(f"  Weighted Precision: {overall_precision:.4f}")
    logger.info(f"  Weighted Recall: {overall_recall:.4f}")
    logger.info(f"  Weighted F1 Score: {overall_f1:.4f}")
    
except Exception as e:
    logger.error(f"Error during evaluation: {e}")
    import traceback
    logger.error(f"Traceback: {traceback.format_exc()}")

## Model Export and Upload

Let's export our model and upload it to Dropbox:

In [None]:
import dropbox
import os
import zipfile
import json

def zip_model_directory(model_dir, output_zip):
    """Zip the model directory for easier upload."""
    logger.info(f"Zipping model directory {model_dir} to {output_zip}...")
    
    with zipfile.ZipFile(output_zip, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for root, dirs, files in os.walk(model_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, os.path.dirname(model_dir))
                logger.info(f"  Adding {arcname} to zip")
                zipf.write(file_path, arcname)
    
    zip_size = os.path.getsize(output_zip) / (1024 * 1024)  # Size in MB
    logger.info(f"Model zipped successfully. Zip size: {zip_size:.2f} MB")
    return output_zip

def upload_to_dropbox(file_path, dropbox_token, dropbox_path):
    """Upload a file to Dropbox."""
    try:
        logger.info(f"Uploading {file_path} to Dropbox path {dropbox_path}...")
        
        # Initialize Dropbox client
        dbx = dropbox.Dropbox(dropbox_token)
        
        # Get file size
        file_size = os.path.getsize(file_path)
        logger.info(f"File size: {file_size / (1024 * 1024):.2f} MB")
        
        # Upload file
        with open(file_path, 'rb') as f:
            if file_size <= 150 * 1024 * 1024:  # 150 MB limit for simple uploads
                logger.info("Using simple upload")
                dbx.files_upload(f.read(), dropbox_path, mode=dropbox.files.WriteMode.overwrite)
            else:
                logger.info("Using chunked upload for large file")
                chunk_size = 4 * 1024 * 1024  # 4 MB chunks
                upload_session_start_result = dbx.files_upload_session_start(f.read(chunk_size))
                cursor = dropbox.files.UploadSessionCursor(
                    session_id=upload_session_start_result.session_id,
                    offset=f.tell()
                )
                commit = dropbox.files.CommitInfo(path=dropbox_path, mode=dropbox.files.WriteMode.overwrite)
                
                while f.tell() < file_size:
                    if (file_size - f.tell()) <= chunk_size:
                        dbx.files_upload_session_finish(f.read(chunk_size), cursor, commit)
                        break
                    else:
                        dbx.files_upload_session_append_v2(f.read(chunk_size), cursor)
                        cursor.offset = f.tell()
                        logger.info(f"Uploaded {cursor.offset / (1024 * 1024):.2f} MB so far")
        
        # Create a shared link
        shared_link = dbx.sharing_create_shared_link_with_settings(dropbox_path)
        download_url = shared_link.url.replace('www.dropbox.com', 'dl.dropboxusercontent.com')
        
        logger.info(f"Upload successful!")
        logger.info(f"Shared link: {shared_link.url}")
        logger.info(f"Direct download link: {download_url}")
        
        return download_url
    
    except Exception as e:
        logger.error(f"Error uploading to Dropbox: {e}")
        import traceback
        logger.error(f"Traceback: {traceback.format_exc()}")
        return None

# Save model metadata
try:
    logger.info("Saving model metadata...")
    
    # Create metadata dictionary
    metadata = {
        "model_name": "CodeBERT-Swift-TPU",
        "base_model": MODEL_NAME,
        "num_labels": num_labels,
        "categories": category_names,
        "max_length": MAX_LENGTH,
        "training_params": {
            "batch_size": BATCH_SIZE,
            "learning_rate": LEARNING_RATE,
            "epochs": NUM_EPOCHS,
            "weight_decay": WEIGHT_DECAY,
            "warmup_steps": WARMUP_STEPS,
            "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
            "tpu_cores": tpu_cores
        },
        "dataset": DATASET_ID,
        "training_examples": len(tokenized_train_data),
        "validation_examples": len(tokenized_val_data),
        "metrics": {
            "accuracy": float(overall_accuracy),
            "f1": float(overall_f1),
            "precision": float(overall_precision),
            "recall": float(overall_recall)
        },
        "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    
    # Save metadata to file
    with open("./final_model_tpu/metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)
    
    logger.info("Model metadata saved successfully")
    
    # Zip the model directory
    model_zip = zip_model_directory("./final_model_tpu", "codebert_swift_tpu.zip")
    
    # Upload to Dropbox if token is provided
    # Note: In a real notebook, you would prompt for the token or use environment variables
    dropbox_token = input("Enter your Dropbox access token (leave empty to skip upload): ").strip()
    
    if dropbox_token:
        dropbox_path = "/codebert_swift_tpu.zip"
        download_url = upload_to_dropbox(model_zip, dropbox_token, dropbox_path)
        
        if download_url:
            # Save download URL to metadata
            metadata["download_url"] = download_url
            with open("./final_model_tpu/metadata.json", "w") as f:
                json.dump(metadata, f, indent=2)
            logger.info("Metadata updated with download URL")
    else:
        logger.info("Dropbox upload skipped")
    
except Exception as e:
    logger.error(f"Error during model export: {e}")
    import traceback
    logger.error(f"Traceback: {traceback.format_exc()}")

## Conclusion

We've successfully enhanced the CodeBERT training process to utilize TPU acceleration while maintaining the same functionality as the original notebook. Our model now classifies Swift code files into meaningful categories based on their purpose in a codebase:

1. **Models** - Data structures and model definitions
2. **Views** - UI related files
3. **Controllers** - Application logic
4. **Utilities** - Helper functions and extensions
5. **Tests** - Test files
6. **Configuration** - Package and configuration files

The TPU-accelerated training provides significant performance improvements over CPU or even GPU training, allowing us to train with larger batch sizes and potentially achieve better results in less time.

### Key Improvements in this TPU Version:

1. **TPU Optimization**: Configured the training pipeline to fully utilize TPU acceleration
2. **Enhanced Logging**: Added comprehensive logging throughout the notebook for better tracking and debugging
3. **Visualization**: Added plots and visualizations for better understanding of the data and results
4. **Robust Error Handling**: Improved error handling and recovery mechanisms
5. **Detailed Metrics**: Added more detailed evaluation metrics for better model assessment

This notebook can be used as a template for other TPU-accelerated NLP tasks, especially those involving code understanding and classification.