# **Multiclass Classification**

**Goal**: Using images of crop diseases from Uganda to build a convolutional neural network and train it to classify images into one of the five classes.

**Objectives**
- Convert images from grayscale to RGB
- Resize images 
- Normalize data
- Create a transformation pipeline to standardize images for training
- create a Convolutional Neural Network
- Train the network to do multiclass classification
- Identify overfitting. 

In [1]:
import os

import pandas as pd 
import matplotlib 
import matplotlib.pyplot as plt 
import numpy as np 
import PIL 
import torch 
import torch.nn as nn 
import torch.optim as optim 
import torchinfo 
import torchvision 
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from torch.utils.data import DataLoader, random_split 
from torchinfo import summary 
from torchvision import datasets, transforms 
from tqdm import tqdm

In [2]:
print("torch version : ", torch.__version__)
print("torchvision version : ", torchvision.__version__)
print("torchinfo version : ", torchinfo.__version__)
print("numpy version : ", np.__version__)
print("matplotlib version : ", matplotlib.__version__)
print("PIL version : ", PIL.__version__)

!python --version

torch version :  2.5.1+cpu
torchvision version :  0.20.1+cpu
torchinfo version :  1.8.0
numpy version :  1.26.3
matplotlib version :  3.8.3
PIL version :  10.2.0
Python 3.12.1


In [3]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f"Using {device} device.")

Using cpu device.


**1. Exploring Data**

In [4]:
data_dir = os.path.join("data_p2", "data_undersampled")
train_dir = os.path.join(data_dir, "train")

print("Data Directory:", data_dir)
print("Training Data Directory:", train_dir)

Data Directory: data_p2\data_undersampled
Training Data Directory: data_p2\data_undersampled\train


A list of class names:

In [5]:
classes = os.listdir(train_dir)

print("List of classes:", classes)

List of classes: ['cassava-bacterial-blight-cbb', 'cassava-brown-streak-disease-cbsd', 'cassava-green-mottle-cgm', 'cassava-healthy', 'cassava-mosaic-disease-cmd']


Custom Transformation:

In [6]:
class ConvertToRGB:
    def __call__(self, img):
        if img.mode != "RGB":
            img = img.convert("RGB")
        return img

Normalization:

In [7]:
# Define transformation to apply to the images
transform_normalized = transforms.Compose(
    [
        ConvertToRGB(),
        transforms.Resize((224, 224)),
        # Convert images to tensors
        transforms.ToTensor(),
        # Normalize the tensors (copy the mean and std from previous lesson!)
        transforms.Normalize(
            mean = [0.4326, 0.4953, 0.3120], std = [0.2178, 0.2214, 0.2091]
        )
        
    ]
)

print(type(transform_normalized))
print("-----------------")
print(transform_normalized)

<class 'torchvision.transforms.transforms.Compose'>
-----------------
Compose(
    <__main__.ConvertToRGB object at 0x00000181FBA3B3E0>
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
    ToTensor()
    Normalize(mean=[0.4326, 0.4953, 0.312], std=[0.2178, 0.2214, 0.2091])
)


In [8]:
dataset = datasets.ImageFolder(root=train_dir, transform= transform_normalized)

print('Length of dataset:', len(dataset))

Length of dataset: 8180


**2. Train and validation splitting**

In [9]:
g = torch.Generator()
g.manual_seed(42)

train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=g)

print("Length of training dataset:", len(train_dataset))
print("Length of validation dataset:", len(val_dataset))

Length of training dataset: 6544
Length of validation dataset: 1636


In [10]:
length_dataset = len(dataset)
length_train = len(train_dataset)
length_val = len(val_dataset)

percent_train = np.round(100 * length_train / length_dataset, 2)
percent_val = np.round(100 * length_val / length_dataset, 2)

print(f"Train data is {percent_train}% of full data")
print(f"Validation data is {percent_val}% of full data")

Train data is 80.0% of full data
Validation data is 20.0% of full data
