<a href="https://colab.research.google.com/github/cicureton/LeafLens/blob/ai-model/LeafLens.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torchvision import models

In [None]:

# Load pretrained MobileNetV3
mobilenet_v3_large = models.mobilenet_v3_large(weights="IMAGENET1K_V2")

# Remove ImageNet classifier
mobilenet_v3_large.classifier = nn.Identity()

# Freeze early layers so we just train the last convolutional layers
for param in mobilenet_v3_large.features[:-10].parameters():
    param.requires_grad = False


Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-5c1a4163.pth


100%|██████████| 21.1M/21.1M [00:00<00:00, 160MB/s]


In [None]:
import pandas as pd
from google.colab import drive

# Mount google drive to store data
drive.mount('/content/drive', force_remount=False)

# Read dataset PlantClef that contains all plant species data
df = pd.read_csv('/content/drive/MyDrive/PlantCLEF2024_single_plant_training_metadata.csv', sep=None, engine='python')
# Print head of csv file
print(df.head())


Mounted at /content/drive
                                     image_name organ  species_id      obs_id  \
0  59feabe1c98f06e7f819f73c8246bd8f1a89556b.jpg  leaf     1396710  1008726402   
1  dc273995a89827437d447f29a52ccac86f65476e.jpg  leaf     1396710  1008724195   
2  416235e7023a4bd1513edf036b6097efc693a304.jpg  leaf     1396710  1008721908   
3  cbd18fade82c46a5c725f1f3d982174895158afc.jpg  leaf     1396710  1008699177   
4  f82c8c6d570287ebed8407cefcfcb2a51eaaf56e.jpg  leaf     1396710  1008683100   

    license partner          author  altitude   latitude  longitude  \
0  cc-by-sa     NaN   Gulyás Bálint  205.9261  47.592160  19.362895   
1  cc-by-sa     NaN    vadim sigaud  323.7520  47.906703   7.201746   
2  cc-by-sa     NaN     fil escande  101.3160  48.826774   2.352774   
3  cc-by-sa     NaN  Desiree Verver    5.1070  52.190427   6.009677   
4  cc-by-sa     NaN      branebrane  165.3390  45.794739  15.965862   

   gbif_species_id           species  genus    family   data

In [None]:
# Print columns of the datset (features)
print(df.columns)

Index(['image_name', 'organ', 'species_id', 'obs_id', 'license', 'partner',
       'author', 'altitude', 'latitude', 'longitude', 'gbif_species_id',
       'species', 'genus', 'family', 'dataset', 'publisher', 'references',
       'url', 'learn_tag', 'image_backup_url'],
      dtype='object')


In [None]:
# I want to work with a subset of the data and take the first 20 different species
# Select column species in our csv file and collect each unique species for first [:20] items
species_list = df['species'].unique()[:20]
for species in species_list:
  print(species)

Taxus baccata L.
Dryopteris filix-mas (L.) Schott
Roemeria hybrida (L.) DC.
Rosa corymbifera Borkh.
Potentilla reptans L.
Saponaria ocymoides L.
Rumex crispus L.
Salix alba L.
Salvia microphylla Kunth
Prunus cerasus L.
Ranunculus repens L.
Salix caprea L.
Potentilla indica (Andrews) Th.Wolf
Melilotus albus Medik.
Rhus typhina L.
Solanum nigrum L.
Primula vulgaris Huds.
Schoenoplectus lacustris (L.) Palla
Podranea ricasoliana (Tanfani) Sprague
Quercus cerris L.


In [None]:
### Store our subset_metadata dataset so it can be written to csv file
# Take the subset in our original dataset (df) that contains columns [species] that exist in (species_list)
# to get our subset_metadata

subset_metadata = df[df['species'].isin(species_list)]

In [None]:
import os

# Create save path for my new subset dataset
save_path = '/content/drive/MyDrive/subset_metadata.csv'

