#### Install unsloth on Colab

In [None]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

!pip install transformers==4.55.4
!pip install --no-deps trl==0.22.2
# !pip install transformers
# !pip install --no-deps trl

#### Import Libraries

In [None]:
from unsloth import FastVisionModel
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
# TRL (Transformer Reinforcement Learning)
from unsloth import unsloth_train

from datetime import datetime
from datasets import load_dataset
from huggingface_hub import whoami, HfApi, create_repo
from pathlib import Path

import os, json, gc, torch
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from transformers import TrainerCallback

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


#### Configuration

In [None]:
args = {
    "lr": 2e-4,
    "lora_r": 16,
    "lora_a": 16,
    "max_length": 2048,
    "batch_size": 8,
    "gradient_accum": 2,
    "warmup": 50,
    "num_train_epochs": 10,
    "eval_steps": 100,
    "save_steps": 100,
    "temperature": 1.5,
    "logging_steps": 5,
    "early_stopping_patience": 3,
    "overfitting_threshold": 0.15,
    "save_total_limit": 2,
    "use_validation": True,
    "hf_username": "Laya-hmkh",
    "dataset_name": "maomao1234/r1_report_generation",
    "token": "" #Replace your own token here
}

print("✓ Hyperparameters loaded")

#### Train Monitoring

