In [7]:
from pathlib import Path
import numpy as np
import pandas as pd
import tensorflow as tf

from sklearn.preprocessing import MultiLabelBinarizer
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

from tensorflow.keras.applications import ResNet50
from tensorflow.keras import layers, models
from tensorflow.keras.metrics import AUC, Precision, Recall
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras import mixed_precision

# =========================
# Mixed precision (fp16)
# =========================
mixed_precision.set_global_policy("mixed_float16")

# ------------------------------------------------------------
# 1) Caminhos e hiperparâmetros
# ------------------------------------------------------------
DATA_DIR = Path("../data")  # ajuste se necessário
TRAIN_CSV = DATA_DIR / "train.csv"
TRAIN_DIR = DATA_DIR / "train_images"

IMG_SIZE = (224, 224)
BATCH_SIZE = 64
SEED = 42
EPOCHS = 10
AUTOTUNE = tf.data.AUTOTUNE
rng = np.random.default_rng(SEED)

# ------------------------------------------------------------
# 2) Carregar labels e montar vetor multi-hot
# ------------------------------------------------------------
df = pd.read_csv(TRAIN_CSV)
df["labels"] = df["labels"].astype(str).str.strip().str.split()
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df["labels"]).astype("float32")
X = df["image"].values
class_names = list(mlb.classes_)
num_classes = len(class_names)
print("Classes:", class_names)

# ------------------------------------------------------------
# 3) tf.data helpers (leitura, resize, normalização, augment)
# ------------------------------------------------------------
# Augment na GPU
data_augment = tf.keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.2),
        layers.RandomContrast(0.1),
    ],
    name="augment",
)


def _load_image(path):
    img = tf.io.read_file(path)
    # use decode_png se suas imagens forem .png
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE, antialias=True)
    img = tf.cast(img, tf.float32) / 255.0
    return img


def make_ds(paths, labels, training: bool):
    """Cria um tf.data.Dataset de (image, multi_hot) com pipeline rápido."""
    paths = tf.convert_to_tensor(paths)
    labels = tf.convert_to_tensor(labels, dtype=tf.float32)

    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    if training:
        # embaralha bem (não gigante pra não explodir RAM)
        buffer = min(10000, len(paths))
        ds = ds.shuffle(buffer, seed=SEED, reshuffle_each_iteration=True)

    # paraleliza leitura/decodificação
    ds = ds.map(lambda p, y: (_load_image(p), y), num_parallel_calls=AUTOTUNE)

    if training:
        ds = ds.map(
            lambda x, y: (data_augment(x, training=True), y),
            num_parallel_calls=AUTOTUNE,
        )

    ds = ds.batch(BATCH_SIZE, drop_remainder=False)
    # ds = ds.cache()  # habilite se couber na RAM/SSD (ou use .cache('arquivo'))
    ds = ds.prefetch(AUTOTUNE)

    # para máximo throughput
    options = tf.data.Options()
    options.experimental_deterministic = False
    ds = ds.with_options(options)
    return ds


# ------------------------------------------------------------
# helper: cria modelo (novo a cada fold)
# ------------------------------------------------------------
def build_model(num_classes: int):
    base_model = ResNet50(
        weights="imagenet", include_top=False, input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)
    )
    base_model.trainable = False  # fase 1: só a cabeça

    inputs = tf.keras.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dropout(0.5)(x)
    # saída em float32 para estabilidade com mixed precision
    outputs = layers.Dense(num_classes, activation="sigmoid", dtype="float32")(x)

    model = models.Model(inputs, outputs)
    model.compile(
        optimizer="adam",
        loss="binary_crossentropy",
        metrics=[
            "accuracy",
            Precision(name="precision"),
            Recall(name="recall"),
            AUC(name="auc", multi_label=True),
        ],
    )
    return model


# ------------------------------------------------------------
# 4) K-Fold Cross-Validation (5 folds) com tf.data
# ------------------------------------------------------------
mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
fold_metrics = []

for fold, (train_idx, val_idx) in enumerate(mskf.split(X, y), start=1):
    print(f"\n===== FOLD {fold}/5 =====")
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx].astype("float32"), y[val_idx].astype("float32")

    # caminhos absolutos (garanta que os arquivos estejam no filesystem do WSL2, não em /mnt/c)
    X_train_paths = [str(TRAIN_DIR / fname) for fname in X_train]
    X_val_paths = [str(TRAIN_DIR / fname) for fname in X_val]

    # datasets tf.data
    train_ds = make_ds(X_train_paths, y_train, training=True)
    val_ds = make_ds(X_val_paths, y_val, training=False)

    # modelo novo por fold
    model = build_model(num_classes)

    callbacks = [
        EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
        ReduceLROnPlateau(monitor="val_loss", factor=0.2, patience=3, min_lr=1e-6),
        ModelCheckpoint(f"best_fold{fold}.h5", monitor="val_loss", save_best_only=True),
    ]

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=EPOCHS,
        verbose=1,
        callbacks=callbacks,
    )

    # Avaliação neste fold
    eval_res = model.evaluate(val_ds, verbose=0)
    metrics_names = (
        model.metrics_names
    )  # ['loss','accuracy','precision','recall','auc']
    fold_result = dict(zip(metrics_names, eval_res))
    fold_result["fold"] = fold
    fold_metrics.append(fold_result)

    print(
        f"Fold {fold} -> "
        f"val_loss={fold_result['loss']:.4f} | "
        f"val_acc={fold_result['accuracy']:.4f} | "
        f"val_prec={fold_result['precision']:.4f} | "
        f"val_rec={fold_result['recall']:.4f} | "
        f"val_auc={fold_result['auc']:.4f}"
    )

