# Species classifier on PlantVillage, Cassava, Rice

### create a single species dataset

In [None]:
import os, shutil

BASE = "/content/processed"
OUTPUT = f"{BASE}/species"
os.makedirs(OUTPUT, exist_ok=True)

def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


# ------------------------------------------------------------
# 1) PLANTVILLAGE â†’ extract species from folder namess
# ------------------------------------------------------------
def collect_plantvillage(path):
    print("Processing PlantVillage...")

    for split in ["train", "val"]:
        split_dir = f"{path}/{split}"
        if not os.path.exists(split_dir):
            continue

        for class_folder in os.listdir(split_dir):
            full = os.path.join(split_dir, class_folder)
            if not os.path.isdir(full):
                continue

            # Extract species name (before the three underscores)
            species = class_folder.split("___")[0]
            species_dir = os.path.join(OUTPUT, species)
            ensure_dir(species_dir)

            # Copy images
            for img in os.listdir(full):
                src = os.path.join(full, img)
                dst = os.path.join(species_dir, img)
                shutil.copy(src, dst)


# ------------------------------------------------------------
# 2) SINGLE-SPECIES DATASETS (Cassava, Rice)
# ------------------------------------------------------------
def collect_single_species_with_diseases(path, species):
    print(f"Processing {species}...")

    species_dir = os.path.join(OUTPUT, species)
    ensure_dir(species_dir)

    for split in ["train", "val"]:
        split_dir = os.path.join(path, split)
        if not os.path.exists(split_dir):
            continue

        # Each subfolder is a disease class, but species is always the same
        for disease_folder in os.listdir(split_dir):
            full = os.path.join(split_dir, disease_folder)
            if not os.path.isdir(full):
                continue

            for img in os.listdir(full):
                src = os.path.join(full, img)
                dst = os.path.join(species_dir, img)
                shutil.copy(src, dst)


# ------------------------------------------------------------
# 3) RUN
# ------------------------------------------------------------

collect_plantvillage(f"{BASE}/plantVillage")
collect_single_species_with_diseases(f"{BASE}/cassava", "Cassava")
collect_single_species_with_diseases(f"{BASE}/riceleaf", "Rice")

print("Unified species dataset created at:", OUTPUT)


### Create species_split
  - train
  - val
  - test

In [None]:
import os, shutil, random

SOURCE = "/content/processed/species"
OUTPUT = "/content/processed/species_split"

train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

os.makedirs(OUTPUT, exist_ok=True)

for split in ["train", "val", "test"]:
    os.makedirs(os.path.join(OUTPUT, split), exist_ok=True)


def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


print("Splitting species into train/val/test...")

for species in os.listdir(SOURCE):
    sp_dir = os.path.join(SOURCE, species)
    if not os.path.isdir(sp_dir):
        continue

    images = os.listdir(sp_dir)
    random.shuffle(images)

    n = len(images)
    n_train = int(train_ratio * n)
    n_val = int(val_ratio * n)

    train_imgs = images[:n_train]
    val_imgs   = images[n_train:n_train+n_val]
    test_imgs  = images[n_train+n_val:]

    # create species folders
    for split in ["train", "val", "test"]:
        ensure_dir(os.path.join(OUTPUT, split, species))

    # copy images
    for img in train_imgs:
        shutil.copy(os.path.join(sp_dir, img),
                    os.path.join(OUTPUT, "train", species, img))

    for img in val_imgs:
        shutil.copy(os.path.join(sp_dir, img),
                    os.path.join(OUTPUT, "val", species, img))

    for img in test_imgs:
        shutil.copy(os.path.join(sp_dir, img),
                    os.path.join(OUTPUT, "test", species, img))

print("Done! Split dataset created at:", OUTPUT)


### Species Classifier (ViT) - Training code

In [None]:
# ---------------------------------------------------
# Train ViT Species Classifier (Correct Final Version)
# ---------------------------------------------------

import os
import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from tqdm import tqdm

DATA_DIR = "/content/processed/species_split"
MODEL_NAME = "species_classifier_vit.pth"

IMG_SIZE = 224
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 12

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


# ------------------------------
# Transforms
# ------------------------------
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

test_tfms = val_tfms


# ------------------------------
# Datasets
# ------------------------------
train_ds = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_tfms)
val_ds   = datasets.ImageFolder(os.path.join(DATA_DIR, "val"),   transform=val_tfms)
test_ds  = datasets.ImageFolder(os.path.join(DATA_DIR, "test"),  transform=test_tfms)

NUM_CLASSES = len(train_ds.classes)
print("Detected species:", train_ds.classes)


# ------------------------------
# DataLoaders
# ------------------------------
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
test_dl  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)


# ------------------------------
# ViT Model
# ------------------------------
model = models.vit_b_16(weights="IMAGENET1K_V1")
model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)


# ------------------------------
# Training Loop
# ------------------------------
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels in tqdm(train_dl, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in val_dl:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            preds = model(imgs).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = correct / total * 100
    print(f"Epoch {epoch+1}: Train Loss={total_loss/len(train_dl):.4f} | Val Acc={val_acc:.2f}%")


# ------------------------------
# Final Test Accuracy
# ------------------------------
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for imgs, labels in test_dl:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        preds = model(imgs).argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print("Test Accuracy:", correct / total * 100, "%")


# ------------------------------
# Save model
# ------------------------------
torch.save(model.state_dict(), MODEL_NAME)
print("Saved:", MODEL_NAME)
