# Fine-tuning Gemma 3 270M on TPU v5e-8 - Persian QA Dataset

This notebook fine-tunes Gemma 3 270M model on TPU v5e-8 using the Persian QA translated dataset.

## Features:
- **TPU v5e-8 Support**: Optimized for Google Cloud TPU v5e-8
- **JAX/Flax Backend**: Uses JAX for TPU acceleration
- **Memory Efficient**: Optimized batch sizes and gradient accumulation
- **LoRA Fine-tuning**: Parameter-efficient fine-tuning
- **Comprehensive Logging**: Training metrics and progress tracking
- **Model Evaluation**: Test and evaluation capabilities

## Prerequisites:
- TPU v5e-8 instance
- JAX, Flax, and Optax installed
- Transformers library with TPU support

In [1]:
# Check TPU availability and setup
import os
import sys

# Verify TPU connection
try:
    import jax
    print(f"JAX version: {jax.__version__}")
    print(f"JAX devices: {jax.devices()}")
    print(f"TPU cores available: {jax.device_count()}")
    print(f"TPU backend: {jax.default_backend()}")

    if jax.default_backend() != 'tpu':
        print("\n⚠️  WARNING: TPU not detected! Current backend:", jax.default_backend())
        print("Please ensure you're connected to a TPU instance.")
    else:
        print("\n✅ TPU connection successful!")

except ImportError:
    print("❌ JAX not installed. Installing required packages...")
    print("Run: pip install jax[tpu] flax optax -f https://storage.googleapis.com/jax-releases/libtpu_releases.html")



JAX version: 0.5.3
JAX devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
TPU cores available: 1
TPU backend: tpu

✅ TPU connection successful!


## Configuration

In [2]:
# Install required packages for TPU
!pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q flax optax
!pip install -q transformers datasets accelerate peft
!pip install -q sentencepiece protobuf
!pip install -q tensorboard

print("All packages installed successfully!")

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/75.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.1/75.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m503.6/503.6 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m504.9/504.9 kB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.7/119.7 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m146.7/146.7 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.9/193.9 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m76.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Import Libraries

In [3]:
import os
import sys
import logging
import time
from datetime import datetime
from pathlib import Path

# JAX and Flax for TPU
import jax
import jax.numpy as jnp

# Transformers and Datasets
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    TrainerCallback,
    TrainerState,
    TrainerControl
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, TaskType

# Utilities
import numpy as np
from tqdm.auto import tqdm
import torch

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f'tpu_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

logger.info("All libraries imported successfully!")

  * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or


## Configuration for TPU v5e-8

## Custom Training Callbacks for Detailed Logging

