# Kraken OCR — Colab Training (Setup Wizard)

Run these three cells in order:
1) Cell A — Setup Wizard
2) Cell B — Build train/val lists
3) Cell C — Train with checkpoints & save the BEST


In [None]:
# === Cell A: Setup Wizard (run once per session) ===
from google.colab import drive, files
from pathlib import Path
import json, os, sys, subprocess, importlib, zipfile, io, re, datetime

PROJECT_NAME   = "0093"
FORCE_ATTEMPT  = None
PIN_KRAKEN_VER = ""

drive.mount("/content/drive")
HOME     = Path("/content/drive/MyDrive")
ROOT     = HOME / "kraken_projects" / PROJECT_NAME
ROOT.mkdir(parents=True, exist_ok=True)

CONFIG_P = ROOT / "config.json"
DATA_ROOT_DEFAULT = ROOT / "data"
MODELS_DIR        = ROOT / "models" / "rec"
CKPTS_DIR_ROOT    = ROOT / "ckpts"
PIP_CACHE_DIR     = HOME / "pip-cache"
for d in [DATA_ROOT_DEFAULT, MODELS_DIR, CKPTS_DIR_ROOT, PIP_CACHE_DIR]:
    d.mkdir(parents=True, exist_ok=True)

def ensure(pkg_spec: str):
    try:
        mod_name = pkg_spec.split("[" ,1)[0].split("==",1)[0]
        importlib.import_module(mod_name)
        print(f"✓ {pkg_spec} already available.")
    except Exception:
        print(f"Installing {pkg_spec} (cached in Drive) ...")
        cmd = [sys.executable, "-m", "pip", "install", "--upgrade", "--no-input",
               "--cache-dir", str(PIP_CACHE_DIR), pkg_spec]
        subprocess.check_call(cmd)

kraken_spec = f"kraken[train]{'=='+PIN_KRAKEN_VER if PIN_KRAKEN_VER else ''}"
ensure(kraken_spec)
import kraken
print("Kraken version:", getattr(kraken, "__version__", "unknown"))
try:
    subprocess.run(["ketos", "--version"], check=False)
except Exception:
    pass

cfg = {
    "project_name": PROJECT_NAME,
    "data_dir": str(DATA_ROOT_DEFAULT),
    "data_format": None,
    "val_ratio": 0.1,
    "random_seed": 42,
    "models_dir": str(MODELS_DIR),
    "ckpts_root": str(CKPTS_DIR_ROOT),
    "attempt_num": 1,
    "base_model_mode": "auto",
    "base_model_manual": "",
    "epochs": 20,
    "lr": 3e-4,
    "batch_size": 16,
}
if CONFIG_P.exists():
    try:
        existing = json.loads(CONFIG_P.read_text())
        if isinstance(existing, dict):
            cfg.update(existing)
        print(f"Loaded existing config from {CONFIG_P}")
    except Exception as e:
        print("Could not load existing config, starting fresh:", e)

print("\nOPTIONAL: upload a ZIP dataset and/or a base .mlmodel for attempt_01 (you can also skip).")
uploaded = files.upload()

def pick_data_dir_for_zip(zip_name: str, default_root: Path):
    stem = Path(zip_name).stem
    return default_root / stem

def detect_format_in_dir(d: Path):
    xmls = list(d.rglob("*.xml"))
    imgs = []
    for ext in [".png",".jpg",".jpeg",".tif",".tiff",".bmp"]:
        imgs += list(d.rglob(f"*{ext}"))
    gts  = list(d.rglob("*.gt.txt"))
    if xmls:
        sample = xmls[0]
        try:
            txt = sample.read_text(errors="ignore")[:4096].lower()
            if "<alto" in txt:
                return "alto"
            if "<pcgts" in txt or "<page" in txt:
                return "page"
        except Exception:
            pass
        return "alto"
    if imgs and gts:
        return "pairs"
    return None

DATA_DIR = Path(cfg.get("data_dir", DATA_ROOT_DEFAULT))

for fname, blob in uploaded.items():
    if fname.lower().endswith(".zip"):
        target = pick_data_dir_for_zip(fname, DATA_ROOT_DEFAULT)
        target.mkdir(parents=True, exist_ok=True)
        print(f"Extracting ZIP to {target} ...")
        with zipfile.ZipFile(io.BytesIO(blob), 'r') as z:
            z.extractall(target)
        print("Done extracting.")
        DATA_DIR = target
        cfg["data_dir"] = str(DATA_DIR)
    elif fname.lower().endswith(".mlmodel"):
        dst = MODELS_DIR / "attempt_01.mlmodel"
        if not dst.exists():
            with open(dst, "wb") as f:
                f.write(blob)
            print(f"Saved base model to {dst}")
        else:
            print(f"Base model already exists at {dst} (skipped).")
    else:
        print(f"Skipped {fname} (not .zip/.mlmodel)")

