# Coconut Disease Classification & Diagnostic System
### Classes: CCI_Caterpillars, CCI_Leaflets, Healthy_Leaves, WCLWD_DryingofLeaflets, WCLWD_Flaccidity, WCLWD_Yellowing

## 1. Import Libraries

In [None]:
import os
import zipfile
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Conv2D, MaxPooling2D, Flatten, Dense,
                                     Dropout, LayerNormalization, ZeroPadding2D)
from tensorflow.keras.optimizers import Adam

## 2. Extract Dataset

In [None]:
data_dir = r"d:\New folder (2)"
extract_dir = os.path.join(data_dir, "dataset")

zip_files = ["CCI_Caterpillars.zip", "CCI_Leaflets.zip", "Healthy_Leaves.zip",
             "WCLWD_DryingofLeaflets.zip", "WCLWD_Flaccidity.zip", "WCLWD_Yellowing.zip"]

if os.path.exists(extract_dir):
    print("Dataset already extracted, skipping...")
else:
    for z in zip_files:
        zip_path = os.path.join(data_dir, z)
        if os.path.exists(zip_path):
            print(f"Extracting {z}...")
            with zipfile.ZipFile(zip_path, 'r') as zf:
                zf.extractall(extract_dir)
    print("Done!")

## 3. Check Class Distribution

In [None]:
class_names = sorted(os.listdir(extract_dir))
class_counts = [len(os.listdir(os.path.join(extract_dir, c))) for c in class_names]

for name, count in zip(class_names, class_counts):
    print(f"{name}: {count} images")

plt.figure(figsize=(10, 4))
plt.bar(class_names, class_counts, color='skyblue')
plt.xticks(rotation=30, ha='right')
plt.title('Class Distribution')
plt.ylabel('Images')
plt.tight_layout()
plt.show()
print(f"Imbalance ratio: {max(class_counts)/min(class_counts):.1f}x")

## 4. Data Preparation

In [None]:
IMG_SIZE = 128
BATCH_SIZE = 32

train_datagen = ImageDataGenerator(
    rescale=1.0/255, validation_split=0.2,
    horizontal_flip=True, vertical_flip=True,
    rotation_range=30, zoom_range=0.2,
    shear_range=0.2, width_shift_range=0.1, height_shift_range=0.1
)
val_datagen = ImageDataGenerator(rescale=1.0/255, validation_split=0.2)

train_data = train_datagen.flow_from_directory(
    extract_dir, target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE, class_mode='categorical', subset='training', shuffle=True
)
val_data = val_datagen.flow_from_directory(
    extract_dir, target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE, class_mode='categorical', subset='validation', shuffle=False
)

num_classes = len(train_data.class_indices)
CLASS_NAMES = list(train_data.class_indices.keys())
print(f"Classes: {train_data.class_indices}")
print(f"Train: {train_data.samples} | Val: {val_data.samples}")

## 5. Compute Class Weights

In [None]:
weights = compute_class_weight('balanced', classes=np.unique(train_data.classes), y=train_data.classes)
class_weight = dict(zip(np.unique(train_data.classes), weights))

print("Class Weights:")
for idx, name in enumerate(CLASS_NAMES):
    print(f"  {name}: {class_weight[idx]:.3f}")

## 6. Build CNN + ANN Model

In [None]:
model = Sequential()

# CNN Block 1
model.add(Conv2D(32, (3,3), activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 3)))
model.add(LayerNormalization())
model.add(ZeroPadding2D(padding=(1,1)))
model.add(MaxPooling2D(pool_size=(2,2)))

# CNN Block 2
model.add(Conv2D(64, (3,3), activation='relu'))
model.add(LayerNormalization())
model.add(ZeroPadding2D(padding=(1,1)))
model.add(MaxPooling2D(pool_size=(2,2)))

# CNN Block 3
model.add(Conv2D(128, (3,3), activation='relu'))
model.add(LayerNormalization())
model.add(ZeroPadding2D(padding=(1,1)))
model.add(MaxPooling2D(pool_size=(2,2)))

# Flatten + ANN
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(num_classes, activation='softmax'))

model.summary()

## 7. Compile and Train

In [None]:
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss='categorical_crossentropy', metrics=['accuracy'])