# Check if already exists so don't have to overwrite every time
if not os.path.exists(save_path):
    df.to_csv(save_path, index=False)
    print("File saved")
else:
    print("File already exists.")

File already exists.


In [None]:
# Check subset_metadata information (20 species, 8450 unique images)
print(subset_metadata.shape)
print(subset_metadata.nunique())

(8450, 20)
image_name          8450
organ                  7
species_id            20
obs_id              8037
license                4
partner                3
author              6294
altitude            3744
latitude            5302
longitude           5306
gbif_species_id       20
species               20
genus                 18
family                16
dataset                2
publisher              2
references          8017
url                 8450
learn_tag              3
image_backup_url    8450
dtype: int64


In [None]:
import os
import requests
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
# Here we are giong to download our images locally so we don't have to use the image url's

# Create directory to store my subset dataset images
os.makedirs('/content/drive/MyDrive/subset_images', exist_ok=True)

# Function to download images
def download_image(url):
      # filename holds the new filename by joining/getting rid of the https:// and simply
      # stores the image name from the url (i.e https://plus.unsplash.com/premium_photo-1757260019141 becomes -->
      # premium_photo-1752... and gets stored in subset_images)
      filename = os.path.join("/content/drive/MyDrive/subset_images", os.path.basename(url))
      # Check if saved so we don't overwrite every time
      if not os.path.exists(filename):
          # Get url
          r = requests.get(url, timeout=10)
          if r.status_code == 200:
            # Get request for image and store in r
            r = requests.get(url)
            # Write the image (r) as binary and store in subset_images
            with open(filename, "wb") as f:
                f.write(r.content)

# list all the urls in subset_metadata to download
urls = subset_metadata["url"].tolist()

# Increase # of threads operating the downloads (26) and use tqdm to initialize a loading bar
# downlad_image function for all urls in our list
with ThreadPoolExecutor(max_workers=26) as executor:
    list(tqdm(executor.map(download_image, urls), total=len(urls)))



100%|██████████| 8450/8450 [00:01<00:00, 4495.03it/s]


In [None]:
from re import X
from sklearn.model_selection import train_test_split
# initialize train/test/val split for our model

# initialize train/val and test to split our subset metadata into 80% and 20%
X_temp, X_test, y_temp, y_test = train_test_split(subset_metadata["url"], subset_metadata["species"], test_size=0.2, random_state=42, stratify=subset_metadata["species"])
# split train/val into 60/20 by giving validation size 25% of 80 = 20
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.25, random_state=42, stratify=y_temp)


print(f"Train size: {len(X_train)}")
print(f"Validation size: {len(X_val)}")
print(f"Test size: {len(X_test)}")

Train size: 5070
Validation size: 1690
Test size: 1690


In [None]:
import os

image_folder = "/content/drive/MyDrive/subset_images"
# Change the datasets to correct localpaths (because i split the datasets using the url attribute) to local filepath (otherwise error will occur when trying to train model (could not find https://lasjfldja))
train_images = [os.path.join(image_folder, os.path.basename(url)) for url in X_train]
val_images   = [os.path.join(image_folder, os.path.basename(url)) for url in X_val]
test_images  = [os.path.join(image_folder, os.path.basename(url)) for url in X_test]