if not cfg.get("data_format"):
    df = detect_format_in_dir(DATA_DIR)
    cfg["data_format"] = df or "pairs"
    print("Auto-detected DATA_FORMAT:", cfg["data_format"])

def next_attempt(models_dir: Path):
    nums = []
    for p in models_dir.glob("attempt_*.mlmodel"):
        m = re.search(r"attempt_(\d+)\.mlmodel$", p.name)
        if m:
            nums.append(int(m.group(1)))
    return (max(nums) + 1) if nums else 1

if isinstance(FORCE_ATTEMPT, int) and FORCE_ATTEMPT > 0:
    cfg["attempt_num"] = FORCE_ATTEMPT
else:
    cfg["attempt_num"] = next_attempt(MODELS_DIR)

if cfg["attempt_num"] <= 1:
    base_model_path = MODELS_DIR / "attempt_01.mlmodel"
else:
    base_model_path = MODELS_DIR / f"attempt_{cfg['attempt_num']-1:02d}.mlmodel"

if cfg.get("base_model_mode") == "manual" and cfg.get("base_model_manual"):
    base_model_path = Path(cfg["base_model_manual"])

cfg["resolved_base_model"] = str(base_model_path)

CONFIG_P.write_text(json.dumps(cfg, ensure_ascii=False, indent=2))
print("\nSaved config to:", CONFIG_P)
print(json.dumps(cfg, ensure_ascii=False, indent=2))
print("\nNext: run Cell B — Build train/val lists.")

In [None]:
# === Cell B: Build train/val lists (no edits) ===
import json, random
from pathlib import Path

HOME     = Path("/content/drive/MyDrive")
ROOT     = HOME / "kraken_projects"

def find_latest_config(root: Path):
    configs = list(root.rglob("config.json"))
    if not configs:
        raise SystemExit("No config.json found. Run Cell A (Setup Wizard) first.")
    configs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
    return configs[0]

CONFIG_P = find_latest_config(ROOT)
cfg = json.loads(CONFIG_P.read_text())
print("Using config:", CONFIG_P)

DATA_DIR    = Path(cfg["data_dir"])
DATA_FORMAT = cfg["data_format"]
VAL_RATIO   = float(cfg.get("val_ratio", 0.1))
RANDOM_SEED = int(cfg.get("random_seed", 42))

LISTS_DIR = Path("/content/lists")
LISTS_DIR.mkdir(parents=True, exist_ok=True)
train_list_path = LISTS_DIR / "train.txt"
val_list_path   = LISTS_DIR / "val.txt"

def collect_pairs(root: Path):
    exts = {".png",".jpg",".jpeg",".tif",".tiff",".bmp"}
    images = []
    for ext in exts:
        images.extend(root.rglob(f"*{ext}"))
    samples = []
    for img in images:
        base = img.with_suffix("")
        gt = img.parent / (base.name + ".gt.txt")
        if gt.exists():
            samples.append(str(gt))
    return sorted(set(samples))

def collect_xml(root: Path):
    return sorted(set(str(p) for p in root.rglob("*.xml")))

if DATA_FORMAT == "pairs":
    all_samples = collect_pairs(DATA_DIR)
elif DATA_FORMAT in ("alto", "page"):
    all_samples = collect_xml(DATA_DIR)
else:
    raise SystemExit("DATA_FORMAT must be 'pairs', 'alto', or 'page'.")

if not all_samples:
    pngs = sum(1 for _ in DATA_DIR.rglob("*.png"))
    gts  = sum(1 for _ in DATA_DIR.rglob("*.gt.txt"))
    xmls = sum(1 for _ in DATA_DIR.rglob("*.xml"))
    raise SystemExit(
        f"No training samples found under {DATA_DIR}.\n"
        f"DATA_FORMAT={DATA_FORMAT}\n"
        f"Found: {pngs} PNGs, {gts} .gt.txt files, {xmls} XML files.\n"
        f"If you uploaded ALTO/PAGE XML, set DATA_FORMAT='alto' or 'page' in Cell A.\n"
        f"If you have image+gt pairs, set DATA_FORMAT='pairs' and ensure *.gt.txt exist beside images."
    )

random.seed(RANDOM_SEED)
random.shuffle(all_samples)
n = len(all_samples)
n_val = max(1, int(n * VAL_RATIO))
val = all_samples[:n_val]
train = all_samples[n_val:]

with open(train_list_path, "w") as f:
    f.write("\n".join(train) + "\n")
with open(val_list_path, "w") as f:
    f.write("\n".join(val) + "\n")

print(f"Wrote {len(train)} train and {len(val)} val samples.")
print("Train list:", train_list_path)
print("Val   list:", val_list_path)
print("\nNext: run Cell C — Train with checkpoints & save the BEST.")

