In [None]:
import os
import torch
from PIL import Image
from transformers import (
    BlipProcessor, BlipForConditionalGeneration,
    VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer,
    AutoImageProcessor, AutoModelForImageClassification
)
import torch.nn as nn
from torchvision import transforms
import json

device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
# ==============================
# Cell 1: Imports and device
# ==============================
import os
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification, BlipProcessor, BlipForConditionalGeneration, VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from torchvision import transforms
from openai import OpenAI
import json

device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
# ==============================
# Fixed Class Mapping
# ==============================
DATASET_DIR = r"D:\T2430458\final-dataset"

# Choose one split (e.g., 'train') to build consistent classes
SPLIT_DIR = os.path.join(DATASET_DIR, "train")

classes = []

for crop_folder in sorted(os.listdir(SPLIT_DIR)):
    crop_path = os.path.join(SPLIT_DIR, crop_folder)
    if os.path.isdir(crop_path):
        for disease_folder in sorted(os.listdir(crop_path)):
            disease_path = os.path.join(crop_path, disease_folder)
            if os.path.isdir(disease_path):
                # Normalize healthy class names
                if disease_folder.lower() in ["fresh leaf", "normal leaf"]:
                    disease_folder = "Healthy"
                classes.append(f"{crop_folder}/{disease_folder}")

# Build mappings
id2label = {i: name for i, name in enumerate(classes)}
label2id = {v: k for k, v in id2label.items()}

crop_model.config.id2label = id2label
crop_model.config.label2id = label2id

print("Number of classes:", len(classes))
print("Sample classes:", list(id2label.items())[:10])

Number of classes: 88
Sample classes: [(0, 'Apple/Alternaria Blotch'), (1, 'Apple/Black Rot'), (2, 'Apple/Brown Spot'), (3, 'Apple/Cedar Apple Rust'), (4, 'Apple/Frog Eye Leaf Spot'), (5, 'Apple/Grey Spot'), (6, 'Apple/Healthy'), (7, 'Apple/Leaf Rust'), (8, 'Apple/Mosaic Virus'), (9, 'Apple/Powdery Mildew')]


In [None]:
# ==============================
# Cell 3: Load Captioning Models
# ==============================
print("Loading captioning models...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-large"
).to(device)

vit_model = VisionEncoderDecoderModel.from_pretrained(
    "nlpconnect/vit-gpt2-image-captioning"
).to(device)
vit_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
vit_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")


Loading captioning models...


In [None]:
# ==============================
# Cell 4: Image transforms
# ==============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


In [None]:
# ==============================
# Cell 5: Caption helpers
# ==============================
def generate_blip_caption(image):
    inputs = blip_processor(image, return_tensors="pt").to(device)
    out = blip_model.generate(**inputs, max_length=50)
    return blip_processor.decode(out[0], skip_special_tokens=True)

def generate_vit_caption(image):
    pixel_values = vit_processor(images=image, return_tensors="pt").pixel_values.to(device)
    out = vit_model.generate(pixel_values, max_length=50)
    return vit_tokenizer.decode(out[0], skip_special_tokens=True)


In [None]:
# ==============================
# Cell 6: Prediction function
# ==============================
def predict_crop_and_disease(image_path, crop_model, transform):
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)

    crop_model.eval()
    with torch.no_grad():
        out = crop_model(img_tensor).logits
        pred_idx = torch.argmax(out, 1).item()
        class_name = crop_model.config.id2label[pred_idx]

    # Split crop/disease
    if "/" in class_name:
        crop_name, disease_name = class_name.split("/")
    else:
        crop_name, disease_name = class_name, "Healthy"

    return crop_name, disease_name


In [None]:
# ==============================
# Cell 7: Build GPT input
# ==============================
def build_gpt_input(image_path, crop_model, transform, user_question=None):
    crop_name, disease_name = predict_crop_and_disease(image_path, crop_model, transform)
    image = Image.open(image_path).convert("RGB")
    blip_caption = generate_blip_caption(image)
    vit_caption = generate_vit_caption(image)
    merged_caption = f"{blip_caption} and {vit_caption}"

    context = f"The crop is {crop_name} and the disease is {disease_name}. {merged_caption}."

    if user_question:
        return f"{context} Question: {user_question}"
    else:
        return f"{context} What are the symptoms and cure for this disease? Return JSON with keys 'symptoms' and 'cure'."


In [None]:
# ==============================
# Cell 8: OpenAI GPT Integration
# ==============================
# API key removed from repository — load from environment variable
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
    raise RuntimeError("OPENAI_API_KEY not set in environment.")
client = OpenAI(api_key=api_key)

def ask_gpt(final_input_string, json_mode=True):
    if json_mode:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": (
                    "You are an agricultural assistant that answers about crop diseases. "
                    "Always return answers in JSON format with keys 'symptoms' and 'cure'."
                )},
                {"role": "user", "content": final_input_string}
            ],
            max_tokens=300,
            response_format={"type": "json_object"}
        )
    else:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": (
                    "You are an agricultural assistant that answers any user question about crops and diseases."
                )},
                {"role": "user", "content": final_input_string}
            ],
            max_tokens=300
        )
    return response.choices[0].message.content


In [None]:
# ==============================
# Cell 9: Example Usage
# ==============================
test_img = r"d:\T2430458\dataset\Corn\Gray_leaf_spot\33333.png"

if os.path.exists(test_img):
    # JSON Mode
    gpt_input_json = build_gpt_input(test_img, crop_model, transform)
    gpt_response_json = ask_gpt(gpt_input_json, json_mode=True)
    parsed_json = json.loads(gpt_response_json)

    crop_name, disease_name = predict_crop_and_disease(test_img, crop_model, transform)

    result_json = {
        "crop": crop_name,
        "disease": disease_name,
        "symptoms": parsed_json.get("symptoms", "N/A"),
        "cure": parsed_json.get("cure", "N/A")
    }
    print("\n--- Structured JSON Mode ---")
    print(json.dumps(result_json, indent=2))

    # Free VQA Mode
    user_question = "Can this disease affect other crops like tea? Ans in Bangla"
    gpt_input_vqa = build_gpt_input(test_img, crop_model, transform, user_question=user_question)
    gpt_response_vqa = ask_gpt(gpt_input_vqa, json_mode=False)
    print("\n--- Free VQA Mode ---")
    print("Question:", user_question)
    print("Answer:", gpt_response_vqa)

else:
    print("\n⚠️ Test image not found.")



--- Structured JSON Mode ---
{
  "crop": "Corn",
  "disease": "Gray_leaf_spot",
  "symptoms": "Gray leaf spot on corn typically presents as elongated lesions with grayish to tan centers and dark borders, generally appearing on the older leaves. The leaves may show yellowing around the spots, and extensive damage can lead to premature leaf death.",
  "cure": "To manage gray leaf spot, practice crop rotation and select resistant corn hybrids. Fungicides may also be applied at the appropriate time, particularly when conditions are favorable for disease development. Ensure proper nitrogen management and good air circulation in the crop."
}

--- Free VQA Mode ---
Question: Can this disease affect other crops like tea? Ans in Bangla
Answer: গ্রে লিফ স্পট (Gray Leaf Spot) একটি রোগ যা সাধারণত ভুট্টা এবং কিছু অন্যান্য শষ্যের ওপর প্রভাব ফেলে। তবে, এটি চা গাছের (Tea plant) ওপর সরাসরি প্রভাব ফেলার কোনো প্রমাণ নেই। প্রতিটি শস্যের নিজস্ব রোগবাহক এবং পরিবেশগত চাহিদা থাকে, তাই এই রোগ চা গাছের জন্য সম