history = model.fit(train_data, validation_data=val_data,
                    epochs=25, class_weight=class_weight)

## 8. Plot Accuracy and Loss

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['accuracy'], label='Train')
ax1.plot(history.history['val_accuracy'], label='Val')
ax1.set_title('Accuracy'); ax1.legend()
ax2.plot(history.history['loss'], label='Train')
ax2.plot(history.history['val_loss'], label='Val')
ax2.set_title('Loss'); ax2.legend()
plt.tight_layout(); plt.show()

## 9. Evaluate and Save Model

In [None]:
val_loss, val_acc = model.evaluate(val_data)
print(f"\nValidation Accuracy: {val_acc*100:.2f}%")
print(f"Validation Loss: {val_loss:.4f}")

model.save(os.path.join(data_dir, "coconut_disease_model.h5"))
print("Model saved!")

---
# Part 2: Industry-Grade Diagnostic System
## 10. Disease Information

In [None]:
DISEASE_INFO = {
    'CCI_Caterpillars': {
        'full_name': 'Coconut Caterpillar Infestation',
        'description': 'Caterpillar attack causing leaf damage.',
        'treatment': {
            'Mild': 'Remove affected leaves. Apply neem oil spray.',
            'Moderate': 'Apply Bt biopesticide. Remove damaged leaves.',
            'Severe': 'Urgent chemical treatment. Consult agriculture officer.'
        }
    },
    'CCI_Leaflets': {
        'full_name': 'Caterpillar Infestation (Leaflets)',
        'description': 'Caterpillar damage on leaflets.',
        'treatment': {
            'Mild': 'Cut damaged leaflets. Spray neem solution.',
            'Moderate': 'Apply systemic insecticide. Monitor spread.',
            'Severe': 'Mass spraying. Report to agriculture dept.'
        }
    },
    'Healthy_Leaves': {
        'full_name': 'Healthy Coconut Leaf',
        'description': 'No disease detected.',
        'treatment': {'Mild': 'No treatment needed.', 'Moderate': 'No treatment needed.', 'Severe': 'No treatment needed.'}
    },
    'WCLWD_DryingofLeaflets': {
        'full_name': 'WCLWD - Drying of Leaflets',
        'description': 'Leaf Wilt Disease causing drying.',
        'treatment': {
            'Mild': 'Remove dried leaflets. Apply NPK fertilizer.',
            'Moderate': 'Root feeding. Apply fungicide.',
            'Severe': 'Tree may need removal. Contact officer.'
        }
    },
    'WCLWD_Flaccidity': {
        'full_name': 'WCLWD - Leaf Flaccidity',
        'description': 'Leaf Wilt Disease causing drooping.',
        'treatment': {
            'Mild': 'Improve drainage. Apply micronutrients.',
            'Moderate': 'Root feeding treatment.',
            'Severe': 'Quarantine tree. Consult pathologist.'
        }
    },
    'WCLWD_Yellowing': {
        'full_name': 'WCLWD - Yellowing',
        'description': 'Leaf Wilt Disease causing yellowing.',
        'treatment': {
            'Mild': 'Apply magnesium sulfate and potash.',
            'Moderate': 'Root feeding. Remove yellowed leaves.',
            'Severe': 'Report to agriculture authority.'
        }
    }
}
print("Disease info loaded!")

## 11. Grad-CAM + Severity + Diagnosis Functions

In [None]:
# find last conv layer
last_conv_name = None
for layer in reversed(model.layers):
    if 'conv2d' in layer.name.lower():
        last_conv_name = layer.name
        break
print(f"Last Conv2D: {last_conv_name}")

# store layer list for manual forward pass
all_layers = model.layers


