In [None]:
#!/usr/bin/env python3
"""
Medical LLM Fine-tuning Pipeline
Complete end-to-end fine-tuning setup for medical/clinical applications
Supports CUDA acceleration and Hugging Face datasets
"""

import os
import torch
import json
import logging
from datetime import datetime
from typing import Dict, List, Optional
import wandb
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
import bitsandbytes as bnb
from accelerate import Accelerator

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class MedicalLLMFineTuner:
    def __init__(self, config: Dict):
        """
        Initialize the Medical LLM Fine-tuning pipeline
        
        Args:
            config: Configuration dictionary with model, training, and data parameters
        """
        self.config = config
        self.accelerator = Accelerator()
        self.device = self.accelerator.device
        
        # Set up directories
        self.output_dir = config['output_dir']
        self.logs_dir = os.path.join(self.output_dir, 'logs')
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(self.logs_dir, exist_ok=True)
        
        logger.info(f"Using device: {self.device}")
        logger.info(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            logger.info(f"GPU: {torch.cuda.get_device_name()}")
            logger.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

    def load_model_and_tokenizer(self):
        """Load the base medical model and tokenizer"""
        model_name = self.config['model_name']
        logger.info(f"Loading model: {model_name}")
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True,
            padding_side="right"
        )
        
        # Add pad token if it doesn't exist
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        
        # Load model with quantization for memory efficiency
        if self.config.get('use_quantization', True):
            from transformers import BitsAndBytesConfig
            
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
            )
            
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True,
                torch_dtype=torch.bfloat16
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="auto",
                trust_remote_code=True,
                torch_dtype=torch.bfloat16
            )
        
        # Prepare model for training
        self.model = prepare_model_for_kbit_training(self.model)
        
        logger.info("Model and tokenizer loaded successfully")

    def setup_lora(self):
        """Setup LoRA (Low-Rank Adaptation) for efficient fine-tuning"""
        lora_config = LoraConfig(
            r=self.config.get('lora_r', 16),
            lora_alpha=self.config.get('lora_alpha', 32),
            target_modules=self.config.get('lora_target_modules', [
                "q_proj", "k_proj", "v_proj", "o_proj",
                "gate_proj", "up_proj", "down_proj"
            ]),
            lora_dropout=self.config.get('lora_dropout', 0.1),
            bias="none",
            task_type=TaskType.CAUSAL_LM,
        )
        
        self.model = get_peft_model(self.model, lora_config)
        self.model.print_trainable_parameters()
        
        logger.info("LoRA configuration applied")

    def load_dataset(self):
        """Load and preprocess the medical dataset"""
        dataset_name = self.config['dataset_name']
        logger.info(f"Loading dataset: {dataset_name}")
        
        if dataset_name.startswith('custom:'):
            # Load custom dataset from local path
            dataset_path = dataset_name.replace('custom:', '')
            self.dataset = load_dataset('json', data_files=dataset_path)
        else:
            # Load from Hugging Face Hub
            self.dataset = load_dataset(
                dataset_name,
                split=self.config.get('dataset_split', 'train')
            )
        
        # If dataset doesn't have train/validation split, create one
        if isinstance(self.dataset, Dataset):
            split = self.dataset.train_test_split(
                test_size=self.config.get('validation_split', 0.1),
                seed=42
            )
            self.dataset = DatasetDict({
                'train': split['train'],
                'validation': split['test']
            })
        
        logger.info(f"Dataset loaded - Train: {len(self.dataset['train'])}, Val: {len(self.dataset['validation'])}")

    def preprocess_data(self):
        """Preprocess the dataset for training"""
        def format_medical_prompt(example):
            """Format the data into a medical instruction-response format"""
            if 'instruction' in example and 'response' in example:
                # Standard instruction-response format
                prompt = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"
            elif 'question' in example and 'answer' in example:
                # Q&A format
                prompt = f"### Question:\n{example['question']}\n\n### Answer:\n{example['answer']}"
            elif 'input' in example and 'output' in example:
                # Input-output format
                prompt = f"### Input:\n{example['input']}\n\n### Output:\n{example['output']}"
            elif 'text' in example:
                # Raw text format
                prompt = example['text']
            else:
                # Try to find the best fields automatically
                keys = list(example.keys())
                if len(keys) >= 2:
                    prompt = f"### {keys[0].title()}:\n{example[keys[0]]}\n\n### {keys[1].title()}:\n{example[keys[1]]}"
                else:
                    prompt = str(example)
            
            return {"text": prompt + self.tokenizer.eos_token}

        def tokenize_function(examples):
            """Tokenize the text"""
            tokenized = self.tokenizer(
                examples["text"],
                truncation=True,
                padding=False,
                max_length=self.config.get('max_length', 2048),
                return_overflowing_tokens=False,
            )
            tokenized["labels"] = tokenized["input_ids"].copy()
            return tokenized

        # Apply formatting
        self.dataset = self.dataset.map(format_medical_prompt, remove_columns=self.dataset['train'].column_names)
        
        # Tokenize
        self.dataset = self.dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=["text"],
            desc="Tokenizing dataset"
        )
        
        logger.info("Dataset preprocessing completed")

    def setup_training_arguments(self):
        """Setup training arguments"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_name = f"medical_llm_finetune_{timestamp}"
        
        self.training_args = TrainingArguments(
            output_dir=self.output_dir,
            run_name=run_name,
            
            # Training hyperparameters
            num_train_epochs=self.config.get('num_epochs', 3),
            per_device_train_batch_size=self.config.get('train_batch_size', 4),
            per_device_eval_batch_size=self.config.get('eval_batch_size', 4),
            gradient_accumulation_steps=self.config.get('gradient_accumulation_steps', 4),
            
            # Optimization
            learning_rate=self.config.get('learning_rate', 2e-4),
            weight_decay=self.config.get('weight_decay', 0.01),
            lr_scheduler_type=self.config.get('lr_scheduler_type', "cosine"),
            warmup_ratio=self.config.get('warmup_ratio', 0.03),
            
            # Memory optimization
            optim="paged_adamw_32bit",
            fp16=False,
            bf16=True,
            dataloader_pin_memory=False,
            
            # Evaluation and saving
            evaluation_strategy="steps",
            eval_steps=self.config.get('eval_steps', 500),
            save_strategy="steps",
            save_steps=self.config.get('save_steps', 500),
            save_total_limit=3,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            
            # Logging
            logging_dir=self.logs_dir,
            logging_steps=self.config.get('logging_steps', 100),
            report_to="wandb" if self.config.get('use_wandb', False) else None,
            
            # Miscellaneous
            remove_unused_columns=False,
            push_to_hub=self.config.get('push_to_hub', False),
            hub_model_id=self.config.get('hub_model_id', None),
        )
        
        logger.info("Training arguments configured")

    def setup_trainer(self):
        """Setup the Trainer"""
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False,
        )
        
        callbacks = []
        if self.config.get('early_stopping', True):
            callbacks.append(EarlyStoppingCallback(
                early_stopping_patience=self.config.get('early_stopping_patience', 3)
            ))
        
        self.trainer = Trainer(
            model=self.model,
            args=self.training_args,
            train_dataset=self.dataset['train'],
            eval_dataset=self.dataset['validation'],
            tokenizer=self.tokenizer,
            data_collator=data_collator,
            callbacks=callbacks,
        )
        
        logger.info("Trainer configured")

    def train(self):
        """Start the training process"""
        logger.info("Starting training...")
        
        # Initialize wandb if enabled
        if self.config.get('use_wandb', False):
            wandb.init(
                project=self.config.get('wandb_project', 'medical-llm-finetuning'),
                config=self.config
            )
        
        # Start training
        train_result = self.trainer.train()
        
        # Save the final model
        self.trainer.save_model()
        self.trainer.save_state()
        
        # Log training results
        logger.info(f"Training completed!")
        logger.info(f"Final train loss: {train_result.training_loss:.4f}")
        
        # Save training metrics
        metrics_file = os.path.join(self.output_dir, 'training_metrics.json')
        with open(metrics_file, 'w') as f:
            json.dump(train_result.metrics, f, indent=2)
        
        return train_result

    def save_model_for_inference(self):
        """Save the model in a format ready for inference"""
        inference_dir = os.path.join(self.output_dir, 'inference_model')
        
        # Save the merged model (LoRA + base model)
        self.model.save_pretrained(inference_dir)
        self.tokenizer.save_pretrained(inference_dir)
        
        # Create model card
        model_card = f"""
