# Experiment: Flux LoRA Vast.ai ai-toolkit Workflow

Objective:
- Train a personal Flux LoRA end-to-end on Vast.ai with reproducible controls.
- Keep all operational knobs in one place for fast iteration on rented GPUs.

Success criteria:
- Dataset passes quality checks.
- Training launches (or resumes) from generated config.
- Final artifacts are bundled for download.


In [None]:
from __future__ import annotations

import json
import os
import shutil
import subprocess
import sys
import tarfile
import textwrap
import time
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any


def run(cmd: list[str], cwd: str | None = None, check: bool = True) -> subprocess.CompletedProcess:
    print("$", " ".join(cmd))
    return subprocess.run(cmd, cwd=cwd, check=check, text=True, capture_output=False)


def run_capture(cmd: list[str], cwd: str | None = None) -> str:
    print("$", " ".join(cmd))
    out = subprocess.check_output(cmd, cwd=cwd, text=True)
    return out.strip()


def ensure_package(package: str) -> None:
    try:
        __import__(package)
    except Exception:
        run([sys.executable, "-m", "pip", "install", package])

print("Python:", sys.version)


## Prerequisites

- Vast.ai Linux instance with NVIDIA GPU.
- Dataset available as folder or zip.
- Permission to train the base model and dataset content.
- You are responsible for legal/TOS compliance, especially if training NSFW content.


In [None]:
# Central configuration (edit this cell first)
PROJECT_NAME = "my_flux_lora"
BASE_MODEL_ID = "black-forest-labs/FLUX.1-dev"

# Paths: auto-detect common Vast mount points, then fallback.
VAST_AUTO_MOUNT_PATHS = [
    "/workspace",
    "/root/autodl-tmp",
    "/data",
]
USER_WORKSPACE_FALLBACK = "/workspace"

# Dataset input
DATASET_ZIP = ""   # e.g. /workspace/datasets/myset.zip
DATASET_DIR = ""   # e.g. /workspace/datasets/myset (if already extracted)
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"}

# Captioning
AUTO_CAPTION = False
CAPTION_MODEL = "nlpconnect/vit-gpt2-image-captioning"
CAPTION_FILE_EXT = ".txt"

# Safety and policy knob
ALLOW_NSFW = True

# Quality gate
QUALITY_GATE = True
MIN_IMAGE_COUNT = 20
MIN_RESOLUTION = 512
REQUIRE_CAPTIONS = True

# Training knobs
EPOCHS = 10
LEARNING_RATE = 1e-4
BATCH_SIZE = 1
GRAD_ACCUM = 4
NETWORK_DIM = 32
NETWORK_ALPHA = 16
MIXED_PRECISION = "bf16"
SAVE_EVERY_N_STEPS = 200
MAX_TRAIN_STEPS = 2000

# Resume support
RESUME_FROM_CHECKPOINT = ""   # path to checkpoint; empty means fresh training

# Logging
WANDB_ENABLED = False
WANDB_PROJECT = "flux-lora"
WANDB_RUN_NAME = PROJECT_NAME

# Cost guard (estimate only)
COST_GUARD = True
GPU_HOURLY_USD = 0.55
ASSUMED_STEPS_PER_SEC = 0.65

# ai-toolkit repo
AI_TOOLKIT_REPO = "https://github.com/ostris/ai-toolkit.git"
AI_TOOLKIT_DIRNAME = "ai-toolkit"
TRAIN_COMMAND_OVERRIDE = ""  # optional full command string

assert ALLOW_NSFW in {True, False}
print("Config loaded for:", PROJECT_NAME)


In [None]:
# Resolve workspace and project paths
mount_root = None
for candidate in VAST_AUTO_MOUNT_PATHS:
    if Path(candidate).exists():
        mount_root = Path(candidate)
        break

if mount_root is None:
    mount_root = Path(USER_WORKSPACE_FALLBACK)

ROOT = mount_root
WORKDIR = ROOT / PROJECT_NAME
DATA_DIR = WORKDIR / "data"
TRAIN_DIR = WORKDIR / "train"
CAPTION_DIR = TRAIN_DIR / "captions"
OUTPUT_DIR = WORKDIR / "outputs"
LOG_DIR = WORKDIR / "logs"
CONFIG_DIR = WORKDIR / "config"
PUBLISH_DIR = WORKDIR / "publish"

