# Disease Detection (CNN on PlantVillage)

In [1]:
import splitfolders

input_folder = "PlantVillage"

splitfolders.ratio(input_folder, output="PlantVillage_split", seed=42, ratio=(.8, .2))


Copying files: 20639 files [00:06, 3184.34 files/s]


In [2]:
train_dir = "PlantVillage_split/train"
val_dir = "PlantVillage_split/val"


In [3]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Image generators
datagen = ImageDataGenerator(rescale=1./255)

train_data = datagen.flow_from_directory(
    train_dir,
    target_size=(128,128),
    batch_size=32,
    class_mode='categorical'
)

val_data = datagen.flow_from_directory(
    val_dir,
    target_size=(128,128),
    batch_size=32,
    class_mode='categorical'
)

# Simple CNN model
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(128,128,3)),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(train_data.num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_data, validation_data=val_data, epochs=3)

model.save("disease_model.h5")
print("Model saved as disease_model.h5")


Found 16504 images belonging to 16 classes.
Found 4135 images belonging to 16 classes.
Epoch 1/3


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


[1m516/516[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 143ms/step - accuracy: 0.6267 - loss: 1.1742 - val_accuracy: 0.7613 - val_loss: 0.7022
Epoch 2/3
[1m516/516[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 153ms/step - accuracy: 0.8459 - loss: 0.4665 - val_accuracy: 0.8655 - val_loss: 0.4005
Epoch 3/3
[1m516/516[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m80s[0m 155ms/step - accuracy: 0.9169 - loss: 0.2550 - val_accuracy: 0.8450 - val_loss: 0.4744




Model saved as disease_model.h5


# Crop Recommendation (ML)

In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import pickle

# Load dataset
df = pd.read_csv("Crop_recommendation.csv")

X = df.drop("label", axis=1)
y = df["label"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

pickle.dump(model, open("crop_model.pkl", "wb"))


Accuracy: 0.9931818181818182


# Market Prices (Simple Lookup)

In [9]:
import pandas as pd

# Load dataset
df = pd.read_csv("Price_Agriculture_commodities_Week.csv")

print(df.head())
print(df.columns)


     State District    Market              Commodity      Variety Grade  \
0  Gujarat   Amreli  Damnagar  Bhindi(Ladies Finger)       Bhindi   FAQ   
1  Gujarat   Amreli  Damnagar                Brinjal        Other   FAQ   
2  Gujarat   Amreli  Damnagar                Cabbage      Cabbage   FAQ   
3  Gujarat   Amreli  Damnagar            Cauliflower  Cauliflower   FAQ   
4  Gujarat   Amreli  Damnagar      Coriander(Leaves)    Coriander   FAQ   

  Arrival_Date  Min Price  Max Price  Modal Price  
0   27-07-2023     4100.0     4500.0       4350.0  
1   27-07-2023     2200.0     3000.0       2450.0  
2   27-07-2023     2350.0     3000.0       2700.0  
3   27-07-2023     7000.0     7500.0       7250.0  
4   27-07-2023     8400.0     9000.0       8850.0  
Index(['State', 'District', 'Market', 'Commodity', 'Variety', 'Grade',
       'Arrival_Date', 'Min Price', 'Max Price', 'Modal Price'],
      dtype='object')


In [10]:
df = df.rename(columns={
    "Commodity": "commodity",
    "Arrival_Date": "date",
    "Modal Price": "modal_price"
})

df["date"] = pd.to_datetime(df["date"], format="%d-%m-%Y")

In [11]:
def get_price_trend(crop_name, state=None, district=None):
    crop_data = df[df["commodity"].str.lower() == crop_name.lower()]
    
    if state:
        crop_data = crop_data[df["State"].str.lower() == state.lower()]
    if district:
        crop_data = crop_data[df["District"].str.lower() == district.lower()]
    
    if crop_data.empty:
        return None
    
    crop_data = crop_data.sort_values("date")
    return crop_data[["date", "State", "District", "Market", "modal_price"]].tail(7)


# Wrap the models into functions

In [12]:
import tensorflow as tf
import pickle
import pandas as pd

# Load disease detection model
disease_model = tf.keras.models.load_model("disease_model.h5")

# Load crop recommendation model
crop_model = pickle.load(open("crop_model.pkl", "rb"))

# Load price dataset
df_price = pd.read_csv("Price_Agriculture_commodities_Week.csv")
df_price = df_price.rename(columns={
    "Commodity": "commodity",
    "Arrival_Date": "date",
    "Modal Price": "modal_price"
})
df_price["date"] = pd.to_datetime(df_price["date"], format="%d-%m-%Y")




# Functions

In [13]:
# --- Plant Disease Prediction ---
import numpy as np
from tensorflow.keras.preprocessing import image

def predict_disease(img_path):
    img = image.load_img(img_path, target_size=(128,128))
    img_array = image.img_to_array(img)/255.0
    img_array = np.expand_dims(img_array, axis=0)
    
    prediction = disease_model.predict(img_array)
    class_index = np.argmax(prediction)
    class_labels = list(train_data.class_indices.keys())  # from earlier ImageDataGenerator
    return class_labels[class_index]


# --- Crop Recommendation ---
def recommend_crop(N, P, K, temperature, humidity, ph, rainfall):
    features = [[N, P, K, temperature, humidity, ph, rainfall]]
    prediction = crop_model.predict(features)
    return prediction[0]


# --- Price Trend ---
def get_price_trend(crop_name, state=None, district=None):
    crop_data = df_price[df_price["commodity"].str.lower() == crop_name.lower()]
    
    if state:
        crop_data = crop_data[df_price["State"].str.lower() == state.lower()]
    if district:
        crop_data = crop_data[df_price["District"].str.lower() == district.lower()]
    
    if crop_data.empty:
        return None
    
    crop_data = crop_data.sort_values("date")
    return crop_data[["date", "State", "District", "Market", "modal_price"]].tail(7)


In [14]:
# Disease detection
print("Disease:", predict_disease("PlantVillage_split/val/Potato___healthy/0be9d721-82f5-42c3-b535-7494afe01dbe___RS_HL 1814.JPG"))

# Crop recommendation
print("Recommended crop:", recommend_crop(90, 42, 43, 25, 80, 6.5, 200))

# Price trend
print(get_price_trend("Cabbage", state="Gujarat", district="Amreli"))


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 48ms/step
Disease: Potato___healthy
Recommended crop: rice
            date    State District    Market  modal_price
2     2023-07-27  Gujarat   Amreli  Damnagar       2700.0
312   2023-07-28  Gujarat   Amreli  Damnagar       1850.0
5677  2023-07-29  Gujarat   Amreli  Damnagar       1350.0
8964  2023-07-30  Gujarat   Amreli  Damnagar       2250.0
9806  2023-07-31  Gujarat   Amreli  Damnagar       2000.0
14768 2023-08-01  Gujarat   Amreli     Dhari       2400.0
22913 2023-08-02  Gujarat   Amreli  Damnagar       2350.0


  crop_data = crop_data[df_price["State"].str.lower() == state.lower()]
  crop_data = crop_data[df_price["District"].str.lower() == district.lower()]


# Chatbot

In [21]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import tensorflow as tf
import pickle
import pandas as pd
import numpy as np
from PIL import Image
import io
import os

# Plant disease CNN
disease_model = tf.keras.models.load_model("disease_model.h5")

# Get class names dynamically from training folder
train_dir = "PlantVillage_split/train"
train_data_classes = sorted(os.listdir(train_dir))
print(f"Loaded {len(train_data_classes)} disease classes.")

# Crop recommendation model
crop_model = pickle.load(open("crop_model.pkl", "rb"))

# Market prices
df_price = pd.read_csv("Price_Agriculture_commodities_Week.csv")
df_price = df_price.rename(columns={"Commodity": "commodity", 
                                    "Arrival_Date": "date", 
                                    "Modal Price": "modal_price"})
df_price["date"] = pd.to_datetime(df_price["date"], format="%d-%m-%Y")


def predict_disease_from_file(file):
    img = Image.open(file).convert("RGB").resize((128,128))
    img_array = np.array(img)/255.0
    img_array = np.expand_dims(img_array, axis=0)
    pred = disease_model.predict(img_array)
    return train_data_classes[np.argmax(pred)]

def recommend_crop(N, P, K, temp, humidity, ph, rainfall):
    features = [[N, P, K, temp, humidity, ph, rainfall]]
    return crop_model.predict(features)[0]

def get_price_trend(crop_name, state=None, district=None):
    crop_data = df_price[df_price["commodity"].str.lower() == crop_name.lower()]
    if state:
        crop_data = crop_data[df_price["State"].str.lower() == state.lower()]
    if district:
        crop_data = crop_data[df_price["District"].str.lower() == district.lower()]
    if crop_data.empty:
        return None
    crop_data = crop_data.sort_values("date")
    return crop_data[["date", "State", "District", "Market", "modal_price"]].tail(7)

out = widgets.Output()

# Module selection
module_dropdown = widgets.Dropdown(
    options=["Plant Disease Detection", "Crop Recommendation", "Market Prices"],
    description="Module:"
)

# Disease detection widgets
upload = widgets.FileUpload(accept='image/*', multiple=False)
disease_button = widgets.Button(description="Predict Disease")

# Crop recommendation widgets
N_input = widgets.IntText(description="N")
P_input = widgets.IntText(description="P")
K_input = widgets.IntText(description="K")
temp_input = widgets.FloatText(description="Temp")
humidity_input = widgets.FloatText(description="Humidity")
ph_input = widgets.FloatText(description="pH")
rainfall_input = widgets.FloatText(description="Rainfall")
crop_button = widgets.Button(description="Recommend Crop")

# Market price widgets
crop_name_input = widgets.Text(description="Crop")
state_input = widgets.Text(description="State")
district_input = widgets.Text(description="District")
price_button = widgets.Button(description="Get Prices")

# Display
def on_module_change(change):
    clear_output(wait=True)
    display(module_dropdown, out)
    with out:
        clear_output()
        if change['new'] == "Plant Disease Detection":
            display(upload, disease_button)
        elif change['new'] == "Crop Recommendation":
            display(N_input, P_input, K_input, temp_input, humidity_input, ph_input, rainfall_input, crop_button)
        elif change['new'] == "Market Prices":
            display(crop_name_input, state_input, district_input, price_button)

def on_disease_click(b):
    with out:
        clear_output(wait=True)
        if upload.value:
            uploaded_file = upload.value[0]  # ipywidgets v8+ returns tuple
            file_obj = io.BytesIO(uploaded_file.content)  # use .content
            result = predict_disease_from_file(file_obj)
            print(f"Predicted Disease: {result}")
        else:
            print("Please upload an image.")

def on_crop_click(b):
    with out:
        clear_output(wait=True)
        result = recommend_crop(N_input.value, P_input.value, K_input.value, temp_input.value, humidity_input.value, ph_input.value, rainfall_input.value)
        print(f"Recommended Crop: {result}")

def on_price_click(b):
    with out:
        clear_output(wait=True)
        result = get_price_trend(crop_name_input.value, state_input.value, district_input.value)
        if result is None:
            print("No data available for this crop/location.")
        else:
            print(result)

module_dropdown.observe(on_module_change, names='value')
disease_button.on_click(on_disease_click)
crop_button.on_click(on_crop_click)
price_button.on_click(on_price_click)

display(module_dropdown, out)
on_module_change({'new': module_dropdown.value})


Dropdown(description='Module:', options=('Plant Disease Detection', 'Crop Recommendation', 'Market Prices'), v…

Output()