In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import DenseNet121
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix


## 
https://www.kaggle.com/datasets/abdulbasit31/strawberry-dataset/data

In [None]:
# import kagglehub
# # Download latest version
# path = kagglehub.dataset_download("abdulbasit31/strawberry-dataset")
# print("Path to dataset files:", path)

In [2]:

# Define path to your dataset
dataset_dir = '/kaggle/input/strawberry-dataset/strawberryDataset'

# Define image size and batch size
img_size = (224, 224)
batch_size = 32


In [3]:
dataset_dir

'/kaggle/input/strawberry-dataset/strawberryDataset'

In [4]:

# Data augmentation for training set
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2  # Split data into 80% training and 20% validation
)

# No data augmentation for validation set, only rescaling
test_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

# Load and augment data
train_generator = train_datagen.flow_from_directory(
    directory=dataset_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    subset='training'  # Set as training data
)

validation_generator = test_datagen.flow_from_directory(
    directory=dataset_dir,
    target_size=img_size,
    batch_size=batch_size,
    class_mode='binary',
    subset='validation'  # Set as validation data
)

# Load the DenseNet121 model pre-trained on ImageNet, exclude the top layer
base_model = DenseNet121(input_shape=(img_size[0], img_size[1], 3), include_top=False, weights='imagenet')

# Freeze the base model
base_model.trainable = False

# Define the model
model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(1, activation='sigmoid')  # Binary classification
])

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(
    train_generator,
    epochs=20,  # Adjust number of epochs as needed
    validation_data=validation_generator
)

# Evaluate the model on validation set
val_loss, val_accuracy = model.evaluate(validation_generator)
print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')

# Plot training & validation accuracy values
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot training & validation loss values
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Predict the validation set
validation_generator.reset()
y_pred = model.predict(validation_generator)
y_pred = np.round(y_pred).astype(int).reshape(-1)

# Get true labels
Y_val = validation_generator.classes

# Print classification report
print('Classification Report')
print(classification_report(Y_val, y_pred, target_names=['Pickable', 'UnPickable']))

# Compute and plot confusion matrix
cm = confusion_matrix(Y_val, y_pred)
print('Confusion Matrix')
print(cm)

# Plot confusion matrix
plt.figure(figsize=(6, 6))
plt.imshow(cm, interpolation='nearest', cmap='Blues')
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(2)
plt.xticks(tick_marks, ['Pickable', 'UnPickable'], rotation=45)
plt.yticks(tick_marks, ['Pickable', 'UnPickable'])
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

FileNotFoundError: [Errno 2] No such file or directory: '/kaggle/input/strawberry-dataset/strawberryDataset'