In [None]:
import zipfile
import os

In [None]:
!wget --no-check-certificate \
    "https://github.com/m-sumaim/corn_seed_dataset/archive/refs/heads/main.zip" \
    -O "/tmp/seed_class.zip"


zip_ref = zipfile.ZipFile('/tmp/seed_class.zip', 'r') #Opens the zip file in read mode
zip_ref.extractall('/tmp') #Extracts the files into the /tmp folder
zip_ref.close()

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
#from PIL import Image as image
import numpy as np

# Set the directories for the training and testing data
train_dir = '/tmp/corn_seed_dataset-main/train'
test_dir = '/tmp/corn_seed_dataset-main/validation'

# Create an ImageDataGenerator for the training data with data augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')

# Create an ImageDataGenerator for the testing data without data augmentation
test_datagen = ImageDataGenerator(rescale=1./255)

# Create a train generator that generates batches of augmented images from the training data
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(150, 150),
    batch_size=32,
    class_mode='categorical')

# Create a test generator that generates batches of images from the testing data
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150, 150),
    batch_size=32,
    class_mode='categorical')

# Build the CNN model
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(4, activation='softmax')
])

# Compile the model
model.compile(loss='categorical_crossentropy',
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-4),
              metrics=['accuracy'])

# Train the model
history = model.fit(train_generator,
                    epochs=5,
                    validation_data=test_generator)

# Save the trained model
model.save('corn_seed_classifier.h5')

# Load the trained model
model = tf.keras.models.load_model('corn_seed_classifier.h5')

def predict_image(image_path):
    # Load the image and resize it to the required size
    img = image.load_img(image_path, target_size=(150, 150))

    # Convert the image to a numpy array and normalize it
    x = image.img_to_array(img)
    x = x / 255.0

    # Add a new axis to create a batch of size 1
    x = np.expand_dims(x, axis=0)

    # Make a prediction on the image
    predictions = model.predict(x)

    # Get the index of the class with the highest probability
    class_index = np.argmax(predictions[0])

    # Get the name of the class from the class index
    class_names = ['Broken', 'Discolored', 'Pure', 'Silkcut']
    predicted_class_name = class_names[class_index]

    # Get the probability of the predicted class
    class_probability = predictions[0][class_index]

    return predicted_class_name, class_probability

# Make a prediction on the image
image_path = '/home/test_image.png'
predicted_class_name, class_probability = predict_image(image_path)

# Print the predicted class and its probability
print(f'The predicted class is: {predicted_class_name}')
print(f'The probability of the predicted class is: {class_probability:.2f}')


Found 14327 images belonging to 4 classes.
Found 3474 images belonging to 4 classes.
Epoch 1/5

In [None]:
# Make a prediction on the image
image_path = '/home/broken_test.png'
predicted_class_name, class_probability = predict_image(image_path)

# Print the predicted class and its probability
print(f'The predicted class is: {predicted_class_name}')
print(f'The probability of the predicted class is: {class_probability:.2f}')

The predicted class is: Pure
The probability of the predicted class is: 0.36
