In [1]:
# 0. Install dependencies (run once per environment)

%pip install fastai --quiet


Note: you may need to restart the kernel to use updated packages.


In [2]:
# 1. Imports and configuration

from pathlib import Path
import pandas as pd
import numpy as np
import torch
from torch import nn

from fastai.basics import *
from fastai.callback.tracker import EarlyStoppingCallback
import fastai.callback.schedule  # ensures fit_one_cycle is patched onto Learner

# Deterministic-ish runs
torch.manual_seed(42)
np.random.seed(42)

# Input CSVs
GENERATED_CSV_ORIGINAL_PATH = Path("Palettes/Generated/palette_export_generated.csv")
GENERATED_CSV_RANDOMIZED_PATH = Path("Palettes/Generated/palette_export_generated_rand_order.csv")
ADOBE_CSV_ORIGINAL_PATH = Path("Palettes/Adobe/adobe_palettes.csv")
ADOBE_CSV_RANDOMIZED_PATH = Path("Palettes/Adobe/adobe_palettes_randomized.csv")

# Output models (default folder: trained_models)
MODEL_DIR = Path("trained_models")
MODEL_ORIGINAL_PATH = MODEL_DIR / "palette_autoencoder.pkl"
MODEL_RANDOMIZED_PATH = MODEL_DIR / "palette_autoencoder_rand_order.pkl"
MODEL_RULES_ONLY_PATH = MODEL_DIR / "palette_autoencoder_rules_only.pkl"

# Enable/disable each training run
TRAIN_NORMAL_ORDER = False
TRAIN_RANDOM_ORDER = False
TRAIN_RULES_ONLY = True

# For rules-only autoencoder: generated palettes only (labeled rule data)
RULES_ONLY_INCLUDE_RANDOMIZED = True

TRAIN_CONFIGS = [
    {
        "name": "normal_order",
        "enabled": TRAIN_NORMAL_ORDER,
        "generated_csv": GENERATED_CSV_ORIGINAL_PATH,
        "adobe_csv": ADOBE_CSV_ORIGINAL_PATH,
        "model_path": MODEL_ORIGINAL_PATH,
        "rules_only": False,
    },
    {
        "name": "random_order",
        "enabled": TRAIN_RANDOM_ORDER,
        "generated_csv": GENERATED_CSV_RANDOMIZED_PATH,
        "adobe_csv": ADOBE_CSV_RANDOMIZED_PATH,
        "model_path": MODEL_RANDOMIZED_PATH,
        "rules_only": False,
    },
    {
        "name": "rules_only",
        "enabled": TRAIN_RULES_ONLY,
        "generated_csv": GENERATED_CSV_ORIGINAL_PATH,
        "adobe_csv": None,
        "model_path": MODEL_RULES_ONLY_PATH,
        "rules_only": True,
    },
]

COLOR_COUNT = 5  # number of colors per palette
VALID_PCT = 0.2  # validation set percentage
BATCH_SIZE = 128  # batch size for training
LATENT_DIM = 16  # dimension of latent space
EPOCHS = 40  # number of training epochs
LR = 0.002  # learning rate
WEIGHTED_LOSS = True  # reduce influence of duplicate palettes
SRGB_TO_LINEAR = True  # convert Adobe palettes from sRGB to linear (expects 0..1 input)


In [3]:
# 2. Dataset utilities

feature_cols = [f"{x}{i}" for i in range(1, COLOR_COUNT + 1) for x in ("r", "g", "b")]
expected_generated = ["law", "id_palette"] + [
    f"c{x}{i}" for i in range(1, COLOR_COUNT + 1) for x in ("r", "g", "b")
]
expected_adobe = [f"{x}{i}" for i in range(1, COLOR_COUNT + 1) for x in ("r", "g", "b")]


def _srgb_to_linear(arr):
    arr = np.clip(arr, 0.0, 1.0)
    return np.where(arr <= 0.04045, arr / 12.92, ((arr + 0.055) / 1.055) ** 2.4)


def _validate_normalized(df, cols, source_name):
    min_v = float(df[cols].min().min())
    max_v = float(df[cols].max().max())
    if min_v < 0.0 or max_v > 1.0:
        raise ValueError(
            f"{source_name} contains non-normalized values (min={min_v:.4f}, max={max_v:.4f}). "
            "Expected all color channels in [0, 1]."
        )