for p in [WORKDIR, DATA_DIR, TRAIN_DIR, CAPTION_DIR, OUTPUT_DIR, LOG_DIR, CONFIG_DIR, PUBLISH_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print("ROOT:", ROOT)
print("WORKDIR:", WORKDIR)


In [None]:
# Runtime preflight
print("== Runtime preflight ==")

try:
    run(["nvidia-smi"])
except Exception as e:
    raise RuntimeError("nvidia-smi failed. Confirm GPU-enabled Vast instance.") from e

try:
    import torch
    print("torch:", torch.__version__)
    print("cuda available:", torch.cuda.is_available())
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA unavailable in torch runtime.")
    print("gpu:", torch.cuda.get_device_name(0))
    total_mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
    print(f"vram_gb: {total_mem_gb:.2f}")
except ImportError:
    print("torch not installed yet; bootstrap cell will install dependencies.")

usage = shutil.disk_usage(str(ROOT))
free_gb = usage.free / (1024 ** 3)
print(f"disk_free_gb: {free_gb:.2f}")
if free_gb < 20:
    raise RuntimeError("Less than 20GB free disk. Increase storage before training.")


In [None]:
# Environment bootstrap
AI_TOOLKIT_DIR = ROOT / AI_TOOLKIT_DIRNAME
VENV_DIR = ROOT / ".venv_flux"

if not AI_TOOLKIT_DIR.exists():
    run(["git", "clone", AI_TOOLKIT_REPO, str(AI_TOOLKIT_DIR)])
else:
    run(["git", "-C", str(AI_TOOLKIT_DIR), "pull", "--ff-only"])

if not VENV_DIR.exists():
    run([sys.executable, "-m", "venv", str(VENV_DIR)])

PYTHON_BIN = str(VENV_DIR / "bin" / "python")
PIP_BIN = [PYTHON_BIN, "-m", "pip"]

run(PIP_BIN + ["install", "--upgrade", "pip", "setuptools", "wheel"])
req = AI_TOOLKIT_DIR / "requirements.txt"
if req.exists():
    run(PIP_BIN + ["install", "-r", str(req)])
else:
    print("requirements.txt not found; install your fork dependencies manually if needed.")

run([PYTHON_BIN, "-c", "import yaml; print('pyyaml ok')"])
print("Bootstrap complete.")


## Entrypoint Diagnostics

Run this once after bootstrap. It detects likely ai-toolkit train/inference entrypoints and prints copy-paste commands for your fork.


In [None]:
# Entrypoint diagnostics
import os
from pathlib import Path

print("AI_TOOLKIT_DIR:", AI_TOOLKIT_DIR)
print("PYTHON_BIN:", PYTHON_BIN)

candidates = [
    "toolkit.train",
    "ai_toolkit.train",
    "toolkit.inference",
    "ai_toolkit.inference",
]

print("\nModule availability:")
for mod in candidates:
    cmd = f"{PYTHON_BIN} -c \"import importlib.util; print(importlib.util.find_spec('{mod}') is not None)\""
    ok = subprocess.check_output(cmd, shell=True, text=True).strip()
    print(f"- {mod}: {ok}")

print("\nCommon script files in repo root:")
for name in ["run.py", "train.py", "inference.py", "main.py"]:
    p = AI_TOOLKIT_DIR / name
    print(f"- {name}:", p.exists())

print("\nSuggested train commands:")
print(f"1) {PYTHON_BIN} -m toolkit.train {config_path}")
print(f"2) {PYTHON_BIN} -m ai_toolkit.train {config_path}")
print(f"3) {PYTHON_BIN} run.py train {config_path}")
print(f"4) {PYTHON_BIN} train.py {config_path}")

print("\nSuggested inference commands:")
print(f"1) {PYTHON_BIN} -m toolkit.inference --base {BASE_MODEL_ID} --lora <lora.safetensors> --prompt 'portrait, cinematic lighting' --out {OUTPUT_DIR / 'samples'}")
print(f"2) {PYTHON_BIN} -m ai_toolkit.inference --base {BASE_MODEL_ID} --lora <lora.safetensors> --prompt 'portrait, cinematic lighting' --out {OUTPUT_DIR / 'samples'}")
print(f"3) {PYTHON_BIN} inference.py --base {BASE_MODEL_ID} --lora <lora.safetensors> --prompt 'portrait, cinematic lighting' --out {OUTPUT_DIR / 'samples'}")


## Dataset and captions

This section ingests your dataset, creates caption sidecars, and checks quality before any training spend.


In [None]:
# Dataset ingest and normalization
import zipfile

if DATASET_ZIP:
    zsrc = Path(DATASET_ZIP)
    if not zsrc.exists():
        raise FileNotFoundError(f"DATASET_ZIP not found: {zsrc}")
    with zipfile.ZipFile(zsrc, "r") as zf:
        zf.extractall(DATA_DIR)

if DATASET_DIR:
    src = Path(DATASET_DIR)
    if not src.exists():
        raise FileNotFoundError(f"DATASET_DIR not found: {src}")
    for item in src.rglob("*"):
        if item.is_file() and item.suffix.lower() in IMAGE_EXTS:
            dst = TRAIN_DIR / item.name
            if not dst.exists():
                shutil.copy2(item, dst)
else:
    for item in DATA_DIR.rglob("*"):
        if item.is_file() and item.suffix.lower() in IMAGE_EXTS:
            dst = TRAIN_DIR / item.name
            if not dst.exists():
                shutil.copy2(item, dst)

images = sorted([p for p in TRAIN_DIR.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS])
print("images_found:", len(images))
if not images:
    raise RuntimeError("No images found. Set DATASET_ZIP or DATASET_DIR correctly.")


In [None]:
# Captioning (optional auto-caption, always manual-editable)
if AUTO_CAPTION:
    from PIL import Image
    from transformers import pipeline

    cap = pipeline("image-to-text", model=CAPTION_MODEL)
    for img_path in images:
        txt_path = CAPTION_DIR / f"{img_path.stem}{CAPTION_FILE_EXT}"
        if txt_path.exists():
            continue
        out = cap(Image.open(img_path))
        text = out[0].get("generated_text", "").strip()
        txt_path.write_text(text + "\n", encoding="utf-8")
else:
    for img_path in images:
        txt_path = CAPTION_DIR / f"{img_path.stem}{CAPTION_FILE_EXT}"
        if not txt_path.exists():
            # Starter caption for manual editing
            txt_path.write_text(f"{PROJECT_NAME}, subject, high detail\n", encoding="utf-8")

print("caption_files:", len(list(CAPTION_DIR.glob(f"*{CAPTION_FILE_EXT}"))))
print("Edit caption files in:", CAPTION_DIR)


In [None]:
# Quality gate
from PIL import Image

if QUALITY_GATE:
    if len(images) < MIN_IMAGE_COUNT:
        raise RuntimeError(f"Quality gate failed: {len(images)} images < MIN_IMAGE_COUNT({MIN_IMAGE_COUNT})")

    too_small = []
    for p in images:
        w, h = Image.open(p).size
        if min(w, h) < MIN_RESOLUTION:
            too_small.append((p.name, w, h))

    caption_files = list(CAPTION_DIR.glob(f"*{CAPTION_FILE_EXT}"))
    caption_names = {p.stem for p in caption_files}
    missing_caps = [p.name for p in images if p.stem not in caption_names]

    print("too_small_count:", len(too_small))
    print("missing_caption_count:", len(missing_caps))

    if too_small:
        raise RuntimeError(f"Quality gate failed: {len(too_small)} images below MIN_RESOLUTION={MIN_RESOLUTION}")
    if REQUIRE_CAPTIONS and missing_caps:
        raise RuntimeError(f"Quality gate failed: {len(missing_caps)} missing captions")

print("Quality gate passed.")


In [None]:
# Generate training config (YAML)
import yaml

config = {
    "job": {
        "name": PROJECT_NAME,
        "type": "lora",
    },
    "meta": {
        "allow_nsfw": ALLOW_NSFW,
        "created_at": datetime.utcnow().isoformat() + "Z",
    },
    "model": {
        "base_model": BASE_MODEL_ID,
    },
    "data": {
        "train_images": str(TRAIN_DIR),
        "captions": str(CAPTION_DIR),
    },
    "train": {
        "epochs": EPOCHS,
        "learning_rate": LEARNING_RATE,
        "batch_size": BATCH_SIZE,
        "grad_accum": GRAD_ACCUM,
        "save_every_n_steps": SAVE_EVERY_N_STEPS,
        "max_train_steps": MAX_TRAIN_STEPS,
        "mixed_precision": MIXED_PRECISION,
    },
    "network": {
        "dim": NETWORK_DIM,
        "alpha": NETWORK_ALPHA,
    },
    "output": {
        "dir": str(OUTPUT_DIR),
        "logs": str(LOG_DIR),
    },
}

if RESUME_FROM_CHECKPOINT:
    config["train"]["resume_from_checkpoint"] = RESUME_FROM_CHECKPOINT

if WANDB_ENABLED:
    config["logging"] = {
        "wandb": {
            "project": WANDB_PROJECT,
            "run_name": WANDB_RUN_NAME,
        }
    }

config_path = CONFIG_DIR / "train_flux_lora.yaml"
config_path.write_text(yaml.safe_dump(config, sort_keys=False), encoding="utf-8")
print("Wrote:", config_path)
print(config_path.read_text(encoding="utf-8")[:1200])


In [None]:
# Cost guard estimate
if COST_GUARD:
    est_steps = MAX_TRAIN_STEPS
    est_seconds = est_steps / max(ASSUMED_STEPS_PER_SEC, 1e-6)
    est_hours = est_seconds / 3600
    est_usd = est_hours * GPU_HOURLY_USD
    print(f"estimated_steps: {est_steps}")
    print(f"estimated_hours: {est_hours:.2f}")
    print(f"estimated_cost_usd: {est_usd:.2f}")


In [None]:
# Launch training
train_log = LOG_DIR / "train.log"

if RESUME_FROM_CHECKPOINT and not Path(RESUME_FROM_CHECKPOINT).exists():
    raise FileNotFoundError(f"RESUME_FROM_CHECKPOINT not found: {RESUME_FROM_CHECKPOINT}")

def build_candidate_commands() -> list[str]:
    if TRAIN_COMMAND_OVERRIDE.strip():
        return [TRAIN_COMMAND_OVERRIDE.strip()]

    candidates = [
        f"{PYTHON_BIN} -m toolkit.train {config_path}",
        f"{PYTHON_BIN} -m ai_toolkit.train {config_path}",
        f"{PYTHON_BIN} run.py train {config_path}",
        f"{PYTHON_BIN} train.py {config_path}",
    ]
    return candidates

candidates = build_candidate_commands()
print("candidate_train_commands:")
for c in candidates:
    print(" -", c)

chosen_command = None
if TRAIN_COMMAND_OVERRIDE.strip():
    chosen_command = candidates[0]
else:
    for cmd in candidates:
        # Cheap pre-check for module entrypoints before full launch
        if "-m toolkit.train" in cmd:
            check_cmd = f"{PYTHON_BIN} -c \"import importlib.util; print(importlib.util.find_spec('toolkit.train') is not None)\""
            ok = subprocess.check_output(check_cmd, shell=True, text=True).strip().endswith("True")
            if ok:
                chosen_command = cmd
                break
            continue
        if "-m ai_toolkit.train" in cmd:
            check_cmd = f"{PYTHON_BIN} -c \"import importlib.util; print(importlib.util.find_spec('ai_toolkit.train') is not None)\""
            ok = subprocess.check_output(check_cmd, shell=True, text=True).strip().endswith("True")
            if ok:
                chosen_command = cmd
                break
            continue
        # Script fallback
        token = cmd.split()[1] if len(cmd.split()) > 1 else ""
        if token and (AI_TOOLKIT_DIR / token).exists():
            chosen_command = cmd
            break

if not chosen_command:
    raise RuntimeError(
        "No known ai-toolkit train entrypoint detected. Set TRAIN_COMMAND_OVERRIDE explicitly, "
        "e.g. '<venv_python> -m toolkit.train <config_path>'"
    )

print("train_command:", chosen_command)
print("log_file:", train_log)

with open(train_log, "a", encoding="utf-8") as lf:
    lf.write(f"\n\n===== START {datetime.utcnow().isoformat()}Z =====\n")
    proc = subprocess.Popen(chosen_command, shell=True, cwd=str(AI_TOOLKIT_DIR), stdout=lf, stderr=lf)

print("PID:", proc.pid)
print("Use the next cell to tail logs.")


In [None]:
# Tail latest logs (run repeatedly)
train_log = LOG_DIR / "train.log"
if not train_log.exists():
    raise FileNotFoundError(f"No log file yet: {train_log}")

lines = train_log.read_text(encoding="utf-8", errors="ignore").splitlines()
for line in lines[-80:]:
    print(line)


In [None]:
# Validation / quick inference hook
# Auto-detect common inference entrypoints; fallback to TRAIN_COMMAND_OVERRIDE style manual override.
SAMPLES_DIR = OUTPUT_DIR / "samples"
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)

