In [2]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import splitfolders  # For splitting data
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import seaborn as sns

In [3]:
# Split data into 70% train, 15% validation, 15% test
splitfolders.ratio(
    "data", 
    output="split_data", 
    seed=42,
    ratio=(0.7, 0.15, 0.15), 
    group_prefix=None
)

In [4]:
# Constants
BATCH_SIZE = 32
IMG_SIZE = (224, 224)  # ResNet expects 224x224

# Data augmentation for training
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True
)

# No augmentation for validation/test
val_test_datagen = ImageDataGenerator(rescale=1./255)

# Train generator
train_generator = train_datagen.flow_from_directory(
    "split_data/train",
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="binary"  # REAL=0, FAKE=1
)

# Validation generator
val_generator = val_test_datagen.flow_from_directory(
    "split_data/val",
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="binary"
)

# Test generator
test_generator = val_test_datagen.flow_from_directory(
    "split_data/test",
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="binary",
    shuffle=False  # Critical for metrics
)

# Verify class indices
print("Class indices:", train_generator.class_indices)

Found 15148 images belonging to 2 classes.
Found 3246 images belonging to 2 classes.
Found 3248 images belonging to 2 classes.
Class indices: {'FAKE': 0, 'REAL': 1}


In [5]:
# Load pre-trained ResNet50 (without top layers)
base_model = ResNet50(
    weights='imagenet', 
    include_top=False, 
    input_shape=(224, 224, 3)
)
base_model.trainable = False  # Freeze initially

# Add custom head
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
outputs = Dense(1, activation='sigmoid')(x)
model = Model(base_model.input, outputs)

# Compile
model.compile(
    optimizer=Adam(1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy']
)
model.summary()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m94765736/94765736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 0us/step


In [6]:
# Train the top layers
history = model.fit(
    train_generator,
    epochs=10,
    validation_data=val_generator,
    verbose=1
)

  self._warn_if_super_not_called()


Epoch 1/10
[1m474/474[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1097s[0m 2s/step - accuracy: 0.5392 - loss: 0.7102 - val_accuracy: 0.6420 - val_loss: 0.6364
Epoch 2/10
[1m425/474[0m [32m━━━━━━━━━━━━━━━━━[0m[37m━━━[0m [1m1:25[0m 2s/step - accuracy: 0.6094 - loss: 0.6512

KeyboardInterrupt: 

In [None]:
# Unfreeze deeper layers
base_model.trainable = True
for layer in base_model.layers[:100]:
    layer.trainable = False  # Freeze first 100 layers

# Recompile with lower learning rate
model.compile(
    optimizer=Adam(1e-5),  # Smaller LR for fine-tuning
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Continue training
history_fine = model.fit(
    train_generator,
    epochs=5,
    validation_data=val_generator,
    verbose=1
)

In [None]:
# Get predictions
test_preds = model.predict(test_generator)
test_preds_binary = (test_preds > 0.5).astype(int)

# Classification report
print(classification_report(test_generator.classes, test_preds_binary))

# Confusion matrix
cm = confusion_matrix(test_generator.classes, test_preds_binary)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')

# ROC-AUC
roc_auc = roc_auc_score(test_generator.classes, test_preds)
print(f"ROC-AUC: {roc_auc:.3f}")

In [None]:
# Plot accuracy and loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.title('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss')
plt.legend()
plt.show()

In [None]:
model.save("art_classifier_model.h5")