def _load_generated_df(generated_csv_path: Path) -> pd.DataFrame:
    if not generated_csv_path.exists():
        raise FileNotFoundError(f"Generated CSV not found: {generated_csv_path}")

    df_gen = pd.read_csv(generated_csv_path)
    missing = [c for c in expected_generated if c not in df_gen.columns]
    if missing:
        raise ValueError(f"Generated column mismatch for {generated_csv_path}. Missing {missing}, got {list(df_gen.columns)}")

    df_gen = df_gen.drop(columns=[c for c in ("palette_name", "batch") if c in df_gen.columns])
    rename_map = {f"c{x}{i}": f"{x}{i}" for i in range(1, COLOR_COUNT + 1) for x in ("r", "g", "b")}
    df_gen = df_gen.rename(columns=rename_map)

    _validate_normalized(df_gen, feature_cols, f"Generated CSV ({generated_csv_path})")
    df_gen[feature_cols] = df_gen[feature_cols].astype("float32")
    df_gen[["law", "id_palette"]] = df_gen[["law", "id_palette"]].astype("int64")
    return df_gen


def _load_adobe_df(adobe_csv_path: Path) -> pd.DataFrame:
    if not adobe_csv_path.exists():
        raise FileNotFoundError(f"Adobe CSV not found: {adobe_csv_path}")

    df_adobe = pd.read_csv(adobe_csv_path)
    if list(df_adobe.columns) != expected_adobe:
        raise ValueError(f"Adobe column mismatch for {adobe_csv_path}. Expected {expected_adobe}, got {list(df_adobe.columns)}")

    df_adobe.insert(0, "id_palette", -1)
    df_adobe.insert(0, "law", -1)
    _validate_normalized(df_adobe, feature_cols, f"Adobe CSV ({adobe_csv_path})")

    if SRGB_TO_LINEAR:
        df_adobe[feature_cols] = _srgb_to_linear(df_adobe[feature_cols].to_numpy(dtype="float32"))

    df_adobe[feature_cols] = df_adobe[feature_cols].astype("float32")
    df_adobe[["law", "id_palette"]] = df_adobe[["law", "id_palette"]].astype("int64")
    return df_adobe


def _attach_sample_weights(df: pd.DataFrame) -> pd.DataFrame:
    if WEIGHTED_LOSS:
        dup_counts = df.groupby(feature_cols).size()
        df["sample_weight"] = df.set_index(feature_cols).index.map(dup_counts).astype("float32")
        df["sample_weight"] = 1.0 / df["sample_weight"]
        df["sample_weight"] = df["sample_weight"] * (len(df) / df["sample_weight"].sum())
    return df


def load_palette_dataframe(generated_csv_path: Path, adobe_csv_path: Path) -> pd.DataFrame:
    df_gen = _load_generated_df(generated_csv_path)
    df_adobe = _load_adobe_df(adobe_csv_path)

    df = pd.concat([df_gen, df_adobe], ignore_index=True)
    df[feature_cols] = df[feature_cols].astype("float32")
    df[["law", "id_palette"]] = df[["law", "id_palette"]].astype("int64")
    return _attach_sample_weights(df)


def load_rules_only_dataframe(primary_generated_csv_path: Path, include_randomized: bool = True) -> pd.DataFrame:
    frames = [_load_generated_df(primary_generated_csv_path)]
    if include_randomized:
        frames.append(_load_generated_df(GENERATED_CSV_RANDOMIZED_PATH))

    df = pd.concat(frames, ignore_index=True)
    df[feature_cols] = df[feature_cols].astype("float32")
    df[["law", "id_palette"]] = df[["law", "id_palette"]].astype("int64")
    return _attach_sample_weights(df)


In [4]:
# 3. DataLoader + model utilities


def build_dataloaders(df: pd.DataFrame):
    def _to_features(row):
        return TensorBase(row[feature_cols].to_numpy(dtype="float32"))

    if WEIGHTED_LOSS:
        def _to_weight(row):
            return TensorBase(np.array(row["sample_weight"], dtype="float32"))

        dblock = DataBlock(
            blocks=(RegressionBlock, RegressionBlock, RegressionBlock),
            get_x=_to_features,
            get_y=[_to_features, _to_weight],
            splitter=RandomSplitter(valid_pct=VALID_PCT, seed=42),
            n_inp=1,
        )
    else:
        dblock = DataBlock(
            blocks=(RegressionBlock, RegressionBlock),
            get_x=_to_features,
            get_y=_to_features,
            splitter=RandomSplitter(valid_pct=VALID_PCT, seed=42),
        )

    return dblock.dataloaders(df, bs=BATCH_SIZE, num_workers=0)


class PaletteAutoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))


class WeightedMSELoss(nn.Module):
    def forward(self, pred, y, w):
        loss = (pred - y) ** 2
        loss = loss.mean(dim=1)
        return (loss * w).mean()


In [5]:
# 4. Train + export utilities


def train_and_export(run_name: str, generated_csv_path: Path, adobe_csv_path: Path, model_path: Path, rules_only: bool = False):
    print(f"\n=== Training: {run_name} ===")
    print(f"Generated CSV: {generated_csv_path}")
    print(f"Adobe CSV:     {adobe_csv_path if adobe_csv_path is not None else 'None (rules-only mode)'}")
    print(f"Model output:  {model_path}")

    if rules_only:
        df = load_rules_only_dataframe(generated_csv_path, include_randomized=RULES_ONLY_INCLUDE_RANDOMIZED)
    else:
        df = load_palette_dataframe(generated_csv_path, adobe_csv_path)

    print(f"Rows loaded: {len(df)}")

    dls = build_dataloaders(df)
    model = PaletteAutoencoder(input_dim=len(feature_cols), latent_dim=LATENT_DIM)

    loss_func = WeightedMSELoss() if WEIGHTED_LOSS else nn.MSELoss()
    metrics = [] if WEIGHTED_LOSS else [rmse]

    learner = Learner(dls, model, loss_func=loss_func, metrics=metrics)
    if torch.cuda.is_available():
        learner = learner.to_fp16()

    callbacks = [EarlyStoppingCallback(monitor="valid_loss", min_delta=5e-4, patience=3)]
    learner.fit_one_cycle(EPOCHS, LR, cbs=callbacks)

    model_path.parent.mkdir(parents=True, exist_ok=True)
    learner.export(model_path)
    print(f"Saved model: {model_path}")

    return {
        "run_name": run_name,
        "rows": len(df),
        "rules_only": bool(rules_only),
        "model_path": str(model_path),
        "final_train_loss": float(learner.recorder.values[-1][0]),
        "final_valid_loss": float(learner.recorder.values[-1][1]),
    }


In [6]:
# 5. Train enabled model variants automatically

results = []
for cfg in TRAIN_CONFIGS:
    if not cfg.get("enabled", True):
        print(f"Skipping disabled run: {cfg['name']}")
        continue

    result = train_and_export(
        run_name=cfg["name"],
        generated_csv_path=cfg["generated_csv"],
        adobe_csv_path=cfg["adobe_csv"],
        model_path=cfg["model_path"],
        rules_only=cfg.get("rules_only", False),
    )
    results.append(result)

if not results:
    print("No training runs were enabled. Set TRAIN_* variables to True.")

results_df = pd.DataFrame(results)
results_df


Skipping disabled run: normal_order
Skipping disabled run: random_order

=== Training: rules_only ===
Generated CSV: Palettes\Generated\palette_export_generated.csv
Adobe CSV:     None (rules-only mode)
Model output:  trained_models\palette_autoencoder_rules_only.pkl
Rows loaded: 14000


epoch,train_loss,valid_loss,time
0,0.121808,0.120491,00:06
1,0.116523,0.110509,00:06
2,0.097712,0.087663,00:06
3,0.076256,0.063801,00:06
4,0.055622,0.043568,00:06
5,0.036479,0.028232,00:06
6,0.027041,0.023641,00:06
7,0.021357,0.019212,00:06
8,0.018989,0.017454,00:07
9,0.014235,0.011479,00:07


No improvement since epoch 15: early stopping
Saved model: trained_models\palette_autoencoder_rules_only.pkl


Unnamed: 0,run_name,rows,rules_only,model_path,final_train_loss,final_valid_loss
0,rules_only,14000,True,trained_models\palette_autoencoder_rules_only.pkl,0.001881,0.001808


In [7]:
# 6. Quick summary

if not results:
    print("No runs to summarize.")
else:
    for r in results:
        print(
            f"{r['run_name']}: rows={r['rows']}, "
            f"train_loss={r['final_train_loss']:.6f}, valid_loss={r['final_valid_loss']:.6f}, "
            f"model={r['model_path']}"
        )


rules_only: rows=14000, train_loss=0.001881, valid_loss=0.001808, model=trained_models\palette_autoencoder_rules_only.pkl
