In [None]:
%pip install -e .

In [None]:
# Update this for your data path.
instance_data_dir = "/Volumes/ml/datasets/test_datasets/single_image_dataset"
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
# Your public model name after it's pushed to the hub.
hub_model_id = "simpletuner-lora"
tracker_project_name = "flux-training"

# Validation prompt
validation_prompt = "A photo-realistic image of a cat"

train_batch_size = 1
learning_rate = 1e-4

# choices: int8-quanto, fp8-quanto, no_change (mac and a100/h100 users get int4 and int2 as well)
base_model_precision = "no_change"

In [None]:
lycoris_config = {
    "algo": "lokr",
    "multiplier": 1.0,
    "linear_dim": 10000,
    "linear_alpha": 1,
    "factor": 12,
    "apply_preset": {
        "target_module": [
            "Attention",
            "FeedForward"
        ],
        "module_algo_map": {
            "Attention": {
                "factor": 12
            },
            "FeedForward": {
                "factor": 6
            }
        }
    }
}
# write to config/lycoris_config.json
import json
with open("config/lycoris_config.json", "w") as f:
    json.dump(lycoris_config, f)

In [None]:
training_config = {
    "mixed_precision":"bf16",
    "model_type":"lora",
    "pretrained_model_name_or_path":pretrained_model_name_or_path,
    "gradient_checkpointing":True,
    "cache_dir": "cache",
    "set_grads_to_none":True,
    "gradient_accumulation_steps":1,
    "resume_from_checkpoint":"latest",
    "snr_gamma":5,
    "num_train_epochs":0,
    "max_train_steps":10000,
    "metadata_update_interval":65,
    "optimizer":"adamw_bf16",
    "learning_rate":learning_rate,
    "lr_scheduler":"polynomial",
    "seed":42,
    "lr_warmup_steps":100,
    "output_dir":"output/models",
    "aspect_bucket_rounding":2,
    "inference_scheduler_timestep_spacing":"trailing",
    "training_scheduler_timestep_spacing":"trailing",
    "report_to":"wandb",
    "lr_end":1e-8,
    "compress_disk_cache":True,
    "push_to_hub":True,
    "hub_model_id":hub_model_id,
    "push_checkpoints_to_hub":True,
    "model_family":"flux",
    "disable_benchmark":False,
    "train_batch":train_batch_size,
    "max_workers":32,
    "read_batch_size":25,
    "write_batch_size":64,
    "caption_dropout_probability":0.1,
    "torch_num_threads":8,
    "image_processing_batch_size":32,
    "vae_batch_size":4,
    "validation_prompt":validation_prompt,
    "num_validation_images":1,
    "validation_num_inference_steps":20,
    "validation_seed":42,
    "minimum_image_size":0,
    "resolution":1024,
    "validation_resolution":"1024x1024",
    "resolution_type":"pixel_area",
    "lycoris_config":"config/lycoris_config.json",
    "lora_type":"lycoris",
    "base_model_precision":base_model_precision,
    "checkpointing_steps":500,
    "checkpoints_total_limit":5,
    "validation_steps":500,
    "tracker_run_name":hub_model_id,
    "tracker_project_name":tracker_project_name,
    "validation_guidance":3.0,
    "validation_guidance_real":1.0,
    "validation_guidance_rescale":0.0,
    "validation_negative_prompt":"blurry, cropped, ugly",
}
# write to config/config.json
with open("config/config.json", "w") as f:
    json.dump(training_config, f, indent=4)