# Reset index back to sequential because after splitting train-val-test they may not be sequential anymore
y_train = y_train.reset_index(drop=True)
y_val   = y_val.reset_index(drop=True)
y_test  = y_test.reset_index(drop=True)


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os
# Create a custom dataset to house our training data (pytorch models needs tensors not csv/url files which the images are currently in)
class PlantDatasetLocal(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels.reset_index(drop=True)
        self.transform = transform
        self.classes = sorted(self.labels.unique())
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.image_directory = "/content/drive/MyDrive/subset_images" # Add image directory

    def __len__(self):
        return len(self.image_paths)

# How to get individual items in the plantdatasetLocal
    def __getitem__(self, idx):
        # Construct local file path
        image_name = os.path.basename(self.image_paths[idx])
        image_path = os.path.join(self.image_directory, image_name)
        image = Image.open(image_path).convert("RGB")
        label = self.class_to_idx[self.labels[idx]]

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError
# Cleaning train images (found some of them to be missing/corrupted)
cleaned_train_images = []
cleaned_y_train = []

# Create for loop and initialize tqdm loader to go through all train_images and open/append them to new clean list
for path, label in tqdm(zip(train_images, y_train), total=len(train_images), desc="Checking train images"):
    try:
        Image.open(path)
        cleaned_train_images.append(path)
        cleaned_y_train.append(label)
    except UnidentifiedImageError as e:
        print(f"Error opening image {path}: {e} + {label}")

# Updated to new cleaned list
train_images = cleaned_train_images
y_train = cleaned_y_train


Checking train images:  19%|█▉        | 962/5070 [11:45<50:11,  1.36it/s]  


KeyboardInterrupt: 

In [None]:
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError

# Cleaning validation images
cleaned_val_images = []
cleaned_y_val = []

# Open the image to see if it is corrupted or missing, if opens then put into new list
for path, label in tqdm(zip(val_images, y_val), total=len(val_images), desc="Checking validation images"):
    try:
        Image.open(path)
        cleaned_val_images.append(path)
        cleaned_y_val.append(label)
    except UnidentifiedImageError as e:
        print(f"Error opening image {path}: {e} + {label}")

# Replace original lists with cleaned ones
val_images = cleaned_val_images
y_val = cleaned_y_val


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define data transformation (how the model can intepret the data. Must be 224x224 and tensor object 0-1 instead of 0-255 pixel values)
# and normalized as well
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Our labels must be reset to pd.series and reset index after changing them to
# python lists when we cleaned them for missing/corrupted data
y_train = pd.Series(y_train).reset_index(drop=True)
y_val = pd.Series(y_val).reset_index(drop=True)
y_test = pd.Series(y_test).reset_index(drop=True)


# Create the custom train dataset using torchvision to transform our data into something apprehendable by the model (not csv, not just img urls)
train_dataset = PlantDatasetLocal(train_images, y_train, transform=transform)
val_dataset = PlantDatasetLocal(val_images, y_val, transform=transform)
test_dataset = PlantDatasetLocal(test_images, y_test, transform=transform)
# Create the train loader with a batch size of 32 that we will pass through the moddle each time, shuffled.
# Increased the num_workers to try and increase the speed
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
# Determine the number of unique species
num_species = len(train_dataset.classes)
print(f"Number of unique species: {num_species}")

# Replace the classifier layer that came with the pretrained model
mobilenet_v3_large.classifier[3] = nn.Linear(mobilenet_v3_large.classifier[3].in_features, num_species)

In [None]:
import torch.optim as optim
from tqdm import tqdm

# Loss function and optimizier
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mobilenet_v3_large.parameters(), lr=0.001)

# Training Loop using 7 epochs (seven run throughs of the same dataset)
num_epochs = 7
mobilenet_v3_large.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        optimizer.zero_grad()

        # Forward pass calculates the output by feeding input thorugh model nad loss using criterion function
        outputs = mobilenet_v3_large(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

In [None]:
torch.save(mobilenet_v3_large.state_dict(), "/content/drive/MyDrive/mobilenet_v3_large_plants.pth")
print("Model saved")


In [None]:
from tqdm import tqdm

mobilenet_v3_large.eval()  # evaluation mode
correct = 0
total = 0
# No gradient calculation
with torch.no_grad():
    for inputs, labels in tqdm(val_loader, desc="Validating", total=len(val_loader)):
        outputs = mobilenet_v3_large(inputs)
        # Get class with highest score
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Validation Accuracy: {100 * correct / total:.2f}%")
