# QWEN3-0.6B LoRA Fine-tuning on SageMaker

This notebook demonstrates how to fine-tune QWEN3-0.6B using LoRA on Amazon SageMaker with local sample data.

## 1. Setup and Import Libraries

In [None]:
import os
import boto3
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.inputs import TrainingInput
from datetime import datetime

In [None]:
!pip install --upgrade sagemaker # to use pytorch 2.7.1 for training job

## 2. Configure SageMaker Session and Role

In [None]:
# SageMaker session
sagemaker_session = sagemaker.Session()

# IAM role
role = "arn:aws:iam::637423390840:role/WSParticipantRole" # need to change your role
print(f"Using SageMaker execution role: {role}")

# S3 Bucket (default bucket)
bucket = sagemaker_session.default_bucket()
prefix = "qwen3-0-6-lora-samples"

print(f"Using bucket: {bucket}")
print(f"Using prefix: {prefix}")

## 3. Upload Local Sample Data to S3

In [None]:
# Upload local train.jsonl data to S3
print("Uploading train.jsonl data to S3...")
train_s3_uri = sagemaker_session.upload_data(
    path='samples/train.jsonl',
    bucket=bucket,
    key_prefix=f'{prefix}/data/train'
)
print(f"Training data uploaded to: {train_s3_uri}")

## 4. Configure Training Parameters

In [None]:
# Training configuration
exp_name = 'qwen3-0-6b-lora-fine-tuning'
instance_type = 'ml.g5.2xlarge'

# Output paths
output_path = f"s3://{bucket}/{prefix}/output"
checkpoint_s3_uri = f"s3://{bucket}/{prefix}/checkpoints"

# Job name based on timestamp
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
job_name = f"{exp_name}-lora-{timestamp}"

print(f"Job name: {job_name}")
print(f"Output path: {output_path}")
print(f"Checkpoint path: {checkpoint_s3_uri}")

## 5. Define Hyperparameters

In [None]:
# Set hyperparameters
hyperparameters = {
    # Model
    "model_name_or_path": "Qwen/Qwen3-0.6B",
    
    # Training (HuggingFace TrainingArguments)
    "output_dir": "/opt/ml/model",
    "num_train_epochs": 3,
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "gradient_accumulation_steps": 64,
    "learning_rate": 2e-4,
    "weight_decay": 0.01,
    "warmup_ratio": 0.03,
    "lr_scheduler_type": "cosine",
    "logging_steps": 1,
    "save_steps": 50,
    "save_strategy": "steps",
    "save_total_limit": 3,
    "do_eval": True,  # Enable evaluation
    "eval_strategy": "steps",  # Use eval_strategy instead of evaluation_strategy
    "eval_steps": 50,
    "metric_for_best_model": "eval_loss",
    "greater_is_better": False,
    "load_best_model_at_end": False,  # Disable for LoRA
    "report_to": "none",
    "bf16": True,
    "gradient_checkpointing": True,
    # DeepSpeed removed - not using distributed training
    
    # LoRA
    "lora_r": 4, 
    "lora_alpha": 32,
    "lora_dropout": 0.1,
    "lora_target_modules": "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
    
    # Dataset
    "train_file": "/opt/ml/input/data/train/train.jsonl",  # SageMaker mounts S3 data here
    "validation_split_percentage": 20,  # Split 20% for validation from train.jsonl
    "block_size": 256,
}

print("Hyperparameters configured successfully")

## 6. Create PyTorch Estimator

In [None]:
# PyTorch Estimator
estimator = PyTorch(
    entry_point="train.py", # entry point code
    source_dir="src",  # source directory
    role=role,
    instance_type=instance_type,
    instance_count=1,
    framework_version="2.7.1",  
    py_version="py312",  
    hyperparameters=hyperparameters,
    output_path=output_path,
    checkpoint_s3_uri=checkpoint_s3_uri,
    use_spot_instances=False,  
    max_run=24*60*60,  # Maximum 24 hours
    keep_alive_period_in_seconds=1800,
    volume_size=450,
    environment={
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
    },
)

print("PyTorch estimator created successfully")

## 7. Prepare Training Inputs

In [None]:
# Training data input - using train.jsonl
train_input = TrainingInput(
    s3_data=train_s3_uri,
    content_type="application/jsonl",
    s3_data_type="S3Prefix",
    distribution="FullyReplicated"
)

print(f"Training input configured with data from: {train_s3_uri}")

## 8. Start Training Job

In [None]:
# Start training job
print(f"Starting training job: {job_name}")
print(f"Training data: {train_s3_uri}")
print(f"Output path: {output_path}")
print(f"Note: The training script will automatically split train.jsonl - {100-hyperparameters['validation_split_percentage']}% for training, {hyperparameters['validation_split_percentage']}% for validation")

estimator.fit(
    inputs={
        "train": train_input
    },
    job_name=job_name,
    wait=False  # Asynchronous start
)

print(f"\nTraining job '{job_name}' has been submitted!")
print(f"You can monitor the job in the SageMaker console")