In [None]:
dataloader_config = [
    {
        "id": "my-dataset-512",
        "type": "local",
        "instance_data_dir": instance_data_dir,
        "crop": False,
        "crop_style": "random",
        "minimum_image_size": 128,
        "resolution": 512,
        "resolution_type": "pixel_area",
        "repeats": "4",
        "metadata_backend": "discovery",
        "caption_strategy": "filename",
        "cache_dir_vae": "cache/vae-512"
    },
    {
        "id": "my-dataset-1024",
        "type": "local",
        "instance_data_dir": instance_data_dir,
        "crop": False,
        "crop_style": "random",
        "minimum_image_size": 128,
        "resolution": 1024,
        "resolution_type": "pixel_area",
        "repeats": "4",
        "metadata_backend": "discovery",
        "caption_strategy": "filename",
        "cache_dir_vae": "cache/vae-1024"
    },
    {
        "id": "my-dataset-512-crop",
        "type": "local",
        "instance_data_dir": instance_data_dir,
        "crop": False,
        "crop_style": "random",
        "minimum_image_size": 128,
        "resolution": 512,
        "resolution_type": "pixel_area",
        "repeats": "4",
        "metadata_backend": "discovery",
        "caption_strategy": "filename",
        "cache_dir_vae": "cache/vae-512-crop"
    },
    {
        "id": "my-dataset-1024-crop",
        "type": "local",
        "instance_data_dir": instance_data_dir,
        "crop": False,
        "crop_style": "random",
        "minimum_image_size": 128,
        "resolution": 1024,
        "resolution_type": "pixel_area",
        "repeats": "4",
        "metadata_backend": "discovery",
        "caption_strategy": "filename",
        "cache_dir_vae": "cache/vae-1024-crop"
    },
    {
        "id": "text-embed-cache",
        "dataset_type": "text_embeds",
        "default": True,
        "type": "local",
        "cache_dir": "cache/text"
    }
]
# write to config/multidatabackend.json
import json
with open("config/multidatabackend.json", "w") as f:
    json.dump(dataloader_config, f)

In [None]:
import logging
import logging.config
from os import environ

logging.getLogger("DeepSpeed").setLevel("ERROR")
logging.getLogger("torch.distributed.elastic.multiprocessing.redirects").setLevel("ERROR")

logging.config.dictConfig({
    "version": 1,
    "disable_existing_loggers": True,
})

environ["ACCELERATE_LOG_LEVEL"] = "WARNING"

from simpletuner.helpers import log_format
from simpletuner.helpers.training.multi_process import _get_rank
from simpletuner.helpers.training.state_tracker import StateTracker
from simpletuner.helpers.training.trainer import Trainer

logger = logging.getLogger("SimpleTuner")
logger.setLevel(environ.get("SIMPLETUNER_LOG_LEVEL", "INFO") if _get_rank() == 0 else "ERROR")


In [None]:
from simpletuner.helpers.configuration.json_file import normalize_args
import os
os.environ['CONFIG_BACKEND'] = 'cmd'
os.environ['ENV'] = 'default'
StateTracker.set_config_path('config/')
loaded_config = normalize_args(training_config)

In [None]:
import multiprocessing

try:
    multiprocessing.set_start_method('fork')
except Exception as e:
    logger.error(
        "Failed to set the multiprocessing start method to 'fork'. Unexpected behaviour such as high memory overhead or poor performance may result."
        f"\nError: {e}"
    )

try:
    trainer = Trainer(
        loaded_config,
        exit_on_error=True,
    )
except Exception as e:
    import traceback
    logger.error(f"Failed to create Trainer: {e}, {traceback.format_exc()}")
    raise e


In [None]:
try:
    trainer.configure_webhook()
    trainer.init_noise_schedule()
    trainer.init_seed()

    trainer.init_huggingface_hub()
except Exception as e:
    logger.error(f"Failed to configure Trainer: {e}")
    raise e

In [None]:
try:
    trainer.init_preprocessing_models()
    trainer.init_precision(preprocessing_models_only=True)
except Exception as e:
    logger.error(f"Failed to initialize preprocessing models: {e}")
    raise e


In [None]:
try:
    trainer.init_data_backend()
except Exception as e:
    logger.error(f"Failed to initialize data backend: {e}")
    raise e

In [None]:
trainer.init_unload_text_encoder()

In [None]:
trainer.init_unload_vae()

In [None]:
trainer.init_load_base_model()

In [None]:
trainer.init_controlnet_model()


In [None]:
trainer.init_tread_model()


In [None]:
trainer.init_precision()


In [None]:
trainer.init_freeze_models()


In [None]:
trainer.init_trainable_peft_adapter()


In [None]:
trainer.init_ema_model()


In [None]:
trainer.init_precision(ema_only=True)


In [None]:
trainer.move_models(destination="accelerator")


In [None]:
trainer.init_distillation()


In [None]:
trainer.init_validations()


In [None]:
trainer.enable_sageattention_inference()


In [None]:
trainer.init_benchmark_base_model()


In [None]:
trainer.disable_sageattention_inference()


In [None]:
trainer.resume_and_prepare()


In [None]:
trainer.init_trackers()


In [None]:
trainer.train()
