In [None]:
from typing import Literal, Optional
from numpy.typing import NDArray

from pathlib import Path
import yaml
import torch
import pprint
import numpy as np
import pandas as pd
import imageio.v3 as iio
import matplotlib.pyplot as plt

from torchvision.models import resnet18
from torchvision.transforms import v2 as T
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

In [None]:
# DATASET_DIR = Path.home() / "datasets" / "plant_village"
# EXPERIMENTS_DIR = Path.home() / "experiments" / "plant_village_classification" / "plant_disease_tuning"

DATASET_DIR = Path.cwd() / "plant_village"
EXPERIMENTS_DIR = Path.cwd() / "plant_disease_tuning" 
assert DATASET_DIR.is_dir()
assert EXPERIMENTS_DIR.is_dir()

In [None]:
class PlantDiseaseClassification(Dataset):
    CLASS_NAMES = (
        'apple-apple_scab',
        'apple-black_rot',
        'apple-cedar_apple_rust',
        'apple-healthy',
        'banana-bbs',
        'banana-bbw',
        'banana-healthy',
        'blueberry-healthy',
        'cherry_(including_sour)-healthy',
        'cherry_(including_sour)-powdery_mildew',
        'corn_(maize)-cercospora_leaf_spot gray_leaf_spot',
        'corn_(maize)-common_rust_',
        'corn_(maize)-healthy',
        'corn_(maize)-northern_leaf_blight',
        'grape-black_rot',
        'grape-esca_(black_measles)',
        'grape-healthy',
        'grape-leaf_blight_(isariopsis_leaf_spot)',
        'groundhut-healthy',
        'groundnut-deficient',
        'groundnut-early_leaf_spot',
        'groundnut-early_rust',
        'groundnut-late_leaf_spot',
        'groundnut-rust',
        'orange-haunglongbing_(citrus_greening)',
        'paddy-blight',
        'paddy-healthy',
        'paddy-smut',
        'paddy_brownspot-paddy_brownspot',
        'peach-bacterial_spot',
        'peach-healthy',
        'pepper_bell-bacterial_spot',
        'pepper_bell-healthy',
        'potato-early_blight',
        'potato-healthy',
        'potato-late_blight',
        'raspberry-healthy',
        'soybean-healthy',
        'squash-powdery_mildew',
        'strawberry-healthy',
        'strawberry-leaf_scorch',
        'tomato-bacterial_spot',
        'tomato-early_blight',
        'tomato-healthy',
        'tomato-late_blight',
        'tomato-leaf_mold',
        'tomato-septoria_leaf_spot',
        'tomato-spider_mites two-spotted_spider_mite',
        'tomato-target_spot',
        'tomato-tomato_mosaic_virus',
        'tomato-tomato_yellow_leaf_curl_virus'
    ) 
    NUM_CLASSES = len(CLASS_NAMES) 
    DEFAULT_IMAGE_TRANSFORM = T.Compose([
        T.ToImage(),
        T.ToDtype(torch.float32, scale = True),
        T.Resize((224,224)),
    ])
    DEFAULT_COMMON_TRANSFORM = T.Compose([
        T.Identity(),
    ])

    NAME = "plant_village"
    TASK = "classification"

    def __init__(
            self,
            root: Path,
            df: Optional[pd.DataFrame] = None,
            split: Literal["train", "val", "trainval", "test", "all"] = "all", 
            val_split: float = 0.1,
            test_split: float = 0.2,
            random_seed: int = 69,
            image_transform: Optional[T.Transform] = None,
            common_transform: Optional[T.Transform] = None,
            **kwargs
            ) -> None:

        assert split in ("train", "val", "trainval", "test", "all"), "invalid split"
        _experiment_kwargs = {
            "random_seed": random_seed,
            "val_split": val_split,
            "test_split": test_split
        }

        self.root = root
        self.split = split
        self.df = df if isinstance(df, pd.DataFrame) else self.get_equal_samples_df(root, **_experiment_kwargs)
        self.df = self.df.assign(image_path = lambda df: df["image_path"].apply(lambda x: str(x)))
        self.split_df = self.df.assign(df_idx = lambda df: df.index).pipe(self.subset_df, split).pipe(self.prefix_root, self.root)
        self.image_transform = image_transform or self.DEFAULT_IMAGE_TRANSFORM
        self.common_transform = common_transform or self.DEFAULT_COMMON_TRANSFORM

    def __len__(self) -> int:
        return len(self.split_df)

    def __getitem__(self, idx) -> tuple[torch.Tensor, int, int]:
        row_idx = self.split_df.iloc[idx]
        image = iio.imread(row_idx["image_path"]).squeeze()
        if image.shape[2] == 4:
            image = image[:, :, :3]
        image = self.image_transform(image)
        if self.split == "train":
            image = self.common_transform(image)
        return image, row_idx["label_idx"], row_idx["df_idx"]
        
    def __repr__(self) -> str:
        return '\n'.join([
            f"{self.NAME} dataset for {self.TASK}",
            f"local @ [{self.root}]",
            f"with {len(self.CLASS_NAMES)} classes: {self.CLASS_NAMES}",
            f"and {len(self)} images loaded under {self.split} split",
        ])

    @classmethod    
    def get_stratified_samples_df(cls, root: Path, val_split: float, test_split: float, random_seed: int) -> pd.DataFrame: 
        df = pd.DataFrame({"name": list(root.rglob("*.JPG")) + list(root.rglob("*.jpg")) + list(root.rglob("*.PNG"))})
        df["image_path"] = df["name"].apply(lambda x: Path(x.parent.stem, x.name)) # type: ignore
        df["label_str"] = df["name"].apply(lambda x: cls.rename_class_name(x.parent.stem))
        df["label_idx"] = df["label_str"].apply(lambda x: cls.CLASS_NAMES.index(x))

        test = (df
                .groupby("label_str", group_keys=False)
                .apply(lambda x: x.sample(frac=test_split, random_state=random_seed, axis=0))
                .assign(split = "test"))
        
        val = (df
                .drop(test.index, axis = 0)
                .groupby("label_str", group_keys=False)
                .apply(lambda x: x.sample(frac=val_split, random_state=random_seed, axis=0))
                .assign(split = "val"))

        train = (df
                .drop(test.index, axis = 0)
                .drop(val.index, axis = 0)
                .assign(split = "train"))

        return (pd.concat([train, val, test])
                .sort_values("image_path")
                .reset_index(drop = True)
                .drop("name", axis = 1))
    
    @classmethod
    def get_equal_samples_df(cls, root: Path, val_split: float, test_split: float, random_seed: int) -> pd.DataFrame: 
        df = pd.DataFrame({"name": list(root.rglob("*.JPG")) + list(root.rglob("*.jpg")) + list(root.rglob("*.PNG"))})
        df["image_path"] = df["name"].apply(lambda x: Path(x.parent.stem, x.name)) # type: ignore
        df["label_str"] = df["name"].apply(lambda x: cls.rename_class_name(x.parent.stem))
        df["label_idx"] = df["label_str"].apply(lambda x: cls.CLASS_NAMES.index(x))

        test = (df
                .groupby("label_str", group_keys=False)
                .apply(lambda x: x.sample(n = 100, random_state=random_seed, axis=0), include_groups = True)
                .assign(split = "test"))
        
        val = (df
                .drop(test.index, axis = 0)
                .groupby("label_str", group_keys=False)
                .apply(lambda x: x.sample(frac=val_split, random_state=random_seed, axis=0), include_groups = True)
                .assign(split = "val"))

        train = (df
                .drop(test.index, axis = 0)
                .drop(val.index, axis = 0)
                .assign(split = "train"))

        return (pd.concat([train, val, test])
                .sort_values("image_path")
                .reset_index(drop = True)
                .drop("name", axis = 1))

    @staticmethod
    def rename_class_name(filename: str) -> str:
        filename = filename.lower()
        splits = filename.split('__')
        plant_name = splits[0].replace(',', '').removesuffix('_')
        disease_name = splits[-1].removeprefix('_')
        return f"{plant_name}-{disease_name}"

    def subset_df(self, df: pd.DataFrame, split: str) -> pd.DataFrame:
        if split == "all":
            return df
        elif split == "trainval":
            return df[(df["split"] == "train") | (df["split"] == "val")].reset_index(drop=True)
        return df[df["split"] == split].reset_index(drop=True)

    def prefix_root(self, df: pd.DataFrame, root: Path) -> pd.DataFrame:
        return df.assign(image_path = lambda df: df["image_path"].apply(lambda x: str(root/x)))