In [None]:
class TrainingMonitor:
    """Comprehensive training monitoring and visualization"""

    def __init__(self, output_dir="monitoring", use_validation=True):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True, parents=True)
        self.use_validation = use_validation

        # Metrics storage
        self.metrics = defaultdict(list)
        self.train_losses = []
        self.val_losses = []
        self.learning_rates = []
        self.gpu_memory = []
        self.steps = []
        self.val_steps = []

        # Best model tracking
        self.best_val_loss = float('inf')
        self.best_step = 0
        self.patience_counter = 0

    def log_metrics(self, step, metrics_dict):
        """Log metrics at each step"""
        self.steps.append(step)

        for key, value in metrics_dict.items():
            self.metrics[key].append(value)

        if 'train_loss' in metrics_dict:
            self.train_losses.append(metrics_dict['train_loss'])

        if 'learning_rate' in metrics_dict:
            self.learning_rates.append(metrics_dict['learning_rate'])

        if 'gpu_memory_gb' in metrics_dict:
            self.gpu_memory.append(metrics_dict['gpu_memory_gb'])

    def log_validation(self, step, val_loss):
        """Log validation metrics"""
        if not self.use_validation:
            return False

        self.val_steps.append(step)
        self.val_losses.append(val_loss)

        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.best_step = step
            self.patience_counter = 0
            return True
        else:
            self.patience_counter += 1
            return False

    def should_stop_early(self, patience):
        """Check if early stopping criteria met"""
        if not self.use_validation:
            return False
        return self.patience_counter >= patience

    def plot_training_curves(self, exp_name):
        """Generate comprehensive training visualization"""
        try:
            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            title_suffix = " (No Validation)" if not self.use_validation or not self.val_losses else ""
            fig.suptitle(f'Training Monitoring Dashboard - {exp_name}{title_suffix}',
                        fontsize=16, fontweight='bold')

            # 1. Training Loss
            ax1 = axes[0, 0]
            if self.train_losses:
                min_len = min(len(self.steps), len(self.train_losses))
                ax1.plot(self.steps[:min_len], self.train_losses[:min_len],
                        label='Train Loss', color='blue', alpha=0.7)
            if self.use_validation and self.val_losses:
                ax1.plot(self.val_steps, self.val_losses, label='Val Loss', color='red',
                        marker='o', markersize=4, linestyle='--')
                ax1.axvline(x=self.best_step, color='green', linestyle=':',
                          label=f'Best Val (step {self.best_step})')
            ax1.set_xlabel('Steps')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training Loss' + (' & Validation' if self.use_validation and self.val_losses else ''))
            ax1.legend()
            ax1.grid(True, alpha=0.3)

            # 2. Learning Rate Schedule
            ax2 = axes[0, 1]
            if self.learning_rates:
                min_len = min(len(self.steps), len(self.learning_rates))
                ax2.plot(self.steps[:min_len], self.learning_rates[:min_len],
                        color='orange', linewidth=2)
                ax2.set_xlabel('Steps')
                ax2.set_ylabel('Learning Rate')
                ax2.set_title('Learning Rate Schedule')
                ax2.grid(True, alpha=0.3)
                ax2.set_yscale('log')

            # 3. GPU Memory Usage
            ax3 = axes[1, 0]
            if self.gpu_memory:
                ax3.plot(self.steps, self.gpu_memory, color='purple', linewidth=2)
                ax3.fill_between(self.steps, self.gpu_memory, alpha=0.3, color='purple')
                ax3.set_xlabel('Steps')
                ax3.set_ylabel('GPU Memory (GB)')
                ax3.set_title('GPU Memory Usage')
                ax3.grid(True, alpha=0.3)

            # 4. Loss Trend
            ax4 = axes[1, 1]
            if self.use_validation and self.val_losses and len(self.val_losses) > 1:
                train_loss_at_val = []
                for val_step in self.val_steps:
                    closest_idx = min(range(len(self.steps)),
                                    key=lambda i: abs(self.steps[i] - val_step))
                    train_loss_at_val.append(self.train_losses[closest_idx])

                x = np.arange(len(self.val_losses))
                width = 0.35
                ax4.bar(x - width/2, train_loss_at_val, width, label='Train Loss',
                       color='blue', alpha=0.7)
                ax4.bar(x + width/2, self.val_losses, width, label='Val Loss',
                       color='red', alpha=0.7)
                ax4.set_xlabel('Evaluation Checkpoint')
                ax4.set_ylabel('Loss')
                ax4.set_title('Train vs Val Loss at Checkpoints')
                ax4.legend()
                ax4.grid(True, alpha=0.3, axis='y')
            else:
                if self.train_losses and len(self.train_losses) > 5:
                    window_size = max(5, len(self.train_losses) // 20)
                    smoothed_losses = []
                    for i in range(len(self.train_losses)):
                        start = max(0, i - window_size // 2)
                        end = min(len(self.train_losses), i + window_size // 2 + 1)
                        smoothed_losses.append(np.mean(self.train_losses[start:end]))

                    ax4.plot(self.steps, self.train_losses, alpha=0.3, color='blue', label='Raw Loss')
                    ax4.plot(self.steps, smoothed_losses, color='darkblue', linewidth=2, label='Smoothed Loss')
                    ax4.set_xlabel('Steps')
                    ax4.set_ylabel('Loss')
                    ax4.set_title('Training Loss Trend (Smoothed)')
                    ax4.legend()
                    ax4.grid(True, alpha=0.3)

            plt.tight_layout()

            plot_path = self.output_dir / f"{exp_name}_training_curves.png"
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            print(f" Training curves saved to: {plot_path}")
            plt.close()

            return plot_path

        except Exception as e:
            print(f"⚠ Warning: Could not generate plots - {e}")
            return None

    def generate_summary_report(self, exp_name, training_result):
        """Generate detailed training summary"""
        try:
            report = {
                "experiment_name": exp_name,
                "training_summary": training_result,
                "validation_used": self.use_validation,
                "monitoring_metrics": {
                    "total_steps": len(self.steps),
                    "final_train_loss": self.train_losses[-1] if self.train_losses else None,
                    "best_val_loss": self.best_val_loss if self.use_validation and self.val_losses else None,
                    "best_step": self.best_step if self.use_validation and self.val_losses else None,
                    "final_learning_rate": self.learning_rates[-1] if self.learning_rates else None,
                    "max_gpu_memory_gb": max(self.gpu_memory) if self.gpu_memory else None,
                    "avg_gpu_memory_gb": np.mean(self.gpu_memory) if self.gpu_memory else None,
                },
                "loss_progression": {
                    "train_losses": self.train_losses,
                    "val_losses": self.val_losses if self.use_validation else [],
                    "train_steps": self.steps,
                    "val_steps": self.val_steps if self.use_validation else [],
                },
                "timestamp": datetime.now().isoformat()
            }

            report_path = self.output_dir / f"{exp_name}_training_report.json"
            with open(report_path, 'w') as f:
                json.dump(report, f, indent=2)
            print(f" Training report saved to: {report_path}")

            return report

        except Exception as e:
            print(f"⚠ Warning: Could not generate report - {e}")
            return None

    def print_progress(self, step, metrics):
        """Print training progress"""
        try:
            progress_str = f"Step {step}"
            if 'train_loss' in metrics:
                progress_str += f" | Train Loss: {metrics['train_loss']:.4f}"
            if 'learning_rate' in metrics:
                progress_str += f" | LR: {metrics['learning_rate']:.2e}"
            if 'gpu_memory_gb' in metrics:
                progress_str += f" | GPU: {metrics['gpu_memory_gb']:.2f}GB"
            print(progress_str)
        except Exception as e:
            print(f"⚠ Warning: Could not print progress - {e}")

#### Callback for training monitoring

In [None]:
class MonitoringCallback(TrainerCallback):
    """Custom callback for training monitoring"""

    def __init__(self, monitor, use_validation=True, eval_steps=10, patience=5, overfitting_threshold=0.15):
        self.monitor = monitor
        self.use_validation = use_validation
        self.eval_steps = eval_steps
        self.patience = patience
        self.overfitting_threshold = overfitting_threshold
        self.last_train_loss = None

    def on_log(self, args, state, control, logs=None, **kwargs):
        """Called when logging occurs"""
        try:
            if logs:
                step = state.global_step
                metrics = {}

                if 'loss' in logs:
                    metrics['train_loss'] = logs['loss']
                    self.last_train_loss = logs['loss']
                if 'learning_rate' in logs:
                    metrics['learning_rate'] = logs['learning_rate']
                if torch.cuda.is_available():
                    metrics['gpu_memory_gb'] = torch.cuda.max_memory_allocated() / 1024**3

                self.monitor.log_metrics(step, metrics)
                self.monitor.print_progress(step, metrics)
        except Exception as e:
            print(f"⚠ Monitoring warning in on_log: {e}")

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Called after evaluation"""
        try:
              if not self.use_validation:
                  return control

              if metrics and 'eval_loss' in metrics:
                  step = state.global_step
                  val_loss = metrics['eval_loss']
                  improved = self.monitor.log_validation(step, val_loss)

                  if self.last_train_loss is not None:
                      loss_gap = val_loss - self.last_train_loss

                      print(f"\n{'='*60}")
                      print(f"Validation at step {step}: Loss = {val_loss:.4f}")
                      print(f"Training Loss: {self.last_train_loss:.4f}")
                      print(f"Gap (Val - Train): {loss_gap:.4f}")

                      if improved:
                          print(" New best model!")
                      else:
                          print(f" No improvement (patience: {self.monitor.patience_counter}/{self.patience})")

                      # Check for overfitting
                      if loss_gap > self.overfitting_threshold:
                          print(f"\n  OVERFITTING DETECTED!")
                          print(f"   Loss gap ({loss_gap:.4f}) exceeds threshold ({self.overfitting_threshold})")
                          print(f"   Training will stop to prevent overfitting.")
                          print(f"   Best checkpoint: step {self.monitor.best_step}")
                          control.should_training_stop = True
                          return control

                      print(f"{'='*60}\n")

                  # Regular early stopping
                  if self.monitor.should_stop_early(self.patience):
                      print(f"\n  Early stopping triggered after {self.patience} evaluations without improvement")
                      control.should_training_stop = True
        except Exception as e:
              print(f"⚠ Monitoring warning in on_evaluate: {e}")
        return control

#### Utilities

In [None]:
def aggressive_cleanup():
    """Aggressive memory cleanup"""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

def get_experiment_combinations():
    """Minimal ablation grid"""
    return [
        {
            "finetune_vision_layers": True,
            "finetune_language_layers": True,
            "finetune_attention_modules": True,
            "finetune_mlp_modules": True,
        }
    ]

#### Authentication

In [None]:
try:
    user = whoami(args["token"])
    print(f"✓ Authenticated as: {user['name']}")
except Exception as e:
    print(f"⚠ Authentication warning: {e}")

#### Dataset preparation

In [None]:
INSTRUCTION = (
    "You are an expert radiologist. Analyze the provided medical images accurately."
    "State every finding and impression that is visually evident."
    "Make it as detailed as possible."
)

def convert_to_conversation(sample):
    """Convert dataset format to conversation format"""
    content = []
    for img in sample["images"]:
        content.append({"type": "image", "image": img})
    content.append({"type": "text", "text": INSTRUCTION})

    conversation = [
        {"role": "user", "content": content},
        {"role": "assistant", "content": [{"type": "text", "text": sample["answer"]}]}
    ]
    return {"messages": conversation}

def load_datasets(dataset_name):
    """Load train and validation datasets"""
    print("Loading datasets...")
    train_ds = load_dataset(dataset_name, split="train")

    try:
        val_ds = load_dataset(dataset_name, split="val")
        print(f"✓ Loaded: {len(train_ds)} train, {len(val_ds)} val samples")
    except Exception as e:
        val_ds = None
        print(f"✓ Loaded: {len(train_ds)} train samples (no validation)")

    return train_ds, val_ds

#### Load datasets

In [None]:
train_ds, val_ds = load_datasets(args["dataset_name"])
train_data = [convert_to_conversation(sample) for sample in train_ds]
val_data = [convert_to_conversation(sample) for sample in val_ds] if val_ds else None

#### Core training class

In [None]:
class ModelTrainer:
    """Core training class - no monitoring dependencies"""

    def __init__(self):
        self.results = []
        self.trained_models = {}  # Store trained models

    def _make_name(self, combo):
        return (f"V{combo['finetune_vision_layers']}_"
                f"L{combo['finetune_language_layers']}_"
                f"A{combo['finetune_attention_modules']}_"
                f"M{combo['finetune_mlp_modules']}")

    def train_model(self, model, tokenizer, combo, train_dataset, args, val_dataset=None):
        """
        Core training function - SAFE from monitoring failures
        Returns trained model and tokenizer no matter what
        """
        aggressive_cleanup()
        exp_name = self._make_name(combo)
        use_validation = val_dataset is not None

        print(f"\n{'='*30}")
        print(f"   Starting Training: {exp_name}")
        print(f"   Validation: {'Enabled' if use_validation else 'Disabled'}")
        print(f"{'='*30}")

        # Configure LoRA
        model = FastVisionModel.get_peft_model(
            model,
            finetune_vision_layers=combo["finetune_vision_layers"],
            finetune_language_layers=combo["finetune_language_layers"],
            finetune_attention_modules=combo["finetune_attention_modules"],
            finetune_mlp_modules=combo["finetune_mlp_modules"],
            r=args["lora_r"],
            lora_alpha=args["lora_a"],
            lora_dropout=0,
            bias="none",
            random_state=3407,
            use_rslora=False,
            loftq_config=None,
        )

        # Calculate parameters
        total = sum(p.numel() for p in model.parameters())
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Parameters - Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)")

        # Prepare for training
        mem_before = torch.cuda.max_memory_reserved() / 1024**3
        FastVisionModel.for_training(model)

        # Training configuration
        trainer_args = SFTConfig(
            per_device_train_batch_size=args["batch_size"],
            per_device_eval_batch_size=args["batch_size"] if use_validation else None,
            gradient_accumulation_steps=args["gradient_accum"],
            warmup_steps=args["warmup"],
            num_train_epochs=args.get("num_train_epochs", 3),
            learning_rate=args["lr"],
            logging_steps=args["logging_steps"],
            eval_strategy="steps" if use_validation else "no",
            eval_steps=args["eval_steps"] if use_validation else None,
            eval_accumulation_steps=args["gradient_accum"],
            save_strategy="steps",
            save_steps=args["save_steps"],
            save_total_limit=args["save_total_limit"],
            load_best_model_at_end=use_validation,
            metric_for_best_model="eval_loss" if use_validation else None,
            optim="adamw_8bit",
            weight_decay=0.01,
            lr_scheduler_type="linear",
            seed=3407,
            output_dir=f"outputs/{exp_name}",
            report_to="none",
            dataloader_pin_memory=False,
            dataloader_num_workers=0,
            # Fine-tuning vision
            remove_unused_columns=False,
            dataset_text_field="",
            dataset_kwargs={"skip_prepare_dataset": True},
            max_length=args["max_length"],
        )

        # Create trainer with monitoring
        callbacks = []
        monitor = None  # Initialize to None
        try:
            monitor = TrainingMonitor(
                output_dir=f"monitoring/{exp_name}",  # Separate dir per experiment
                use_validation=use_validation
            )
            callbacks.append(MonitoringCallback(
                monitor=monitor,
                use_validation=use_validation,
                eval_steps=args["eval_steps"],
                patience=args["early_stopping_patience"],
                overfitting_threshold=args.get("overfitting_threshold", 0.15)
            ))
            print("  Monitoring callback added")
        except Exception as e:
            print(f"  Monitoring callback failed: {e}")
            print("   Continuing without monitoring")
            monitor = None  # Ensure it's None if failed

        # Create trainer
        trainer = SFTTrainer(
            model=model,
            tokenizer=tokenizer,
            data_collator=UnslothVisionDataCollator(model, tokenizer),
            train_dataset=train_dataset,
            eval_dataset=val_dataset if use_validation else None,
            args=trainer_args,
            callbacks=callbacks,
        )

        # Train
        print("  Training started...")
        try:
            # stats = trainer.train()
            stats = unsloth_train(trainer)
            mem_after = torch.cuda.max_memory_reserved() / 1024**3

            result = {
                "experiment_name": exp_name,
                "configuration": combo,
                "validation_used": use_validation,
                "total_parameters": total,
                "trainable_parameters": trainable,
                "trainable_percentage": 100 * trainable / total,
                "training_time": stats.metrics["train_runtime"],
                "final_train_loss": stats.metrics["train_loss"],
                "memory_usage_gb": mem_after - mem_before,
                "timestamp": datetime.now().isoformat(),
                "status": "success"
            }

            print(f"   Training completed successfully!")
            print(f"   Final loss: {result['final_train_loss']:.4f}")
            print(f"   Time: {result['training_time']:.2f}s")

        except Exception as e:
            print(f"Training error: {e}")
            result = {
                "experiment_name": exp_name,
                "status": "failed",
                "error": str(e),
                "timestamp": datetime.now().isoformat()
            }

        self.results.append(result)
        self.trained_models[exp_name] = (model, tokenizer)

        if monitor is not None:
            try:
                monitor.plot_training_curves(exp_name)
                monitor.generate_summary_report(exp_name, result)
                print("Monitoring visualizations saved!")
            except Exception as e:
                print(f"Could not save visualizations: {e}")
                print("   But your model is safe!")

        return model, tokenizer, result

    def save_checkpoint_as_merged(self, checkpoint_path, output_name, load_in_4bit=True):
      try:
          print(f"\n{'='*60}")
          print(f"Loading checkpoint: {checkpoint_path}")

          # Load the checkpoint
          model, tokenizer = FastVisionModel.from_pretrained(
              checkpoint_path,
              load_in_4bit=load_in_4bit,
          )
          print(f"Checkpoint loaded")

          # Save as merged
          print(f"Saving as merged model: {output_name}")
          model.save_pretrained_merged(
              output_name,
              tokenizer,
              save_method="merged_16bit",
          )

          # Verify it worked
          success, error = self.verify_model_loadable(output_name)
          if success:
              print(f"Merged model saved and verified: {output_name}")
              return True
          else:
              print(f"Merged model saved but verification failed: {error}")
              return False

      except Exception as e:
          print(f"Error: {e}")
          return False

    def list_available_checkpoints(self, exp_name=None):
      if exp_name:
          search_path = f"outputs/{exp_name}/checkpoint-*"
      else:
          search_path = "outputs/*/checkpoint-*"

      checkpoints = sorted(Path(".").glob(search_path))

      if checkpoints:
          print(f"\nFound {len(checkpoints)} checkpoint(s):")
          for cp in checkpoints:
              # Extract step number
              step = cp.name.split("-")[-1]
              print(f"   Step {step}: {cp}")
      else:
          print("No checkpoints found")

      return [str(cp) for cp in checkpoints]

    def save_model_safe(self, model, tokenizer, save_dir="trained_model"):
        """Save model safely - will work even if training had issues"""
        try:
            print(f"\nSaving model to {save_dir}...")
            model.save_pretrained(save_dir)
            tokenizer.save_pretrained(save_dir)
            print(f"Model saved successfully to {save_dir}")
            return True
        except Exception as e:
            print(f"Error saving model: {e}")
            return False

    def save_merged_model(self, model, tokenizer, save_dir="trained_model_merged"):
        """Save merged model (base + LoRA)"""
        try:
            print(f"\nSaving merged model to {save_dir}...")
            model.save_pretrained_merged(
                save_dir,
                tokenizer,
                save_method="merged_16bit",
            )
            print(f"Merged model saved to {save_dir}")
            return True
        except Exception as e:
            print(f"Error saving merged model: {e}")
            return False

    def save_results(self, filename="training_results.json"):
        """Save training results"""
        with open(filename, "w") as f:
            json.dump(self.results, f, indent=2)
        print(f"   Results saved to {filename}")

    def push_to_hub(self, model_path, repo_name, token=None, private=False):
      try:
          print(f"\n{'='*30}")
          print(f"Pushing to HuggingFace Hub")
          print(f"Model: {model_path}")
          print(f"Repo: {args['hf_username']}/{repo_name}")
          print(f"{'='*30}")

          # Verify model exists
          if not Path(model_path).exists():
              print(f" Model path does not exist: {model_path}")
              return None

          # Create full repo ID
          repo_id = f"{args['hf_username']}/{repo_name}"

          # Create repository if it doesn't exist
          try:
              create_repo(
                  repo_id=repo_id,
                  token=token,
                  private=private,
                  exist_ok=True,
                  repo_type="model"
              )
              print(f" Repository ready: {repo_id}")
          except Exception as e:
              print(f" Repository creation: {e}")

          # Upload model
          api = HfApi()
          print(f" Uploading model files...")

          api.upload_folder(
              folder_path=model_path,
              repo_id=repo_id,
              repo_type="model",
              token=token,
          )

          model_url = f"https://huggingface.co/{repo_id}"
          print(f" Model uploaded successfully!")
          print(f" URL: {model_url}")

          return model_url

      except Exception as e:
          print(f" Upload failed: {e}")
          return None

    def verify_model_loadable(self, save_dir):
      try:
          from pathlib import Path

          # Check directory exists
          if not Path(save_dir).exists():
              return False, f"Directory {save_dir} does not exist"

          # Check for required files
          required_files = ["config.json"]
          missing = [f for f in required_files if not (Path(save_dir) / f).exists()]
          if missing:
              return False, f"Missing required files: {missing}"

          # Try to actually load the model
          print(f"  Verifying {save_dir} can be loaded...")
          test_model, test_tokenizer = FastVisionModel.from_pretrained(
              save_dir,
              load_in_4bit=True,
          )

          # Cleanup
          del test_model, test_tokenizer
          torch.cuda.empty_cache()

          return True, None

      except Exception as e:
          return False, str(e)

    def push_checkpoint_to_hub(self, checkpoint_path, repo_name, token=None, private=False,
                           save_as_merged=True):
      try:
          if save_as_merged:
              # Convert checkpoint to merged model first
              temp_model_name = f"temp_merged_for_upload"
              success = self.save_checkpoint_as_merged(
                  checkpoint_path=checkpoint_path,
                  output_name=temp_model_name
              )

              if not success:
                  print(" Failed to create merged model")
                  return None

              # Upload the merged version
              url = self.push_to_hub(temp_model_name, repo_name, token, private)

              # Cleanup temp files if needed
              # import shutil
              # shutil.rmtree(temp_model_name)

              return url
          else:
              # Upload checkpoint directly (LoRA adapter)
              return self.push_to_hub(checkpoint_path, repo_name, token, private)

      except Exception as e:
          print(f" Error: {e}")
          return None

#### Main training execution

In [None]:
print("\n" + "="*30)
print("MODEL TRAINING - Core Module")
print("="*30)

# Load base model
print("\n Loading base model...")
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Llama-3.2-11B-Vision-Instruct",
    load_in_4bit=True,
    use_gradient_checkpointing="unsloth",
)
print(" Base model loaded")

In [None]:
# Initialize trainer
trainer = ModelTrainer()

# Train all configurations

for combo in get_experiment_combinations():
  combination = combo

trained_model, trained_tokenizer, result = trainer.train_model(
    model=model,
    tokenizer=tokenizer,
    combo=combination,
    train_dataset=train_data,
    args=args,
    val_dataset=val_data
)

# Save all results
trainer.save_results()

print("\n" + "="*30)
print(" ALL TRAINING COMPLETED!")
print("="*30)

#### Save model

In [None]:
# Save immediately after training
exp_name = trainer._make_name(combo)
trainer.save_model_safe(trained_model, trained_tokenizer, f"model_{exp_name}")
trainer.save_merged_model(trained_model, trained_tokenizer, f"model_{exp_name}_merged")

In [None]:
# 1. List all checkpoints
# exp_name = "VTrue_LTrue_ATrue_MTrue"
checkpoints = trainer.list_available_checkpoints(exp_name)

# 2. Save a specific checkpoint as merged model
trainer.save_checkpoint_as_merged(
    checkpoint_path="outputs/VTrue_LTrue_ATrue_MTrue/checkpoint-300",
    output_name="best_model_before_overfit_step300_merged"
)

# # 3. Or save multiple checkpoints to compare
# for step in [300, 500, 700]:
#     trainer.save_checkpoint_as_merged(
#         checkpoint_path=f"outputs/{exp_name}/checkpoint-{step}",
#         output_name=f"model_{exp_name}_step{step}_merged"
#     )

#### Verify if model saved perfectly

In [None]:
success, error = trainer.verify_model_loadable("/content/drive/MyDrive/main/best_model_before_overfit_step300_merged")
if success:
    print("Merged model saved and verified")
else:
    print("Merged model saved but verification failed")

  Verifying /content/drive/MyDrive/main/best_model_before_overfit_step300_merged can be loaded...
==((====))==  Unsloth 2025.10.1: Fast Mllama patching. Transformers: 4.55.4.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

#### Push model to Hugging Face account

In [None]:
# Push the final merged model
trainer.push_to_hub(
    model_path="/content/drive/MyDrive/main/best_model_before_overfit_step300_merged",
    repo_name="llama-vision-radiology-final",
    token=args["token"],  # Or None to use default
    private=False  # Set True if you want private repo
)

# # Or push the LoRA adapter
# trainer.push_to_hub(
#     model_path=f"model_{exp_name}",
#     repo_name="llama-vision-radiology-lora",
#     token=args["token"],
#     private=False
# )