<a href="https://www.kaggle.com/code/ghousiah/saved-dataset-multi-family-plant-classification?scriptVersionId=246082204" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# **Start From Here if you have "datasetforfivefamiliesplant**

### **Parameter Settings**

In [None]:
import os
import pandas as pd
from PIL import Image
import requests
from io import BytesIO
from tqdm import tqdm

In [None]:
img_size=96
batch_size=16

In [None]:
df = pd.read_csv("/kaggle/input/datasetforfivefamiliesplant/filtered_cached.csv")
df.head()

In [None]:
# Manage Labels
# Create new label mappings based on the 5 families
#label_to_idx = {fam: i for i, fam in enumerate(sorted(df['family'].unique()))}
label_to_idx = dict(zip(df['family'], df['label']))
#-----------------------------------------------------------------------

.

### **Prepare Training and Validation Dataset**

In [None]:
# Split dataset into Train and Validation Set

from sklearn.model_selection import train_test_split

# df is your full dataframe
train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['label'],  # This keeps class proportions
    random_state=42
)



#------------------------------------------------------

In [None]:
# Sanity check

print("Training label distribution:")
print(train_df['label'].value_counts(normalize=True))
print("")


print(f"Training samples after filtering: {len(train_df)}")
print(f"Unique families: {train_df['family'].unique()}")
print("Number of samples per family:")
print(train_df['family'].value_counts())

print("\nValidation label distribution:")
print(val_df['label'].value_counts(normalize=True))


print(f"Validation samples after filtering: {len(val_df)}")
print(f"Unique families: {val_df['family'].unique()}")
print("Number of samples per family:")
print(val_df['family'].value_counts())

In [None]:
from PIL import Image
import requests
from io import BytesIO
from torch.utils.data import Dataset
from PIL import Image
import torch
import random

class PlantClefDataset(torch.utils.data.Dataset):
    def __init__(self, df = df, transform=None , label_to_idx = label_to_idx):
        self.df = df
        self.df = df.reset_index(drop=True)
        
        self.transform = transform

        # Create new label mappings based on the 5 families
        self.label_to_idx = label_to_idx
        self.idx_to_label = {i: fam for fam, i in label_to_idx.items()}
        self.num_classes = len(self.label_to_idx)
        
    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        path = self.df.loc[idx, 'cached_path']
        label = self.df.loc[idx, 'label']
        full_path = os.path.join('/kaggle/input/datasetforfivefamiliesplant/', path)
    
        try:
            image = Image.open(full_path).convert('RGB')
        except Exception as e:
            print(f"⚠️ Failed to open image: {full_path} — {e}")
            image = Image.new('RGB', (224, 224), (255, 255, 255))  # fallback white image

        if self.transform:
            image = self.transform(image)

        # One-hot label
        label_onehot = torch.zeros(self.num_classes)
        label_onehot[label] = 1.0

        return image, label_onehot


#--------------------------------------------------------------------------------------------------
from torchvision import transforms

plant_transform = transforms.Compose([
    transforms.Resize((400, 400)),
    transforms.ToTensor(),  # Scales to [0, 1]

])


In [None]:
# Checks - check if Class work properly

plant_train = PlantClefDataset(train_df, transform=plant_transform)
plant_val  = PlantClefDataset(val_df, transform=plant_transform)

# Checks
print(f"Number of classes in Training: {plant_train.num_classes}")
print("Classes:", plant_train.label_to_idx)

print(f"Total samples after filtering: {len(plant_train.df)}")
print(f"Unique families: {plant_train.df['family'].unique()}")
print("Number of samples per family:")
print(plant_train.df['family'].value_counts())

print(f"Number of classesin Validation: {plant_val.num_classes}")
print("Classes:", plant_val.label_to_idx)

print(f"Total samples after filtering: {len(plant_val.df)}")
print(f"Unique families: {plant_val.df['family'].unique()}")
print("Number of samples per family:")
print(plant_val.df['family'].value_counts())


# check if image can be loaded
import matplotlib.pyplot as plt

# Visualize N images from dataset (no normalization involved)
def visualize_dataset(dataset, title='Sample Images', num_images=5):
    plt.figure(figsize=(15, 3))
    for i in range(num_images):
        img_tensor, label_onehot = dataset[i]

        # Convert tensor to HWC format and scale to [0, 255]
        img = img_tensor.permute(1, 2, 0).numpy()
        img = (img * 255).astype('uint8')

        # Get label name from one-hot
        label_idx = label_onehot.argmax().item()
        label_name = dataset.idx_to_label[label_idx]

        # Plot the image
        plt.subplot(1, num_images, i+1)
        plt.imshow(img)
        plt.title(label_name)
        plt.axis('off')
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Visualize training and validation datasets
visualize_dataset(plant_train, title="Training Samples")
visualize_dataset(plant_val, title="Validation Samples")


