<a href="https://colab.research.google.com/github/cuducquang/ML_Project/blob/main/final_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Libraries

## This cell imports the essential Python libraries used throughout the notebook for various tasks such as file handling, data processing, image manipulation, and deep learning.

In [1]:
from google.colab import drive
import gdown
import os
import shutil
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import random
import hashlib
from PIL import Image
import tensorflow as tf
from google.colab import files

# Download Test Images

## This cell downloads a zip file (test_images.zip) containing the test images from Google Drive using the gdown library.

In [2]:
# Google Drive file ID (from the shared link)
file_test_id = "1othgf5BTO_sZYXBOWykn2OkitSCna7J6"
file_test_name = "test_images.zip"
file_test_path = "/content/" + file_test_name

# Download the file
gdown.download(f"https://drive.google.com/uc?id={file_test_id}", file_test_path, quiet=False)
print(f"Downloaded: {file_test_path}")

Downloading...
From (original): https://drive.google.com/uc?id=1othgf5BTO_sZYXBOWykn2OkitSCna7J6
From (redirected): https://drive.google.com/uc?id=1othgf5BTO_sZYXBOWykn2OkitSCna7J6&confirm=t&uuid=629511c3-4e62-492d-a010-08c4f92d7e42
To: /content/test_images.zip
100%|██████████| 274M/274M [00:04<00:00, 64.1MB/s]

Downloaded: /content/test_images.zip





# Unzip Test Images

## This cell extracts the contents of the test_images.zip file into a directory named extracted_test_folder in the Colab environment.

In [3]:
!unzip -q $file_test_path -d /content/extracted_test_folder

# Download Models and Test CSV

## This cell downloads the pre-trained models and a prediction CSV file from Google Drive, which are necessary for making predictions on the test images.

In [None]:
# ==== Download models from Google Drive ====
# Disease model (Keras)
gdown.download(id="1VLPAeiCB1CnTR_SRKW7ssXhBXgqoxuUe", output="disease_model.h5", quiet=False)

# Variety model (Keras)
gdown.download(id="1FfyDsHhZY70FpgZbuP7gve71fQScdTFh", output="variety_model.h5", quiet=False)

# Age model (PyTorch)
gdown.download(id="1NAM2FhkpTCEuqTrbnN6wtOHOZTmaRgjD", output="age_model.pth", quiet=False)

# Test CSV
gdown.download(id="1mL7YyoufZE12-tMuIxlks48IRkYPDeU4", output="test.csv", quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=16RTT_6ULyt2rgGD0A5b-OE_WcS9su62R
From (redirected): https://drive.google.com/uc?id=16RTT_6ULyt2rgGD0A5b-OE_WcS9su62R&confirm=t&uuid=db724e87-0258-43e9-bc6c-9ff37fcd3918
To: /content/disease_model.h5
100%|██████████| 320M/320M [00:05<00:00, 58.4MB/s]
Downloading...
From: https://drive.google.com/uc?id=1FfyDsHhZY70FpgZbuP7gve71fQScdTFh
To: /content/variety_model.h5
100%|██████████| 34.7M/34.7M [00:00<00:00, 47.3MB/s]
Downloading...
From: https://drive.google.com/uc?id=1NAM2FhkpTCEuqTrbnN6wtOHOZTmaRgjD
To: /content/age_model.pth
100%|██████████| 19.5M/19.5M [00:00<00:00, 33.7MB/s]
Downloading...
From: https://drive.google.com/uc?id=1mL7YyoufZE12-tMuIxlks48IRkYPDeU4
To: /content/test.csv
100%|██████████| 52.1k/52.1k [00:00<00:00, 2.71MB/s]


'test.csv'

# Load Models and Make Predictions

## This cell loads the pre-trained disease, variety, and age models, processes the test images, makes predictions, and saves the results to a CSV file (prediction_output.csv).

In [6]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
from tensorflow.keras.models import load_model

# ==== Load models ====
disease_model = load_model('/content/disease_model.h5')
variety_model = load_model('/content/variety_model.h5')

class AgeRegressor(nn.Module):
    def __init__(self, num_labels, num_varieties):
        super().__init__()
        self.base = models.efficientnet_b0(weights=None)
        self.base.classifier = nn.Identity()
        self.metadata_dim = num_labels + num_varieties + 1
        self.head = nn.Sequential(
            nn.Linear(1280 + self.metadata_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x, label_vec, variety_vec, grvi_scalar):
        features = self.base(x)
        meta = torch.cat([label_vec, variety_vec, grvi_scalar], dim=1)
        combined = torch.cat([features, meta], dim=1)
        return self.head(combined)

age_model = AgeRegressor(num_labels=3, num_varieties=17)
age_model.load_state_dict(torch.load('/content/age_model.pth'))
age_model.eval()


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


df = pd.read_csv('/content/test.csv')

label_classes = ['bacterial_leaf_blight', 'bacterial_leaf_streak', 'bacterial_panicle_blight', 'blast', 'brown_spot', 'dead_heart', 'downy_mildew', 'hispa', 'normal', 'tungro']
variety_classes = ['ADT45', 'AndraPonni', 'AtchayaPonni', 'IR20', 'KarnatakaPonni', 'Onthanel', 'Ponni', 'RR', 'Surya', 'Zonal']

img_folder = '/content/extracted_test_folder/test_images'

label_preds = []
variety_preds = []
age_preds = []

for idx, row in df.iterrows():
    img_path = os.path.join(img_folder, row['image_id'])
    image = Image.open(img_path).convert("RGB")

    # For disease/variety model (Keras - expect Numpy)
    img_np = np.array(image.resize((224, 224))) / 255.0
    img_np = np.expand_dims(img_np, axis=0)

    # For age model (PyTorch - expect Tensor)
    img_tensor = transform(image).unsqueeze(0)

    # === Predict label ===
    label_logits = disease_model.predict(img_np, verbose=0)
    label_idx = np.argmax(label_logits)
    label_name = label_classes[label_idx]
    label_preds.append(label_name)

    # === Predict variety ===
    variety_logits = variety_model.predict(img_np, verbose=0)
    variety_idx = np.argmax(variety_logits)
    variety_name = variety_classes[variety_idx]
    variety_preds.append(variety_name)

    # === Predict age ===
    label_onehot = torch.zeros((1, 3))
    if label_idx < 3:  # Ensure index safe for 3-class onehot
        label_onehot[0, label_idx] = 1.0

    variety_onehot = torch.zeros((1, 17))
    if variety_idx < 17:
        variety_onehot[0, variety_idx] = 1.0

    grvi_scalar = torch.tensor([[0.5]])  # Default GRVI if not known

    with torch.no_grad():
        age_pred = age_model(img_tensor, label_onehot, variety_onehot, grvi_scalar)
        age_preds.append(int(round(age_pred.item())))

# ==== Save to CSV ====
df["label"] = label_preds
df["variety"] = variety_preds
df["age"] = age_preds
df.to_csv('/content/prediction_output.csv', index=False)
print("prediction_output.csv generated.")



prediction_output.csv generated.