# Medical LLM Fine-tuned Model

This model has been fine-tuned for medical/clinical applications.

## Model Details
- Base Model: {self.config['model_name']}
- Dataset: {self.config['dataset_name']}
- Fine-tuning Method: LoRA (Low-Rank Adaptation)
- Training Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

## Usage
```python
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("{inference_dir}")
model = AutoModelForCausalLM.from_pretrained("{inference_dir}")

# Example usage
prompt = "### Instruction:\\nExplain the symptoms of diabetes\\n\\n### Response:\\n"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=512, temperature=0.7)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
```

## Intended Use
- Clinical trial support
- Medical question answering
- RAG applications in healthcare
- Multi-agent medical systems

## Limitations
- This model is for research and educational purposes
- Should not be used as a substitute for professional medical advice
- Always validate outputs with medical professionals
"""
        
        with open(os.path.join(inference_dir, 'README.md'), 'w') as f:
            f.write(model_card)
        
        logger.info(f"Model saved for inference at: {inference_dir}")


In [None]:
def main():
    """Main function to run the fine-tuning pipeline"""
    
    # Configuration optimized for 2x RTX 4090 setup (48GB total VRAM)
    config = {
        # Model configuration - can handle larger models with your setup
        'model_name': 'medalpaca/medalpaca-13b',  # 13B model fits well on dual 4090s
        'use_quantization': True,  # Still use quantization for memory efficiency
        
        # Dataset configuration
        'dataset_name': 'medalpaca/medical_meadow_medical_flashcards',  # Medical dataset
        'dataset_split': 'train',
        'validation_split': 0.1,
        'max_length': 4096,  # Increased context length for better performance
        
        # LoRA configuration - more aggressive settings for better results
        'lora_r': 64,  # Higher rank for better adaptation
        'lora_alpha': 128,  # Scaled accordingly
        'lora_dropout': 0.05,  # Lower dropout with more VRAM
        'lora_target_modules': [
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"  # More target modules
        ],
        
        # Training configuration - optimized for dual RTX 4090
        'num_epochs': 5,  # More epochs for better convergence
        'train_batch_size': 4,  # Larger batch size with 48GB VRAM
        'eval_batch_size': 4,
        'gradient_accumulation_steps': 4,  # Reduced due to larger batch size
        'learning_rate': 1e-4,  # Slightly lower LR for stability
        'weight_decay': 0.01,
        'lr_scheduler_type': "cosine",
        'warmup_ratio': 0.05,
        
        # Evaluation and saving
        'eval_steps': 100,  # More frequent evaluation
        'save_steps': 250,  # More frequent saves
        'logging_steps': 25,  # More frequent logging
        'early_stopping': True,
        'early_stopping_patience': 5,
        
        # Output configuration
        'output_dir': './medical_llm_finetuned',
        
        # Weights & Biases (recommended for this setup)
        'use_wandb': True,  # Enable for experiment tracking
        'wandb_project': 'medical-llm-dual-4090',
        
        # Hugging Face Hub (optional)
        'push_to_hub': False,  # Set to True if you want to push to hub
        'hub_model_id': 'your-username/medical-llm-13b-clinical',
    }
    
    # Initialize and run the fine-tuning pipeline
    finetuner = MedicalLLMFineTuner(config)
    
    try:
        # Load model and tokenizer
        finetuner.load_model_and_tokenizer()
        
        # Setup LoRA
        finetuner.setup_lora()
        
        # Load and preprocess dataset
        finetuner.load_dataset()
        finetuner.preprocess_data()
        
        # Setup training
        finetuner.setup_training_arguments()
        finetuner.setup_trainer()
        
        # Train the model
        train_result = finetuner.train()
        
        # Save model for inference
        finetuner.save_model_for_inference()
        
        logger.info("Fine-tuning pipeline completed successfully!")
        
    except Exception as e:
        logger.error(f"Error during fine-tuning: {str(e)}")
        raise


In [None]:
if __name__ == "__main__":
    main()