def plot_batch(images, labels, dataset) -> None:
    n = len(images)
    assert n >= 2, "invalid n, n must be at least 2"
    nrows = int(np.around(np.sqrt(n)))
    ncols = int(np.ceil(n / nrows))
    
    fig, axes = plt.subplots(nrows, ncols, figsize = (16, 16))
    fig.suptitle(f"{dataset.NAME}_{dataset.TASK}")
    for idx, ax in enumerate(axes.ravel()):
        if idx < n: 
            image, label = images[idx], labels[idx]
            ax.imshow(image.permute(1,2,0).clamp(0, 1))
            ax.set_title(dataset.CLASS_NAMES[label], fontsize = 10)
        ax.axis("off")
    plt.tight_layout()

def get_checkpoints_list(experiments_dir: Path) -> list:
    ckpts = sorted([p for p in (experiments_dir/"model_ckpts").iterdir() if p.stem != "last.ckpt"])
    display([p.stem for p in ckpts])
    return ckpts
    
def get_model(experiments_dir: Path, model:torch.nn.Module):
    ckpt = torch.load(get_checkpoints_list(experiments_dir)[1])
    model_weights = {k.removeprefix("model."):v for k,v in ckpt["state_dict"].items()}
    m = model(num_classes = PlantDiseaseClassification.NUM_CLASSES)
    m.load_state_dict(model_weights)
    m.eval()
    return m

