# NEURA9 — Treinamento Completo

Este notebook demonstra o fluxo completo de treinamento da NEURA9 usando o
dataset em `ai/dataset/neura9_dataset.csv`.

Execute as células sequencialmente a partir da raiz do projeto `WavePwn/`.

In [None]:
import pathlib

import numpy as np
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

DATASET = pathlib.Path("ai/dataset/neura9_dataset.csv")
assert DATASET.exists(), f"Dataset não encontrado em {DATASET}"

In [None]:
data = np.loadtxt(DATASET, delimiter=",", skiprows=1)
x = data[:, :-1].astype("float32")
y = data[:, -1].astype("int32")

x_train, x_test, y_train, y_test = train_test_split(
    x, y, test_size=0.2, random_state=42, stratify=y
)
x_train.shape, x_test.shape

In [None]:
inputs = keras.Input(shape=(x_train.shape[1],), name="features")
x_ = keras.layers.Dense(128, activation="relu")(inputs)
x_ = keras.layers.Dense(64, activation="relu")(x_)
x_ = keras.layers.Dense(32, activation="relu")(x_)
outputs = keras.layers.Dense(10, activation="softmax")(x_)

model = keras.Model(inputs=inputs, outputs=outputs, name="neura9_defense")
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)
model.summary()

In [None]:
callbacks = [
    keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True),
    keras.callbacks.ModelCheckpoint(
        "ai_training/best_model.h5", monitor="val_accuracy", save_best_only=True
    ),
]

history = model.fit(
    x_train,
    y_train,
    validation_split=0.2,
    epochs=80,
    batch_size=256,
    callbacks=callbacks,
    shuffle=True,
)

In [None]:
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"Acuracia de teste: {test_acc*100:.2f}%")

y_pred = model.predict(x_test).argmax(axis=1)
print(classification_report(y_test, y_pred))

In [None]:
import matplotlib.pyplot as plt

cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(cm, cmap="Blues")
ax.set_xlabel("Predito")
ax.set_ylabel("Real")
plt.colorbar(im, ax=ax)
plt.show()