In [4]:
class DetailedLoggingCallback(TrainerCallback):
    """
    Custom callback for detailed logging during training on TPU.
    Provides comprehensive metrics tracking and progress updates.
    """

    def __init__(self, log_every_n_steps=50):
        self.log_every_n_steps = log_every_n_steps
        self.start_time = None
        self.last_log_time = None
        self.step_times = []

    def on_train_begin(self, args, state, control, **kwargs):
        """Called when training begins"""
        self.start_time = time.time()
        self.last_log_time = self.start_time

        logger.info("="*80)
        logger.info("🚀 TRAINING STARTED ON TPU v5e-8")
        logger.info("="*80)
        logger.info(f"Output directory: {args.output_dir}")
        logger.info(f"Total training steps: {state.max_steps}")
        logger.info(f"Batch size per device: {args.per_device_train_batch_size}")
        logger.info(f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
        logger.info(f"Effective batch size: {args.per_device_train_batch_size * args.gradient_accumulation_steps * args.world_size}")
        logger.info(f"Number of epochs: {args.num_train_epochs}")
        logger.info(f"Learning rate: {args.learning_rate}")
        logger.info(f"Save steps: {args.save_steps}")
        logger.info(f"Eval steps: {args.eval_steps}")
        logger.info("="*80)

    def on_step_end(self, args, state, control, **kwargs):
        """Called at the end of each training step"""
        if state.global_step % self.log_every_n_steps == 0:
            current_time = time.time()
            elapsed = current_time - self.start_time
            step_elapsed = current_time - self.last_log_time

            # Track step times for averaging
            self.step_times.append(step_elapsed / self.log_every_n_steps)
            if len(self.step_times) > 10:
                self.step_times.pop(0)

            avg_step_time = sum(self.step_times) / len(self.step_times)

            # Calculate ETA
            if state.global_step > 0:
                remaining_steps = state.max_steps - state.global_step
                eta = remaining_steps * avg_step_time
                eta_str = f"{eta / 3600:.1f}h" if eta > 3600 else f"{eta / 60:.1f}m"
            else:
                eta_str = "N/A"

            # Get latest metrics safely
            loss_str = "N/A"
            lr_str = "N/A"

            if state.log_history:
                latest_log = state.log_history[-1]
                loss = latest_log.get('loss', None)
                lr = latest_log.get('learning_rate', None)

                if loss is not None and isinstance(loss, (int, float)):
                    loss_str = f"{loss:.4f}"
                if lr is not None and isinstance(lr, (int, float)):
                    lr_str = f"{lr:.2e}"

            # Calculate progress
            progress = (state.global_step / state.max_steps) * 100

            logger.info("─"*80)
            logger.info(f"Step {state.global_step}/{state.max_steps} ({progress:.1f}%)")
            logger.info(f"Loss: {loss_str}")
            logger.info(f"Learning Rate: {lr_str}")
            logger.info(f" Elapsed: {elapsed/60:.1f}m | Step time: {avg_step_time:.2f}s | ETA: {eta_str}")
            logger.info(f"Epoch: {state.epoch:.2f}")

            self.last_log_time = current_time

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Called after evaluation"""
        if metrics:
            logger.info("="*80)
            logger.info(f"EVALUATION at Step {state.global_step}")
            logger.info("="*80)

            for key, value in sorted(metrics.items()):
                if isinstance(value, (int, float)):
                    logger.info(f"  {key}: {value:.4f}")
                else:
                    logger.info(f"  {key}: {value}")

            logger.info("="*80)

    def on_save(self, args, state, control, **kwargs):
        """Called when a checkpoint is saved"""
        logger.info(f"💾 Checkpoint saved at step {state.global_step}")

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Called when logging occurs"""
        if logs and state.global_step % self.log_every_n_steps == 0:
            # Additional metrics logging
            if 'grad_norm' in logs:
                logger.info(f"  Gradient norm: {logs['grad_norm']:.4f}")

    def on_train_end(self, args, state, control, **kwargs):
        """Called when training ends"""
        total_time = time.time() - self.start_time

        logger.info("="*80)
        logger.info("TRAINING COMPLETED!")
        logger.info("="*80)
        logger.info(f"  Total training time: {total_time/3600:.2f} hours")
        logger.info(f"Final step: {state.global_step}")
        logger.info(f"Best metric: {getattr(state, 'best_metric', 'N/A')}")
        logger.info(f"Best model checkpoint: {getattr(state, 'best_model_checkpoint', 'N/A')}")
        logger.info("="*80)


class EarlyStoppingCallback(TrainerCallback):
    """
    Early stopping callback to prevent overfitting.
    """

    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_metric = None
        self.patience_counter = 0

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Check if we should stop early"""
        if metrics is None:
            return

        current_metric = metrics.get('eval_loss')
        if current_metric is None:
            return

        if self.best_metric is None:
            self.best_metric = current_metric
            logger.info(f"Initial best metric: {self.best_metric:.4f}")
            return

        # Check if improvement
        if current_metric < (self.best_metric - self.min_delta):
            improvement = self.best_metric - current_metric
            self.best_metric = current_metric
            self.patience_counter = 0
            logger.info(f"Improved! New best: {self.best_metric:.4f} (↓ {improvement:.4f})")
        else:
            self.patience_counter += 1
            logger.info(f"  No improvement. Patience: {self.patience_counter}/{self.patience}")

            if self.patience_counter >= self.patience:
                logger.info(f"Early stopping triggered! No improvement for {self.patience} evaluations.")
                control.should_training_stop = True

logger.info("Custom callbacks defined!")

## Utility Functions

In [5]:
def set_seed(seed=42):
    """Set random seeds for reproducibility"""
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    logger.info(f"Random seed set to: {seed}")

def log_system_info():
    """Log system and TPU information"""
    logger.info("="*80)
    logger.info("SYSTEM INFORMATION")
    logger.info("="*80)

    # Python version
    logger.info(f"Python version: {sys.version}")

    # PyTorch info
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")

    # JAX info
    try:
        logger.info(f"JAX version: {jax.__version__}")
        logger.info(f"JAX backend: {jax.default_backend()}")
        logger.info(f"JAX devices: {jax.device_count()}")

        if jax.default_backend() == 'tpu':
            for i, device in enumerate(jax.devices()):
                logger.info(f"  TPU Device {i}: {device}")
    except:
        logger.warning("JAX not available or not configured")

    # Disk space
    import shutil
    total, used, free = shutil.disk_usage("/")
    logger.info(f"Disk space: {free // (2**30)} GB free of {total // (2**30)} GB")

    logger.info("="*80)

def print_model_info(model):
    """Print detailed model information"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("\n" + "="*80)
    print("🤖 MODEL INFORMATION")
    print("="*80)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Trainable %: {100 * trainable_params / total_params:.4f}%")
    print(f"Model dtype: {model.dtype}")
    print("="*80)

def compute_metrics(eval_pred):
    """Compute metrics for evaluation"""
    predictions, labels = eval_pred

    # Compute perplexity from loss
    predictions = np.array(predictions[0]) if isinstance(predictions, tuple) else np.array(predictions)
    labels = np.array(labels)

    # Calculate accuracy (for tokens)
    mask = labels != -100
    correct = (predictions.argmax(-1) == labels) & mask
    accuracy = correct.sum() / mask.sum() if mask.sum() > 0 else 0

    return {
        "accuracy": accuracy,
    }

logger.info("Utility functions defined!")

In [6]:
# Log system information
log_system_info()

# Set random seed for reproducibility
set_seed(42)

## HuggingFace Authentication (Required for Gemma Models)

In [7]:
!export HF_TOKEN=

In [8]:
from huggingface_hub import login
import os

# Check if token is in environment
hf_token = os.environ.get('HF_TOKEN', None)

if hf_token:
    print("✅ Token found in environment variable")
    login(token=hf_token)
    logger.info("✅ Logged in to HuggingFace using environment token")
else:
    print("⚠️  No token found in environment")
    print("Enter your HuggingFace token (or press Enter to use interactive login):")

    # Option to enter token manually
    try:
        from getpass import getpass
        manual_token = getpass("Token: ")
        if manual_token:
            login(token=manual_token)
            logger.info("✅ Logged in to HuggingFace using manual token")
        else:
            # Interactive login
            login()
            logger.info("✅ Logged in to HuggingFace interactively")
    except:
        # Fallback to interactive
        login()
        logger.info("✅ Logged in to HuggingFace interactively")

print("✅ HuggingFace authentication successful!")

⚠️  No token found in environment
Enter your HuggingFace token (or press Enter to use interactive login):
Token: ··········
✅ HuggingFace authentication successful!


In [15]:
class TPUConfig:
    """Configuration optimized for TPU v5e-8"""

    # Model settings
    model_name = "google/gemma-3-270m"
    dataset_name = "pourmand1376/persian-qa-translated"

    # TPU v5e-8 has 8 cores
    num_tpu_cores = 8

    # Training settings optimized for TPU
    # TPU works best with larger batch sizes
    per_device_train_batch_size = 4  # Per TPU core
    per_device_eval_batch_size = 4
    gradient_accumulation_steps = 4  # Effective batch size = 4 * 8 * 4 = 128

    # Training parameters
    num_train_epochs = 3
    learning_rate = 2e-4
    warmup_ratio = 0.03
    max_grad_norm = 1.0
    lr_scheduler_type = "cosine"

    # Sequence lengths
    max_input_length = 512
    max_target_length = 64

    # LoRA configuration
    lora_r = 8
    lora_alpha = 16  # Typically 2x lora_r
    lora_dropout = 0.05
    lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

    # Logging and saving
    logging_steps = 50
    eval_steps = 200
    save_steps = 400
    save_total_limit = 3

    # Directories
    output_dir = "./outputs_tpu"
    logging_dir = "./logs_tpu"
    cache_dir = "./cache"

    # Dataset
    val_set_size = 0.1
    seed = 42
    dataset_size_limit = None # Set to an integer to limit training data size

config = TPUConfig()

# Create directories
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.logging_dir, exist_ok=True)
os.makedirs(config.cache_dir, exist_ok=True)