# ------------------------------------------------------------
# 5) Resumo dos 5 folds (média ± desvio)
# ------------------------------------------------------------
fold_df = pd.DataFrame(fold_metrics).set_index("fold")
print("\nResultados por fold:")
print(fold_df.round(4))

summary = fold_df.agg(["mean", "std"]).round(4)
print("\nMédia e desvio (5 folds):")
print(summary)

# ------------------------------------------------------------
# (Opcional) Fine-tuning por fold
# ------------------------------------------------------------
# Depois da fase inicial, você pode descongelar parte da ResNet e rodar mais épocas:
# base_model = model.layers[1]  # se usar Model(inputs, outputs), ajuste índice conforme seu grafo
# base_model.trainable = True
# for layer in base_model.layers[:-30]:
#     layer.trainable = False
# model.compile(
#     optimizer=tf.keras.optimizers.Adam(1e-5),
#     loss="binary_crossentropy",
#     metrics=["accuracy", Precision(name="precision"), Recall(name="recall"),
#              AUC(name="auc", multi_label=True)]
# )
# model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS_FT, callbacks=callbacks)

Classes: ['complex', 'frog_eye_leaf_spot', 'healthy', 'powdery_mildew', 'rust', 'scab']

===== FOLD 1/5 =====
Epoch 1/10
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 361ms/step - accuracy: 0.2498 - auc: 0.5079 - loss: 0.4746 - precision: 0.2789 - recall: 0.0383



[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 490ms/step - accuracy: 0.2578 - auc: 0.5116 - loss: 0.4612 - precision: 0.2884 - recall: 0.0143 - val_accuracy: 0.2724 - val_auc: 0.6115 - val_loss: 0.4441 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 2/10
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m104s[0m 443ms/step - accuracy: 0.2636 - auc: 0.5161 - loss: 0.4534 - precision: 0.3556 - recall: 9.9077e-04 - val_accuracy: 0.2979 - val_auc: 0.6269 - val_loss: 0.4444 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 3/10
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m109s[0m 464ms/step - accuracy: 0.2717 - auc: 0.5326 - loss: 0.4503 - precision: 0.4286 - recall: 1.8577e-04 - val_accuracy: 0.2866 - val_auc: 0.6249 - val_loss: 0.4443 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 4/10
[1m233/234[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m 



[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m103s[0m 441ms/step - accuracy: 0.2722 - auc: 0.5376 - loss: 0.4491 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.2686 - val_auc: 0.6268 - val_loss: 0.4407 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 5/10
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 435ms/step - accuracy: 0.2779 - auc: 0.5430 - loss: 0.4484 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.3009 - val_auc: 0.6326 - val_loss: 0.4426 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 6/10
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m103s[0m 441ms/step - accuracy: 0.2773 - auc: 0.5418 - loss: 0.4495 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.2960 - val_auc: 0.6288 - val_loss: 0.4433 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 7/10
[1m233/234[0m [32m━━━━━━━━━━━━━━━━━━━



[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m103s[0m 440ms/step - accuracy: 0.2760 - auc: 0.5405 - loss: 0.4496 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.2716 - val_auc: 0.6265 - val_loss: 0.4393 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 8/10
[1m233/234[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 354ms/step - accuracy: 0.2807 - auc: 0.5456 - loss: 0.4485 - precision: 0.0000e+00 - recall: 0.0000e+00



[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m103s[0m 440ms/step - accuracy: 0.2777 - auc: 0.5516 - loss: 0.4476 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.2805 - val_auc: 0.6294 - val_loss: 0.4381 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 9/10
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m103s[0m 437ms/step - accuracy: 0.2755 - auc: 0.5473 - loss: 0.4485 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.2813 - val_auc: 0.6293 - val_loss: 0.4386 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 10/10
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m109s[0m 461ms/step - accuracy: 0.2781 - auc: 0.5532 - loss: 0.4472 - precision: 1.0000 - recall: 6.1923e-05 - val_accuracy: 0.2813 - val_auc: 0.6378 - val_loss: 0.4386 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010


KeyError: 'accuracy'