def predict_with_gradcam(img_batch):
    """Single forward pass: prediction + Grad-CAM heatmap."""
    img_tensor = tf.cast(img_batch, tf.float32)
    with tf.GradientTape() as tape:
        x = img_tensor
        conv_output = None
        for layer in all_layers:
            x = layer(x)
            if layer.name == last_conv_name:
                conv_output = x
                tape.watch(conv_output)
        pred_idx = tf.argmax(x[0])
        class_score = x[:, pred_idx]
    
    grads = tape.gradient(class_score, conv_output)
    pooled = tf.reduce_mean(grads, axis=(0, 1, 2))
    heatmap = conv_output[0] @ pooled[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / (tf.math.reduce_max(heatmap) + 1e-8)
    return x.numpy(), heatmap.numpy()


def make_overlay(img, heatmap, alpha=0.4):
    """Overlay heatmap on image."""
    h = np.uint8(255 * heatmap)
    h = np.array(tf.image.resize(h[..., np.newaxis], (img.shape[0], img.shape[1])))[:,:,0]
    h = h.astype(int)  # IMPORTANT: convert float to int for array indexing
    jet = plt.colormaps['jet']
    colored = jet(np.arange(256))[:, :3][h]
    return np.clip(colored * alpha + img * (1 - alpha), 0, 1)


def get_severity(confidence, heatmap):
    """Severity from confidence + heatmap."""
    affected = np.mean(heatmap > 0.3)
    intensity = np.mean(heatmap[heatmap > 0.3]) if np.any(heatmap > 0.3) else 0.0
    score = min((0.4*confidence) + (0.35*affected) + (0.25*intensity), 1.0)
    if score < 0.35: return score, 'Mild'
    elif score < 0.65: return score, 'Moderate'
    else: return score, 'Severe'


def diagnose_leaf(img_path, show_plot=True):
    """Full diagnosis: predict, Grad-CAM, severity, treatment."""
    img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
    img_arr = image.img_to_array(img) / 255.0
    img_batch = np.expand_dims(img_arr, axis=0)
    
    preds, heatmap = predict_with_gradcam(img_batch)
    idx = np.argmax(preds[0])
    disease = CLASS_NAMES[idx]
    conf = float(preds[0][idx])
    
    if disease == 'Healthy_Leaves':
        sev_score, sev_label = 0.0, 'Healthy'
    else:
        sev_score, sev_label = get_severity(conf, heatmap)
    
    info = DISEASE_INFO[disease]
    treat_key = 'Mild' if sev_label == 'Healthy' else sev_label
    treatment = info['treatment'][treat_key]
    
    if show_plot:
        overlay = make_overlay(img_arr, heatmap)
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(img_arr); axes[0].set_title('Original'); axes[0].axis('off')
        axes[1].imshow(heatmap, cmap='jet'); axes[1].set_title('Grad-CAM (Red=Affected)'); axes[1].axis('off')
        axes[2].imshow(overlay); axes[2].set_title('Overlay'); axes[2].axis('off')
        plt.suptitle(disease, fontsize=14, fontweight='bold')
        plt.tight_layout(); plt.show()
    
    print("=" * 55)
    print("     COCONUT LEAF DIAGNOSTIC REPORT")
    print("=" * 55)
    print(f"  Disease    : {info['full_name']}")
    print(f"  Confidence : {conf*100:.1f}%")
    print(f"  Severity   : {sev_label} ({sev_score:.2f})")
    print(f"  Info       : {info['description']}")
    print("-" * 55)
    print(f"  Treatment  : {treatment}")
    print("=" * 55)
    return {'disease': disease, 'confidence': conf, 'severity': sev_label}


# quick test
dummy = np.zeros((1, IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)
p, h = predict_with_gradcam(dummy)
print(f"Grad-CAM test passed! Heatmap: {h.shape}")

## 12. Test on All Classes

In [None]:
for class_name in CLASS_NAMES:
    folder = os.path.join(extract_dir, class_name)
    imgs = [f for f in os.listdir(folder) if f.lower().endswith(('.jpg','.jpeg','.png','.bmp'))]
    if imgs:
        print(f"\nTesting: {class_name}")
        diagnose_leaf(os.path.join(folder, imgs[0]))
        print()

## 13. Batch Diagnosis Table

In [None]:
import random

print(f"{'Image':<30} {'Predicted':<25} {'Conf':>6} {'Severity':>10}")
print("-" * 75)
for cn in CLASS_NAMES:
    folder = os.path.join(extract_dir, cn)
    imgs = [f for f in os.listdir(folder) if f.lower().endswith(('.jpg','.jpeg','.png','.bmp'))]
    for name in random.sample(imgs, min(2, len(imgs))):
        r = diagnose_leaf(os.path.join(folder, name), show_plot=False)
        print(f"{name:<30} {r['disease']:<25} {r['confidence']*100:>5.1f}% {r['severity']:>9}")
print("-" * 75)