In [None]:
from magistral_finetuning import MagistralFineTuningConfig, MagistralFineTuning, ThinkingMode

In [None]:
config = MagistralFineTuningConfig(
    model_name="mistralai/Magistral-Small-2506",
    train_file="data/mixed_distributed.jsonl", 
    output_dir="./model/magistral_mixed_v1_fixed",
    thinking_mode=ThinkingMode.MIXED,
    batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    num_epochs=1,
    max_length=1024,
    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    gradient_checkpointing=True,
)

In [None]:
config.print_config()

In [None]:
def setup_alternative_trainer(finetuner, train_data):
    """Alternative trainer setup that avoids SFTTrainer pad token issues"""
    
    from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
    from datasets import Dataset
    
    # Prepare dataset manually
    train_dataset = finetuner.prepare_dataset(train_data)
    
    # Create data collator that handles Magistral tokenizer properly
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=finetuner.tokenizer,
        mlm=False,
        pad_to_multiple_of=None,  # Let it handle padding naturally
    )
    
    # Training arguments (simpler approach)
    training_args = TrainingArguments(
        output_dir=config.output_dir,
        num_train_epochs=config.num_epochs,
        per_device_train_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        gradient_checkpointing=config.gradient_checkpointing,
        learning_rate=config.learning_rate,
        lr_scheduler_type=config.lr_scheduler_type,
        warmup_ratio=config.warmup_ratio,
        weight_decay=0.01,
        logging_steps=50,
        save_strategy="epoch",
        seed=42,
        bf16=True,
        tf32=True,
        optim="adamw_bnb_8bit",
        dataloader_num_workers=config.dataloader_num_workers,
        dataloader_pin_memory=True,
        remove_unused_columns=False,
    )
    
    # Tokenize dataset for standard Trainer
    def tokenize_function(examples):
        return finetuner.tokenizer(
            examples["text"],
            truncation=True,
            padding=False,  # Let data collator handle padding
            max_length=config.max_length,
            return_tensors=None,
        )
    
    tokenized_dataset = train_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=train_dataset.column_names,
    )
    
    # Create standard trainer
    trainer = Trainer(
        model=finetuner.model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )
    
    return trainer


In [None]:
finetuner = MagistralFineTuning(config)

In [None]:
train_data = finetuner.load_jsonl(config.train_file)
print(f"Training samples: {len(train_data)}")

In [None]:
print("Setting up model...")
finetuner.setup_model()

In [None]:
print("Setting up trainer...")
try:
    finetuner.setup_trainer(train_data)
    print("✅ SFTTrainer setup successful!")
    use_alternative = False
except Exception as e:
    print(f"❌ SFTTrainer failed: {e}")
    print("Using alternative trainer setup...")
    finetuner.trainer = setup_alternative_trainer(finetuner, train_data)
    print("✅ Alternative trainer setup successful!")
    use_alternative = True

In [None]:
# Monitor memory and start training
import torch
import gc
torch.cuda.empty_cache()
gc.collect()
print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")

In [None]:
print("Starting training...")
finetuner.train()

In [None]:
print("Saving model...")
finetuner.save_model()

In [None]:
trainer_type = "Alternative Trainer" if use_alternative else "SFTTrainer"
print(f"\n✅ Mixed training complete with {trainer_type}!")
print(f"Model saved to: {config.output_dir}")
print(f"Training mode: {config.thinking_mode.value}")
print(f"Effective batch size: {config.effective_batch_size}")
print(f"Max sequence length: {config.max_length}")
print(f"Final learning rate: {config.learning_rate}")