In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report

# Wczytanie danych MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalizacja i przygotowanie danych
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# Podział danych: treningowe (cyfry 0-4), testowe (wszystkie)
train_idx = np.isin(y_train, [0, 1, 2, 3, 4])
test_idx = np.isin(y_test, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

x_train_normal = x_train[train_idx]
y_train_normal = y_train[train_idx]
x_test_all = x_test[test_idx]
y_test_all = y_test[test_idx]

# Etykiety anomalii: 0=normalne (0-4), 1=anomalia (5-9)
y_test_labels = np.where(np.isin(y_test_all, [5, 6, 7, 8, 9]), 1, 0)

# Budowa autoenkodera
input_shape = (28, 28, 1)

# Encoder
inputs = layers.Input(shape=input_shape)
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = layers.MaxPooling2D((2, 2), padding='same')(x)
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = layers.MaxPooling2D((2, 2), padding='same')(x)

# Decoder
x = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = layers.UpSampling2D((2, 2))(x)
x = layers.Conv2D(32, (3, 3), activation='relu')(x)
x = layers.UpSampling2D((2, 2))(x)
decoded = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = keras.Model(inputs, decoded)
autoencoder.compile(optimizer='adam', loss='mse')
autoencoder.summary()

# Trenowanie modelu
history = autoencoder.fit(
    x_train_normal, x_train_normal,
    epochs=30,
    batch_size=128,
    shuffle=True,
    validation_split=0.1
)

# Obliczanie błędów rekonstrukcji
train_recon = autoencoder.predict(x_train_normal)
train_mse = np.mean(np.square(x_train_normal - train_recon), axis=(1, 2, 3))

# Ustalenie progu anomalii (95 percentyl)
threshold = np.percentile(train_mse, 95)

# Ewaluacja na zbiorze testowym
test_recon = autoencoder.predict(x_test_all)
test_mse = np.mean(np.square(x_test_all - test_recon), axis=(1, 2, 3))

# Predykcje anomalii
predictions = (test_mse > threshold).astype(int)

# Ocena modelu
print("\nKlasyfikacja:")
print(classification_report(y_test_labels, predictions, target_names=['Normal', 'Anomalia']))

# Macierz pomyłek
cm = confusion_matrix(y_test_labels, predictions)
print("Macierz pomyłek:")
print(cm)

# Krzywa ROC
fpr, tpr, _ = roc_curve(y_test_labels, test_mse)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Krzywa ROC')
plt.legend(loc="lower right")

# Przykładowe rekonstrukcje
plt.subplot(1, 2, 2)
n = 10
indices = np.random.choice(range(len(x_test_all)), n)
plt.figure(figsize=(20, 4))

for i, idx in enumerate(indices):
    # Oryginał
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test_all[idx].squeeze(), cmap='gray')
    plt.title(f"Label: {y_test_all[idx]}\nMSE: {test_mse[idx]:.4f}")
    plt.axis('off')
    
    # Rekonstrukcja
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(test_recon[idx].squeeze(), cmap='gray')
    plt.title(f"Pred: {'Anomalia' if predictions[idx] else 'Normal'}")
    plt.axis('off')

plt.tight_layout()
plt.show()

# Wykres błędów rekonstrukcji
plt.figure(figsize=(10, 6))
plt.hist(test_mse[y_test_labels == 0], bins=50, alpha=0.5, label='Normal (0-4)')
plt.hist(test_mse[y_test_labels == 1], bins=50, alpha=0.5, label='Anomalia (5-9)')
plt.axvline(threshold, color='r', linestyle='dashed', linewidth=2, label=f'Próg: {threshold:.4f}')
plt.title('Rozkład błędów rekonstrukcji')
plt.xlabel('Błąd rekonstrukcji (MSE)')
plt.ylabel('Liczba próbek')
plt.yscale('log')
plt.legend()
plt.show()