def load_one_batch(dataloader: DataLoader):
    return next(iter(dataloader))

def predict_one_batch(model: torch.nn.Module, images: torch.Tensor):
    return torch.softmax(model(images).detach(), 0)

def get_top_k_logits(logit: torch.tensor, k: int):
    _logit = [(i, p.item()) for i, p in enumerate(logit)]
    _logit.sort(key = lambda x: x[1], reverse = True)
    return _logit[:k]

def plot_predictions(images, labels, logits, dataset, experiments_dir) -> None:
    inference_dir = experiments_dir / "inference"
    inference_dir.mkdir(exist_ok = True)
    
    for idx, (image, label, logit) in tqdm(enumerate(zip(images, labels, logits)), total = len(images), desc = "Generating Plots"):
        top_10_logits = get_top_k_logits(logit, 10)
        top_10_classes = tuple(i for i, _ in top_10_logits)
        top_10_probs = tuple(p for _, p in top_10_logits)
        x_ticks = tuple(range(0, 10)) 

        fig = plt.figure(layout = "tight", figsize = (5, 5))
        gs = plt.GridSpec(nrows = 2, ncols = 1, figure = fig, height_ratios=[3,1])

        image_ax = fig.add_subplot(gs[0, 0])
        hist_ax = fig.add_subplot(gs[1, 0])

        fig.suptitle(f"True: {dataset.CLASS_NAMES[label]} ({label}) :: Pred: {dataset.CLASS_NAMES[top_10_classes[0]]} ({top_10_classes[0]})", fontsize = 10)
        image_ax.imshow(image.clamp(0, 1).permute(1,2,0))
        image_ax.axis("off")
        hist_ax.bar(x = np.arange(0, 10, 1, dtype = np.uint8), height=top_10_probs)
        hist_ax.set_xticks(ticks = x_ticks, labels = top_10_classes)
        hist_ax.autoscale(True, "both", True)
        hist_ax.grid(True, "both", "y")

        plt.savefig(inference_dir/f"{idx}.png")

In [None]:
RANDOM_SEED = 42
TESTING_SPLIT = 0.2
VALIDATION_SPLIT = 0.1
BATCH_SIZE = 32

dataset = PlantDiseaseClassification(
    root = DATASET_DIR,
    split = "test",
    random_seed=RANDOM_SEED,
    test_split=TESTING_SPLIT,
    val_split=VALIDATION_SPLIT,
)
dataloader = DataLoader(
    dataset = dataset,
    batch_size = BATCH_SIZE,
    shuffle = True
)
model = get_model(
    experiments_dir = EXPERIMENTS_DIR, 
    model = resnet18, 
)
#display(model)

In [None]:
display(dataset)
images, labels, _ = load_one_batch(dataloader)
predictions = predict_one_batch(model, images)
print(f"Accuracy over Batch: {((labels == torch.argmax(predictions, 1)).sum() / len(labels))*100}%")
plot_predictions(images, labels, predictions, PlantDiseaseClassification, EXPERIMENTS_DIR)