final_loras = sorted(OUTPUT_DIR.rglob("*.safetensors"), key=lambda p: p.stat().st_mtime)
if not final_loras:
    print("No LoRA artifact yet. Finish training before inference.")
else:
    final_lora = final_loras[-1]
    prompt = f"{PROJECT_NAME}, cinematic portrait, high detail"

    candidates = [
        f"{PYTHON_BIN} -m toolkit.inference --base {BASE_MODEL_ID} --lora {final_lora} --prompt \"{prompt}\" --out {SAMPLES_DIR}",
        f"{PYTHON_BIN} -m ai_toolkit.inference --base {BASE_MODEL_ID} --lora {final_lora} --prompt \"{prompt}\" --out {SAMPLES_DIR}",
        f"{PYTHON_BIN} inference.py --base {BASE_MODEL_ID} --lora {final_lora} --prompt \"{prompt}\" --out {SAMPLES_DIR}",
    ]

    chosen = None
    for cmd in candidates:
        if "-m toolkit.inference" in cmd:
            check_cmd = f"{PYTHON_BIN} -c \"import importlib.util; print(importlib.util.find_spec('toolkit.inference') is not None)\""
            ok = subprocess.check_output(check_cmd, shell=True, text=True).strip().endswith("True")
            if ok:
                chosen = cmd
                break
            continue
        if "-m ai_toolkit.inference" in cmd:
            check_cmd = f"{PYTHON_BIN} -c \"import importlib.util; print(importlib.util.find_spec('ai_toolkit.inference') is not None)\""
            ok = subprocess.check_output(check_cmd, shell=True, text=True).strip().endswith("True")
            if ok:
                chosen = cmd
                break
            continue
        if "inference.py" in cmd and (AI_TOOLKIT_DIR / "inference.py").exists():
            chosen = cmd
            break

    if chosen:
        print("Running inference:", chosen)
        run(["bash", "-lc", chosen], cwd=str(AI_TOOLKIT_DIR), check=False)
        print("Samples dir:", SAMPLES_DIR)
    else:
        print("No known inference entrypoint detected.")
        print("Set a manual inference command in this cell for your fork.")
        print("Example:")
        print(f"{PYTHON_BIN} -m toolkit.inference --base {BASE_MODEL_ID} --lora {final_lora} --prompt 'portrait, cinematic lighting' --out {SAMPLES_DIR}")


