# 🍡🍱 Day 3 – Fruit Classification

### **Goal**

Build a **Convolutional Neural Network (CNN)** model to classify fruits such as apples, bananas, oranges, and others using the **Fruits 360 dataset**.

* **Target Accuracy:** At least **85%–90%** on the test dataset.

---

### **Steps**

1. **Dataset Preparation**

   * Download the **Fruits 360 dataset**.
   * Arrange the dataset into **training** and **testing** folders, where each fruit has its own subfolder (e.g., `apple`, `banana`, `orange`).

2. **Data Augmentation**

   * Apply transformations to increase dataset diversity:

     * Resizing all images to the same size.
     * Random rotations, flips, and shifts.
     * Normalization of pixel values.

3. **Data Loading**

   * Load training and testing data in batches for efficient processing.
   * Ensure proper shuffling to prevent bias during training.

4. **Model Building (CNN)**

   * Design a CNN with layers such as:

     * Convolutional layers to extract fruit features.
     * Pooling layers to reduce spatial dimensions.
     * Fully connected layers for classification.
   * Use **softmax** activation in the final layer to handle multi-class classification.

5. **Loss Function and Optimizer**

   * Use **categorical cross-entropy loss** since this is a multi-class task.
   * Use an optimizer such as **Adam** for faster convergence.

6. **Training**

   * Train the CNN for several epochs.
   * Track **training loss** and **validation accuracy** after each epoch.
   * Stop when the model reaches at least **85% accuracy** on the validation dataset.

7. **Model Evaluation**

   * Evaluate the CNN using the test dataset.
   * Measure **accuracy** and analyze which fruit classes are most often confused.

8. **Save the Model**

   * Save the trained model so it can be reused later without retraining.

9. **Improvements**

   * If accuracy is below target:

     * Add more convolutional layers.
     * Adjust learning rate or batch size.
     * Try stronger augmentation.

In [1]:
import torch
from torch import nn

# for computer vision
import torchvision
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
# Setting up costants
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
SEED = 42

In [4]:
# import kagglehub 

# path = kagglehub.dataset_download("data/fruits")