"""
#🐾 Animal Faces HQ (AFHQ) - Image Classification using CNN 🧠

## 📘 About the Dataset
--------------------
This dataset, also known as **Animal Faces-HQ (AFHQ)**, consists of **16,130 high-quality images**
at **512×512 resolution**.

It contains **three domains of animal faces**, each with about **5,000 images** representing
a diverse range of breeds and appearances.

**Classes:**
1. 🐱 Cat
2. 🐶 Dog
3. Wildlife


## Import libraries

In [None]:
import kagglehub
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from tensorflow.keras.preprocessing  import image_dataset_from_directory
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout


## Load dataset

In [None]:
# Download latest version
path = kagglehub.dataset_download("andrewmvd/animal-faces")

print("Path to dataset files:", path)

In [None]:
train_dir=os.path.join(path,'afhq','train')
test_dir=os.path.join(path,'afhq','val')

In [None]:
IMAGE_SIZE=(128,128)
BATCH_SIZE=12
train_ds=image_dataset_from_directory(train_dir,
                                      image_size=IMAGE_SIZE,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True,
                                      )
test_ds=image_dataset_from_directory(test_dir,
                                      image_size=IMAGE_SIZE,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True
                                      )

In [None]:
class_names=train_ds.class_names
n_classes=len(class_names)
print(class_names)


# # defin Rescaling + Resizing function

In [None]:
def Image_resize_and_rescale(image):
    image=tf.image.resize(image,(128,128))
    image=image/255.0
    return image

# Apply Rescaling + Resizing function

In [None]:
train_ds=train_ds.map(lambda x,y: (Image_resize_and_rescale(x),y))
test_ds=test_ds.map(lambda x,y: (Image_resize_and_rescale(x),y))

 # Optimize dataset performance


In [None]:
train_ds=train_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
test_ds=test_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

 # Build CNN model

In [None]:
model=Sequential([
    Conv2D(32,(3,3),activation='relu',input_shape=(128,128,3)),
    MaxPooling2D((2,2)),
    Conv2D(64,(3,3),activation='relu'),
    MaxPooling2D((2,2)),
    Conv2D(128,(3,3),activation='relu'),
    MaxPooling2D((2,2)),
    Conv2D(128,(3,3),activation='relu'),
    MaxPooling2D((2,2)),
    Flatten(),
    Dense(512,activation='relu'),
    Dense(n_classes,activation='softmax')


])

# Compile model

In [None]:
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train model

In [None]:
history=model.fit(train_ds,epochs=10,verbose=2)

# Plot training history


In [None]:
# Plot training accuracy
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy', color='green')
plt.title('Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot training loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss', color='red')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.show()


# Evaluate model

In [None]:
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc*100:.2f}%")
print(f"Test Loss: {test_loss:.4f}")


# Confusion Matrix and Classification Report


In [None]:
# Get true and predicted labels
true_labels = []
pred_labels = []

for images, labels in test_ds:
    preds = model.predict(images,verbose=0)
    pred_labels.extend(np.argmax(preds, axis=1))
    true_labels.extend(labels.numpy())

# Convert to numpy arrays
true_labels = np.array(true_labels)
pred_labels = np.array(pred_labels)

# Compute confusion matrix
cm = confusion_matrix(true_labels, pred_labels)

# Display confusion matrix
plt.figure(figsize=(6,6))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(cmap='Blues', values_format='d')
plt.title('Confusion Matrix')
plt.show()