In [None]:
# Publish pack: LoRA + metadata + prompts
final_loras = sorted(OUTPUT_DIR.rglob("*.safetensors"), key=lambda p: p.stat().st_mtime)
if not final_loras:
    raise RuntimeError(f"No .safetensors found under {OUTPUT_DIR}. Finish training first.")

final_lora = final_loras[-1]
metadata = {
    "project_name": PROJECT_NAME,
    "base_model": BASE_MODEL_ID,
    "allow_nsfw": ALLOW_NSFW,
    "final_lora": str(final_lora),
    "created_at": datetime.utcnow().isoformat() + "Z",
    "recommended_prompts": [
        f"{PROJECT_NAME}, detailed portrait, high contrast",
        f"{PROJECT_NAME}, cinematic scene, volumetric lighting",
    ],
    "negative_prompt": "low quality, blurry, artifacts",
}

meta_path = PUBLISH_DIR / "metadata.json"
meta_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")

bundle_path = PUBLISH_DIR / f"{PROJECT_NAME}-publish-pack.tar.gz"
with tarfile.open(bundle_path, "w:gz") as tf:
    tf.add(final_lora, arcname=final_lora.name)
    tf.add(config_path, arcname="train_flux_lora.yaml")
    tf.add(meta_path, arcname="metadata.json")

print("final_lora:", final_lora)
print("publish_bundle:", bundle_path)


## Next steps

- Re-run only changed cells when tuning hyperparameters.
- Keep each run in a new `PROJECT_NAME` to preserve logs and checkpoints.
- If your ai-toolkit fork uses different commands, edit `TRAIN_COMMAND_OVERRIDE` and inference cell.