### **Preprocess Dataset - Prepare Quadrat Images for Training and Validation**

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torch
import random
from torch.utils.data import DataLoader

class QuadratDataset(Dataset):
    def __init__(self, base_dataset, quad_size=4, transform=None):
        self.base = base_dataset
        self.quad_size = quad_size  # e.g. 4 for 2x2 grid
        self.transform = transform
        self.num_classes = self.base.num_classes

    def __len__(self):
        return len(self.base) // self.quad_size

    def __getitem__(self, idx):
        imgs, labels = [], []

        for _ in range(self.quad_size):
            # Randomly select a single-plant image and its one-hot label
            img, lbl = self.base[random.randint(0, len(self.base) - 1)]
            imgs.append(transforms.ToPILImage()(img))  # Convert back to PIL for stitching
            labels.append(lbl)

        # Assume all images same size, create blank 2x2 canvas
        w, h = imgs[0].size
        quad = Image.new('RGB', (w*2, h*2))

        coords = [(0, 0), (w, 0), (0, h), (w, h)]
        for img_patch, xy in zip(imgs, coords):
            quad.paste(img_patch, xy)

        if self.transform:
            quad = self.transform(quad)

        # Multi-label target (merge all 4 labels)
        ml_label = torch.zeros(self.num_classes)
        for lbl in labels:
            ml_label = torch.logical_or(ml_label, lbl.bool()).float()

        return quad, ml_label


quadrat_transform = transforms.Compose([
    transforms.Resize((img_size,img_size)),
    transforms.ToTensor(),  # Scales to [0, 1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [None]:
# Create single-plant datasets
train_base = PlantClefDataset(train_df, transform=plant_transform)
val_base = PlantClefDataset(val_df, transform=plant_transform)

# Create quadrat datasets for training and validation
quadrat_train = QuadratDataset(train_base, quad_size=4, transform=quadrat_transform)
quadrat_val = QuadratDataset(val_base, quad_size=4, transform=quadrat_transform)

train_loader = DataLoader(quadrat_train, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(quadrat_val, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# Checks
print(f"Total samples in quadrat_test: {len(quadrat_train)}")
print(f"Total samples in quadrat_test: {len(quadrat_val)}")

In [None]:
# Checks
import matplotlib.pyplot as plt
import torchvision
import numpy as np

# Get one batch from the training loader
images, labels = next(iter(train_loader))  # or val_loader

# De-normalize helper (reverse normalization)
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1,1,3)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1,1,3)
    return tensor * std + mean

# Plot the first few images with labels
plt.figure(figsize=(16, 8))
for i in range(8):  # Show 8 samples
    img = images[i].permute(1, 2, 0).detach().cpu()  # CHW -> HWC
    img = denormalize(img).numpy()
    img = np.clip(img, 0, 1)

    # Get label indices
    label_idxs = torch.nonzero(labels[i]).squeeze().tolist()
    if isinstance(label_idxs, int): label_idxs = [label_idxs]

    # Convert label indices to family names
    label_names = [train_loader.dataset.base.idx_to_label[idx] for idx in label_idxs]

    plt.subplot(2, 4, i+1)
    plt.imshow(img)
    plt.title(", ".join(label_names), fontsize=9)
    plt.axis('off')

plt.tight_layout()
plt.show()


# **MODEL**

In [None]:
import torch, timm
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = timm.create_model('efficientnet_lite0', pretrained=True, num_classes=quadrat_train.num_classes)
for p in model.parameters():
    p.requires_grad = False
for p in model.get_classifier().parameters():
    p.requires_grad = True

model = model.to(device)

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)


In [None]:
# Checks

model

from torchinfo import summary
summary(model=model,
        input_size=(1, 3, 96, 96),  # match model expected input size
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

# **TRAINING**

In [None]:
import time 
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
import copy

def train_model(model, train_loader, val_loader, optimizer, loss_fn, device, epochs=10, save_path='best_model.pth'):
    train_losses = []
    val_losses = []
    best_f1 = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    start_time = time.time()  # ⏱️ Start tracking total time

    for epoch in range(epochs):
        epoch_start = time.time()  # ⏱️ Time per epoch start
        model.train()
        total_loss = 0

        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Evaluate
        val_loss, val_f1 = evaluate_model(model, val_loader, loss_fn, device)
        val_losses.append(val_loss)

        epoch_time = time.time() - epoch_start  # ⏱️ Time per epoch
        print(f"Epoch {epoch+1}/{epochs} | Time: {epoch_time:.2f}s | Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f}")

        # Save best model with epoch-based filename
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_model_wts = copy.deepcopy(model.state_dict())

            # Save only model weights (optional)
            best_model_filename = f"best_model_epoch_{epoch+1}.pth"
            torch.save(best_model_wts, best_model_filename)

            # Save full checkpoint (model + optimizer + epoch)
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': best_model_wts,
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_f1': val_f1,
            }, f"checkpoint_epoch_{epoch+1}.pth")

            print(f"✅ Best model saved as '{best_model_filename}'")

    # Save final checkpoint (last state)
    torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, 'final_checkpoint.pth')

    print("✅ Final model checkpoint saved as 'final_checkpoint.pth'")

    total_time = time.time() - start_time  # ⏱️ Total time
    print(f"⏱️ Total training time: {total_time / 60:.2f} minutes")

    # Plot loss history
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training vs Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

    # Load best model weights into model for further use
    model.load_state_dict(best_model_wts)