In [None]:
# === Cell C: Train with checkpoints & save the BEST (no edits) ===
import json, re, csv, shutil, subprocess, datetime
from pathlib import Path

HOME  = Path("/content/drive/MyDrive")
ROOT  = HOME / "kraken_projects"

def find_latest_config(root: Path):
    configs = list(root.rglob("config.json"))
    if not configs:
        raise SystemExit("No config.json found. Run Cell A (Setup Wizard) first.")
    configs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
    return configs[0]

CONFIG_P = find_latest_config(ROOT)
cfg = json.loads(CONFIG_P.read_text())
print("Using config:", CONFIG_P)

DATA_FORMAT    = cfg["data_format"]
MODELS_DIR     = Path(cfg["models_dir"])
CKPTS_DIR_ROOT = Path(cfg["ckpts_root"])
ATTEMPT_NUM    = int(cfg["attempt_num"])
ATTEMPT_NAME   = f"attempt_{ATTEMPT_NUM:02d}.mlmodel"
BEST_MODEL_OUT = MODELS_DIR / ATTEMPT_NAME
BASE_MODEL     = Path(cfg.get("resolved_base_model","")) if cfg.get("resolved_base_model") else None

LISTS_DIR = Path("/content/lists")
train_list_path = LISTS_DIR / "train.txt"
val_list_path   = LISTS_DIR / "val.txt"
assert train_list_path.exists() and val_list_path.exists(), "Run Cell B to build train/val lists first."

CKPT_DIR = CKPTS_DIR_ROOT / f"ckpts_{ATTEMPT_NAME.replace('.mlmodel','')}"
CKPT_DIR.mkdir(parents=True, exist_ok=True)

base_opt = ""
if BASE_MODEL and BASE_MODEL.exists():
    base_opt = f'--load "{BASE_MODEL}"'
    print("Using base model:", BASE_MODEL)
else:
    print("No base model found. Training from scratch.")

EPOCHS     = int(cfg.get("epochs", 20))
LR         = float(cfg.get("lr", 3e-4))
BATCH_SIZE = int(cfg.get("batch_size", 16))

fmt_opt = "" if DATA_FORMAT == "pairs" else f"-f {DATA_FORMAT}"

train_cmd = f'''ketos train {fmt_opt} {base_opt} --savefreq 1 --epochs {EPOCHS} -lr {LR} -b {BATCH_SIZE} -o "{CKPT_DIR}/epoch_model.mlmodel" $(cat "{train_list_path}") --validation $(cat "{val_list_path}")'''
print("Training command:\n", train_cmd)

ret = subprocess.call(train_cmd, shell=True, executable="/bin/bash")
if ret != 0:
    print("Training exited non-zero (possibly interrupted). Proceeding to pick the best checkpoint.")

def parse_accuracy(output: str):
    acc = None
    cer = None
    m = re.search(r'accuracy[:\s]+([0-9.]+)%', output, re.I)
    if m: acc = float(m.group(1))
    m2 = re.search(r'CER[:\s]+([0-9.]+)', output, re.I)
    if m2: cer = float(m2.group(1))
    return acc, cer

def score_model(model_path: Path):
    test_cmd = f'''ketos test -m "{model_path}" $(cat "{val_list_path}")'''
    proc = subprocess.run(test_cmd, shell=True, executable="/bin/bash",
                          stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    out = proc.stdout
    acc, cer = parse_accuracy(out)
    score = acc if acc is not None else (100.0 - cer*100.0 if cer is not None else float("-inf"))
    return score, acc, cer

ckpts = sorted(CKPT_DIR.glob("*.mlmodel"))
if not ckpts:
    raise SystemExit(f"No checkpoints found in {CKPT_DIR}. Nothing to pick.")

best = None
for m in ckpts:
    score, acc, cer = score_model(m)
    print(f"{m.name}: score={score:.4f}  acc={acc}  cer={cer}")
    if best is None or score > best["score"]:
        best = {"path": m, "score": score, "acc": acc, "cer": cer}

shutil.copy2(best["path"], BEST_MODEL_OUT)
print("Saved best model as:", BEST_MODEL_OUT)

log_path = MODELS_DIR / "attempts.csv"
is_new = not log_path.exists()
with open(log_path, "a", newline="") as f:
    w = csv.writer(f)
    if is_new:
        w.writerow(["timestamp","attempt","model_path","base_model","score","accuracy","cer"])
    w.writerow([datetime.datetime.now().isoformat(), ATTEMPT_NAME, str(BEST_MODEL_OUT),
                str(BASE_MODEL) if BASE_MODEL else "", best["score"], best["acc"], best["cer"]])

print("Logged to:", log_path)
print("\nDone. Next time, run Cell A again; it will auto-increment attempts and chain from the last saved model.")