logger.info("Configuration initialized for TPU v5e-8")
logger.info(f"Effective batch size: {config.per_device_train_batch_size * config.num_tpu_cores * config.gradient_accumulation_steps}")

## Load Model and Tokenizer

In [10]:
logger.info("Loading tokenizer and model...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    config.model_name,
    cache_dir=config.cache_dir,
    trust_remote_code=True
)

# Set padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

logger.info(f"Tokenizer loaded: {config.model_name}")
logger.info(f"Vocab size: {len(tokenizer):,}")
logger.info(f"PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
logger.info(f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
logger.info(f"BOS token: {tokenizer.bos_token} (ID: {tokenizer.bos_token_id})")

# Load model - For TPU, we use bfloat16 precision
logger.info("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    cache_dir=config.cache_dir,
    torch_dtype=torch.bfloat16,  # TPU v5e works best with bfloat16
    device_map=None,  # We'll handle device placement manually for TPU
    trust_remote_code=True,
    attn_implementation="eager"  # Recommended for Gemma 3
)

logger.info(f"Model loaded: {config.model_name}")

# Print model information
print_model_info(model)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/536M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/133 [00:00<?, ?B/s]


🤖 MODEL INFORMATION
Total parameters: 268,098,176
Trainable parameters: 268,098,176
Trainable %: 100.0000%
Model dtype: torch.bfloat16


## Apply LoRA Configuration

In [11]:
logger.info("Applying LoRA configuration...")

# Configure LoRA
lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    lora_dropout=config.lora_dropout,
    target_modules=config.lora_target_modules,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# Apply LoRA to model
model = get_peft_model(model, lora_config)

logger.info("LoRA applied successfully!")

# Print updated model information
print_model_info(model)

# Print LoRA configuration details
print("\n" + "="*80)
print(" LoRA CONFIGURATION")
print("="*80)
print(f"Rank (r): {config.lora_r}")
print(f"Alpha: {config.lora_alpha}")
print(f"Dropout: {config.lora_dropout}")
print(f"Target modules: {', '.join(config.lora_target_modules)}")
print(f"Bias: none")
print(f"Task type: CAUSAL_LM")
print("="*80)


🤖 MODEL INFORMATION
Total parameters: 268,835,456
Trainable parameters: 737,280
Trainable %: 0.2742%
Model dtype: torch.bfloat16

 LoRA CONFIGURATION
Rank (r): 8
Alpha: 16
Dropout: 0.05
Target modules: q_proj, k_proj, v_proj, o_proj
Bias: none
Task type: CAUSAL_LM


## Load and Prepare Dataset

In [12]:
logger.info("Loading dataset...")

# Load dataset
dataset = load_dataset(
    config.dataset_name,
    cache_dir=config.cache_dir
)

logger.info(f" Dataset loaded: {config.dataset_name}")
logger.info(f"Dataset keys: {list(dataset.keys())}")

# Check if 'train' key exists, if not use the first available key
if 'train' not in dataset:
    dataset_key = list(dataset.keys())[0]
    logger.info(f"'train' key not found, using '{dataset_key}' instead")
    dataset = {'train': dataset[dataset_key]}

logger.info(f"Train samples: {len(dataset['train']):,}")

# Log available columns
logger.info(f"Available columns: {dataset['train'].column_names}")

# Show sample to understand structure
print("\n" + "="*80)
print("📋 Sample Dataset Entry:")
print("="*80)
sample = dataset['train'][0]
for key, value in sample.items():
    display_value = str(value)[:100] + "..." if len(str(value)) > 100 else str(value)
    print(f"{key}: {display_value}")
print("="*80)

# Filter out samples that are too long or have missing fields
def filter_examples(example):
    """Filter out invalid examples - using Persian columns"""
    # For Persian QA dataset, use original_instruction and original_output
    instruction = example.get('original_instruction', '').strip()
    output = example.get('original_output', '').strip()

    # Must have both instruction and output in Persian
    return len(instruction) > 0 and len(output) > 0

logger.info("Filtering invalid examples...")
initial_count = len(dataset['train'])
dataset = dataset.filter(filter_examples)
filtered_count = len(dataset['train'])
logger.info(f"After filtering - Train samples: {filtered_count:,} (removed {initial_count - filtered_count:,})")

if filtered_count == 0:
    logger.error("  Dataset is empty after filtering!")
    raise ValueError("Dataset is empty. Please check the dataset structure and column names.")



# Format the dataset for causal language modeling using Persian columns
def format_instruction(example):
    """Format examples as instruction-following prompts using Persian text"""
    # Use Persian columns
    instruction = example.get('original_instruction', '').strip()
    output = example.get('original_output', '').strip()

    # Format as instruction-following
    text = f"### دستور:\n{instruction}\n\n### پاسخ:\n{output}{tokenizer.eos_token}"
    return {"text": text}

logger.info("Formatting dataset with Persian columns...")
dataset = dataset.map(
    format_instruction,
    remove_columns=dataset['train'].column_names,
    desc="Formatting Persian instructions"
)

# Apply dataset size limiting if specified
if config.dataset_size_limit is not None and isinstance(config.dataset_size_limit, int) and config.dataset_size_limit > 0:
    if config.dataset_size_limit >= len(dataset['train']):
        logger.info(f"Dataset size limit ({config.dataset_size_limit}) >= total dataset size ({len(dataset['train'])}), using full dataset.")
    else:
        logger.info(f"Limiting dataset to {config.dataset_size_limit} samples (shuffling first)...")
        # Shuffle the dataset before taking the subset
        dataset['train'] = dataset['train'].shuffle(seed=config.seed)
        dataset['train'] = dataset['train'].select(range(config.dataset_size_limit))
        logger.info(f"Dataset limited to {len(dataset['train'])} samples.")


# Split into train and validation
logger.info("Splitting dataset...")
dataset = dataset['train'].train_test_split(
    test_size=config.val_set_size,
    seed=config.seed
)

logger.info(f" Final dataset split:")
logger.info(f"  Train: {len(dataset['train']):,}")
logger.info(f"  Validation: {len(dataset['test']):,}")

# Show formatted example
print("\n" + "="*80)
print("📝 Example Formatted Persian Text:")
print("="*80)
print(dataset['train'][0]['text'][:500])
if len(dataset['train'][0]['text']) > 500:
    print("... (truncated)")
print("="*80)

README.md:   0%|          | 0.00/836 [00:00<?, ?B/s]

data/train-00000-of-00001-76038f702094e6(…):   0%|          | 0.00/187M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/153127 [00:00<?, ? examples/s]


📋 Sample Dataset Entry:
input: None
instruction: چگونه می توانم نقاط داغ و تاول ها را درمان کنم وقتی که موسکین ندارم؟ چند بار بیرون پیاده روی کرده ام...
original_instruction: How do I treat hot spots and blisters when I have no moleskin?
A few times I've been out walking or ...
original_output: The key is reducing friction. Duct tape can be a good preventative as long as you get it on before a...
output: کلید کاهش اصطکاک است. نوار چسب می تواند یک پیشگیری خوب باشد تا زمانی که شما آن را قبل از ایجاد تاول ...
source: stackexchange-outdoors


Filter:   0%|          | 0/153127 [00:00<?, ? examples/s]

Formatting Persian instructions:   0%|          | 0/153127 [00:00<?, ? examples/s]


📝 Example Formatted Persian Text:
### دستور:
Is it healthy to be really petite
My gf is 1.55 m (5'1") tall and weights 40 kg (88 lbs).

Considering everything else to be at normal parameters, is it healthy to be really this tiny?

### پاسخ:
There are health risks associated with being underweight. The NIH BMI calculator returns 16.6 for a 5'1" 88 lb person and says anything under 18.5 is underweight. There is less information available about the risks of a low BMI than a high one, but I found a page from the NHS in the UK. It su
... (truncated)


## Tokenize Dataset

In [13]:
logger.info("Tokenizing dataset...")

def tokenize_function(examples):
    """Tokenize the texts"""
    outputs = tokenizer(
        examples["text"],
        truncation=True,
        max_length=config.max_input_length,
        padding="max_length",
        return_tensors=None,
    )
    outputs["labels"] = outputs["input_ids"].copy()
    return outputs

# Tokenize datasets
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing dataset",
)

logger.info("Tokenization complete!")
logger.info(f"Train samples: {len(tokenized_dataset['train']):,}")
logger.info(f"Validation samples: {len(tokenized_dataset['test']):,}")

Tokenizing dataset:   0%|          | 0/137814 [00:00<?, ? examples/s]

Tokenizing dataset:   0%|          | 0/15313 [00:00<?, ? examples/s]

## Setup Training Arguments for TPU

In [16]:
logger.info("Setting up training arguments for TPU...")

training_args = TrainingArguments(
    # Output and logging
    output_dir=config.output_dir,
    logging_dir=config.logging_dir,
    logging_steps=config.logging_steps,

    # Training parameters
    num_train_epochs=config.num_train_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    per_device_eval_batch_size=config.per_device_eval_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,

    # Optimization
    learning_rate=config.learning_rate,
    warmup_ratio=config.warmup_ratio,
    max_grad_norm=config.max_grad_norm,
    lr_scheduler_type=config.lr_scheduler_type,
    optim="adafactor",  # Adafactor works well on TPU

    # Precision - TPU v5e supports bfloat16
    bf16=True,  # Use bfloat16 for TPU
    fp16=False,

    # Evaluation
    eval_strategy="steps",
    eval_steps=config.eval_steps,

    # Saving
    save_strategy="steps",
    save_steps=config.save_steps,
    save_total_limit=config.save_total_limit,

    # Performance
    dataloader_num_workers=0,  # TPU doesn't need multiple workers
    group_by_length=False,  # Can cause issues on TPU
    dataloader_pin_memory=False,

    # Reporting
    report_to="tensorboard",

    # TPU specific
    tpu_num_cores=config.num_tpu_cores,
    dataloader_drop_last=True,  # Recommended for TPU

    # Misc
    seed=config.seed,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

logger.info("Training arguments configured!")
print("\n" + "="*50)
print("Training Configuration:")
print("="*50)
print(f"TPU cores: {config.num_tpu_cores}")
print(f"Per device batch size: {config.per_device_train_batch_size}")
print(f"Gradient accumulation: {config.gradient_accumulation_steps}")
print(f"Effective batch size: {config.per_device_train_batch_size * config.num_tpu_cores * config.gradient_accumulation_steps}")
print(f"Learning rate: {config.learning_rate}")
print(f"Epochs: {config.num_train_epochs}")
print(f"Precision: bfloat16")
print("="*50)


Training Configuration:
TPU cores: 8
Per device batch size: 4
Gradient accumulation: 4
Effective batch size: 128
Learning rate: 0.0002
Epochs: 3
Precision: bfloat16


## Initialize Data Collator

In [17]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # We're doing causal language modeling, not masked LM
)

logger.info("Data collator initialized")

## Initialize Trainer

In [18]:
logger.info("Initializing Trainer with custom callbacks...")

# Initialize callbacks
detailed_logger = DetailedLoggingCallback(log_every_n_steps=config.logging_steps)
early_stopping = EarlyStoppingCallback(
    patience=3,
    min_delta=0.001
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator,
    callbacks=[detailed_logger, early_stopping],
)

logger.info("  Trainer initialized successfully!")

# Calculate and display training info
steps_per_epoch = len(trainer.get_train_dataloader()) // config.gradient_accumulation_steps
total_steps = steps_per_epoch * config.num_train_epochs

print("\n" + "="*80)
print("  Training Information:")
print("="*80)
print(f"Steps per epoch: {steps_per_epoch}")
print(f"Total training steps: {total_steps}")
print(f"Evaluation every: {config.eval_steps} steps")
print(f"Save checkpoint every: {config.save_steps} steps")
print("="*80)


  Training Information:
Steps per epoch: 8613
Total training steps: 25839
Evaluation every: 200 steps
Save checkpoint every: 400 steps


## Start Training on TPU

**⚠️ Important**: Before running training, ensure:
1. You're connected to the TPU v5e-8 instance
2. All previous cells have been executed successfully
3. You have sufficient disk space for checkpoints

This will take several hours depending on the dataset size.

In [None]:
logger.info("="*70)
logger.info("STARTING TRAINING ON TPU")
logger.info("="*70)

try:
    # Start training
    train_result = trainer.train()

    logger.info("="*70)
    logger.info("TRAINING COMPLETED SUCCESSFULLY!")
    logger.info("="*70)

    # Print training results
    print("\n" + "="*50)
    print("Training Results:")
    print("="*50)
    for key, value in train_result.metrics.items():
        print(f"{key}: {value}")
    print("="*50)

    # Save final model
    final_model_path = f"{config.output_dir}/final_model"
    trainer.save_model(final_model_path)
    logger.info(f"Final model saved to: {final_model_path}")

    # Save tokenizer
    tokenizer.save_pretrained(final_model_path)
    logger.info(f"Tokenizer saved to: {final_model_path}")

except Exception as e:
    logger.error(f"Training failed with error: {str(e)}")
    raise e



Step,Training Loss,Validation Loss
200,3.4,3.221632
400,3.5,3.494632
600,5.4,6.616061


config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]



Step,Training Loss,Validation Loss
200,3.4,3.221632
400,3.5,3.494632
600,5.4,6.616061


## Evaluate Model

In [None]:
logger.info("Running final evaluation...")

eval_results = trainer.evaluate()

print("\n" + "="*50)
print("Evaluation Results:")
print("="*50)
for key, value in eval_results.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")
    else:
        print(f"{key}: {value}")
print("="*50)

logger.info("Evaluation complete!")

## Test Inference

In [None]:
logger.info("Testing model inference...")

def generate_response(question, max_length=256):
    """Generate response for a given question"""
    prompt = f"### سوال:\n{question}\n\n### پاسخ:\n"

    inputs = tokenizer(prompt, return_tensors="pt")

    # Move to appropriate device
    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=4,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract answer part
    if "### پاسخ:" in response:
        answer = response.split("### پاسخ:")[1].strip()
    else:
        answer = response

    return answer

# Test with sample questions
test_questions = [
    "پایتخت ایران کجاست؟",
    "چگونه می‌توان یک مدل زبانی را آموزش داد؟",
    "تفاوت بین هوش مصنوعی و یادگیری ماشین چیست؟"
]

print("\n" + "="*70)
print("Testing Model Inference:")
print("="*70)

for i, question in enumerate(test_questions, 1):
    print(f"\n[Test {i}]")
    print(f"سوال: {question}")
    answer = generate_response(question)
    print(f"پاسخ: {answer}")
    print("-"*70)

logger.info("Inference testing complete!")

## Save and Export Model for Deployment

In [None]:
# Save LoRA adapters separately (smaller file size)
lora_output_dir = f"{config.output_dir}/lora_adapters"
os.makedirs(lora_output_dir, exist_ok=True)

model.save_pretrained(lora_output_dir)
tokenizer.save_pretrained(lora_output_dir)

logger.info(f"LoRA adapters saved to: {lora_output_dir}")

# Print file sizes
import glob
for file_path in glob.glob(f"{lora_output_dir}/*"):
    size_mb = os.path.getsize(file_path) / (1024 * 1024)
    print(f"{os.path.basename(file_path)}: {size_mb:.2f} MB")

print("\n" + "="*70)
print("Model Training and Evaluation Complete!")
print("="*70)
print(f"Final model: {final_model_path}")
print(f"LoRA adapters: {lora_output_dir}")
print(f"Logs: {config.logging_dir}")
print("="*70)

## TPU Connection Verification (Run First)

Before starting the full training, run this cell to verify your TPU connection.