def evaluate_model(model, val_loader, loss_fn, device):
    model.eval()
    total_loss = 0
    all_preds, all_trues = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)  # Already one-hot from dataset

            outputs = model(images)
            loss = loss_fn(outputs, labels)
            total_loss += loss.item()

            probs = torch.sigmoid(outputs).cpu().numpy()
            preds = (probs > 0.5).astype(int)

            all_preds.append(preds)
            all_trues.append(labels.cpu().numpy())

    y_pred = np.vstack(all_preds)
    y_true = np.vstack(all_trues)
    f1 = f1_score(y_true, y_pred, average='samples')  # multi-label F1
    return total_loss / len(val_loader), f1



In [None]:
train_model(model, train_loader, val_loader, optimizer, loss_fn, device, epochs=25)

In [None]:
torch.save(model.state_dict(), 'final_model.pth')
print("💾 Final model saved as 'final_model.pth'.")

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': 10,
    # add any other training info here
}, 'final_checkpoint.pth')


# **EVALUATION**

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np
import torch

def evaluate_quadrats(model, quad_loader, idx_to_label, threshold=0.5, device='cuda'):
    model.eval()
    pred_bin_all = []
    true_bin_all = []
    pred_class_names_all = []
    true_class_names_all = []

    with torch.no_grad():
        for images, labels in quad_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            probs = torch.sigmoid(outputs).cpu().numpy()
            preds_bin = (probs > threshold).astype(int)

            pred_bin_all.extend(preds_bin)
            true_bin_all.extend(labels.cpu().numpy())

            # Convert predictions to class names
            for pred_vec in preds_bin:
                pred_classes = [idx_to_label[i] for i in range(len(pred_vec)) if pred_vec[i] == 1]
                pred_class_names_all.append(pred_classes)

            # Convert true labels to class names
            for true_vec in labels.cpu().numpy():
                true_classes = [idx_to_label[i] for i in range(len(true_vec)) if true_vec[i] == 1]
                true_class_names_all.append(true_classes)

    # Metrics
    y_pred = np.vstack(pred_bin_all)
    y_true = np.vstack(true_bin_all)

    f1 = f1_score(y_true, y_pred, average='samples', zero_division=0)
    precision = precision_score(y_true, y_pred, average='samples', zero_division=0)
    recall = recall_score(y_true, y_pred, average='samples', zero_division=0)

    print(f"\n📊 Evaluation Metrics:")
    print(f"✅ F1 Score   : {f1:.4f}")
    print(f"✅ Precision  : {precision:.4f}")
    print(f"✅ Recall     : {recall:.4f}\n")

    return pred_class_names_all, y_pred, y_true, true_class_names_all


In [None]:
pred_class_names, pred_bin, true_bin, true_class_names = evaluate_quadrats(model, val_loader, plant_val.idx_to_label, threshold=0.5, device=device)

In [None]:
# Show first N results (e.g., N=5)
N = 5
for i in range(min(N, len(pred_class_names))):
    print(f"🟩 Sample {i+1}")
    print("✅ Predicted :", pred_class_names[i])
    print("🎯 Ground Truth:", true_class_names[i])
    print("-" * 50)


In [None]:
import matplotlib.pyplot as plt
import torch

def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3)
    return tensor * std + mean

N = 5
for i in range(min(N, len(pred_class_names))):
    # Get the quadrat image and label
    img_tensor, _ = quadrat_val[i]  # image is CHW

    # Denormalize
    img = denormalize(img_tensor.permute(1, 2, 0))  # Convert to HWC and denormalize

    # Clamp values to [0, 1] range in case of slight numerical overshoot
    img = img.clamp(0, 1).numpy()

    # Plot
    plt.figure(figsize=(5, 5))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Pred: {', '.join(pred_class_names[i])}\n True: {', '.join(true_class_names[i])}")
    plt.show()

    print(f"🟩 Sample {i+1}")
    print("✅ Predicted :", pred_class_names[i])
    print("🎯 Ground Truth:", true_class_names[i])
    print("-" * 50)
