# Flux Impressionism Fine-Tuning Training

This notebook implements the LoRA fine-tuning process for the Flux.1 Dev model with int4 + bf16 quantization.

In [None]:
# Install required packages if not already installed
!pip install -q -r ../requirements.txt

In [None]:
import os
import sys
import yaml
import torch
from pathlib import Path

# Add src to path
sys.path.append("..")
from src.training.trainer import FluxLoRATrainer

In [None]:
import os
import torch
import logging
from pathlib import Path
from typing import Optional, Dict, Any
from transformers import PreTrainedModel, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from diffusers import FluxModel, FluxScheduler
from datasets import Dataset

class FluxLoRATrainer:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.setup_logging()
        self.setup_accelerator()
        
    def setup_logging(self):
        """Configure logging for the training process."""
        logging.basicConfig(
            format="%(asctime)s - %(levelname)s - %(message)s",
            level=logging.INFO,
            handlers=[
                logging.StreamHandler(),
                logging.FileHandler(os.path.join(self.config["output"]["logging_dir"], "training.log"))
            ]
        )
        self.logger = logging.getLogger(__name__)

    def setup_accelerator(self):
        """Configure accelerator for distributed training."""
        project_config = ProjectConfiguration(
            project_dir=self.config["output"]["output_dir"],
            logging_dir=self.config["output"]["logging_dir"]
        )
        
        self.accelerator = Accelerator(
            gradient_accumulation_steps=self.config["training"]["gradient_accumulation_steps"],
            mixed_precision="bf16" if self.config["mixed_precision"]["enabled"] else "no",
            project_config=project_config,
            log_with=self.config["output"]["report_to"]
        )

    def load_model(self):
        """Load and prepare the Flux model with LoRA configuration."""
        # Load base model with quantization
        model = FluxModel.from_pretrained(
            self.config["model"]["pretrained_model_name_or_path"],
            torch_dtype=getattr(torch, self.config["model"]["torch_dtype"]),
            load_in_4bit=self.config["model"]["load_in_4bit"],
            use_bf16_4bit=self.config["model"]["use_bf16_4bit"],
            bnb_4bit_compute_dtype=getattr(torch, self.config["model"]["bnb_4bit_compute_dtype"]),
            bnb_4bit_quant_type=self.config["model"]["bnb_4bit_quant_type"],
            bnb_4bit_use_double_quant=self.config["model"]["bnb_4bit_use_double_quant"]
        )

        # Prepare model for k-bit training
        model = prepare_model_for_kbit_training(
            model,
            use_gradient_checkpointing=self.config["system"]["gradient_checkpointing"]
        )

        # Configure LoRA
        lora_config = LoraConfig(
            r=self.config["lora"]["rank"],
            lora_alpha=self.config["lora"]["alpha"],
            target_modules=self.config["lora"]["target_modules"],
            lora_dropout=self.config["lora"]["lora_dropout"],
            bias=self.config["lora"]["bias"]
        )

        # Apply LoRA
        self.model = get_peft_model(model, lora_config)
        
        if self.config["system"]["enable_xformers_memory_efficient_attention"]:
            self.model.enable_xformers_memory_efficient_attention()
        
        if self.config["system"]["use_flash_attention_2"]:
            self.model.enable_flash_attention_2()

        return self.model

    def prepare_dataset(self, dataset: Dataset):
        """Prepare dataset for training."""
        if self.config["dataset"]["max_train_samples"]:
            dataset = dataset.select(range(self.config["dataset"]["max_train_samples"]))
        
        # Add any necessary preprocessing here
        return dataset

    def train(self, dataset: Dataset):
        """Execute the training loop."""
        # Prepare training arguments
        training_args = TrainingArguments(
            output_dir=self.config["output"]["output_dir"],
            per_device_train_batch_size=self.config["training"]["train_batch_size"],
            gradient_accumulation_steps=self.config["training"]["gradient_accumulation_steps"],
            learning_rate=self.config["training"]["learning_rate"],
            lr_scheduler_type=self.config["training"]["lr_scheduler"],
            num_train_epochs=self.config["training"]["num_train_epochs"],
            max_steps=self.config["training"]["max_train_steps"],
            warmup_steps=self.config["training"]["lr_warmup_steps"],
            save_steps=self.config["training"]["checkpointing_steps"],
            save_total_limit=self.config["training"]["save_total_limit"],
            logging_steps=10,
            remove_unused_columns=False,
            seed=self.config["training"]["seed"],
            bf16=self.config["mixed_precision"]["enabled"],
            report_to=self.config["output"]["report_to"]
        )

        # Initialize trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=dataset,
            data_collator=self.collate_fn
        )

        # Start training
        self.logger.info("Starting training...")
        trainer.train()
        
        # Save final model
        self.save_model()

    def save_model(self, path: Optional[str] = None):
        """Save the trained model."""
        save_path = path or os.path.join(self.config["output"]["output_dir"], "final_model")
        self.model.save_pretrained(save_path)
        self.logger.info(f"Model saved to {save_path}")

    @staticmethod
    def collate_fn(examples):
        """Collate function for batch preparation."""
        # Implement custom collate function here
        pass 

In [None]:
# Load configuration
with open("../configs/default.yaml", "r") as f:
    config = yaml.safe_load(f)

# Create output directories
os.makedirs(config["output"]["output_dir"], exist_ok=True)
os.makedirs(config["output"]["logging_dir"], exist_ok=True)

In [None]:
# Initialize trainer
trainer = FluxLoRATrainer(config)

# Load and prepare model
model = trainer.load_model()

In [None]:
# Load dataset
from datasets import load_dataset

dataset = load_dataset(config["dataset"]["name"])
train_dataset = trainer.prepare_dataset(dataset["train"])

print(f"Training dataset size: {len(train_dataset)}")

In [None]:
# Start training
trainer.train(train_dataset)

## Training Complete

The fine-tuned model has been saved to the output directory. You can now use it for inference or upload it to the Hugging Face Hub.