In [6]:
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import matplotlib.pyplot as plt
tf.compat.v1.set_random_seed(0)
from tensorflow import keras
import numpy as np
np.random.seed(0)
import itertools
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import Rescaling
from sklearn.metrics import precision_score, accuracy_score, recall_score, confusion_matrix, ConfusionMatrixDisplay
from collections import Counter


**Loading Data**

In [None]:
train = image_dataset_from_directory(directory="../input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/train",
                                         image_size=(256, 256))
test = image_dataset_from_directory(directory="../input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)/valid",
                                        image_size=(256, 256))

rescale = Rescaling(scale=1.0/255)
train = train.map(lambda image,label:(rescale(image),label))
test  = test.map(lambda image,label:(rescale(image),label))

In [None]:
#Checking class names (disease labels)
class_names = train.class_names
print("Class names:", class_names)
print("Number of classes:", len(class_names))

In [None]:
#Inspecting data structure
for images, labels in train.take(1):  # Take one batch
    print("Images shape:", images.shape)
    print("Labels shape:", labels.shape)

In [None]:
#Displaying some images

# Plot a batch of images
plt.figure(figsize=(10, 10))
for images, labels in train.take(1):  # Take one batch
    for i in range(9):  # Display 9 images
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy())  # Convert tensor to numpy array
        plt.title(class_names[labels[i].numpy()])  # Use the label index to get class name
        plt.axis("off")

In [None]:
# Data distribution

# Count occurrences of each class in the training set
class_counts = Counter()
for _, labels in train.unbatch():  # Unbatch the dataset for individual images/labels
    class_counts[class_names[labels.numpy()]] += 1

# Print the distribution
for cls, count in class_counts.items():
    print(f"{cls}: {count}")

# Visualize the distribution as a bar chart
plt.figure(figsize=(12, 6))
plt.bar(class_counts.keys(), class_counts.values())
plt.xticks(rotation=90)
plt.title("Number of Images per Class")
plt.ylabel("Count")
plt.xlabel("Class Name")
plt.show()

In [None]:
# Dataset properties
print("Training dataset size:", len(train))
print("Testing dataset size:", len(test))

In [None]:
total_train_images = len(list(train.unbatch()))
total_test_images = len(list(test.unbatch()))

print("Total training images:", total_train_images)
print("Total testing images:", total_test_images)

In [None]:
# Checking image properties
for images, _ in train.take(1):
    print("Min pixel value:", np.min(images[0]))
    print("Max pixel value:", np.max(images[0]))

In [None]:
# Class distribution in test set
test_class_counts = Counter()
for _, labels in test.unbatch():
    test_class_counts[class_names[labels.numpy()]] += 1

plt.figure(figsize=(12, 6))
plt.bar(test_class_counts.keys(), test_class_counts.values())
plt.xticks(rotation=90)
plt.title("Number of Images per Class (Test Set)")
plt.ylabel("Count")
plt.xlabel("Class Name")
plt.show()

In [None]:
# Checking class imbalance
plt.figure(figsize=(12, 6))
plt.bar(class_counts.keys(), class_counts.values(), alpha=0.5, label="Train")
plt.bar(test_class_counts.keys(), test_class_counts.values(), alpha=0.5, label="Test")
plt.xticks(rotation=90)
plt.title("Class Distribution: Train vs Test")
plt.ylabel("Count")
plt.xlabel("Class Name")
plt.legend()
plt.show()

In [None]:
# Image variability
plt.figure(figsize=(15, 15))
for cls in class_names[:5]:  # Pick first 5 classes
    for images, labels in train.filter(
        lambda x, y: class_names[y.numpy()[0]] == cls
    ).take(1):
        plt.subplot(5, 5, class_names.index(cls) + 1)
        plt.imshow(images[0].numpy())
        plt.title(cls)
        plt.axis("off")

In [None]:
# Checking for missing or corrupted images
for images, labels in train:
    try:
        tf.debugging.assert_all_finite(images, "Image contains NaN or Inf")
    except:
        print("Corrupted image found")