In [1]:
import torch # type: ignore
import torchvision # type: ignore
from typing import Tuple, Dict, List
from pathlib import Path 
from zipfile import ZipFile
from PIL import Image # type: ignore
from torch.utils.data import Dataset # type: ignore
from torchvision import transforms # type: ignore
import os

torch.__version__

'2.4.0'

In [2]:
DIR_PATH = Path("data")
IMAGE_PATH = DIR_PATH / "spoiled-fresh"

if IMAGE_PATH.is_dir():
    print("data already exist!")
else: 
    print("extracting data..")
    with ZipFile(file=DIR_PATH / "spoiled-fresh.zip", mode="r") as zip_ref:
        zip_ref.extractall(IMAGE_PATH)
    os.remove(DIR_PATH / "spoiled-fresh.zip")
    print("[INFO] done unzipping the file")

data already exist!


In [3]:
type(IMAGE_PATH)

pathlib.WindowsPath

In [4]:
IMAGE_PATH = DIR_PATH / "spoiled-fresh" / "FRUIT-16K"
sorted(entry.name for entry in os.scandir(IMAGE_PATH))

['F_Banana',
 'F_Lemon',
 'F_Lulo',
 'F_Mango',
 'F_Orange',
 'F_Strawberry',
 'F_Tamarillo',
 'F_Tomato',
 'S_Banana',
 'S_Lemon',
 'S_Lulo',
 'S_Mango',
 'S_Orange',
 'S_Strawberry',
 'S_Tamarillo',
 'S_Tomato']

In [5]:
# create custom dataset
def find_classes(directory: str) -> Tuple[list[str], Dict[str, int]]:
    """
    Finds the class folder names in a target directory 
    """
    # 1. get the class names by scanning the target directory 
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())

    # 2. raise an error is class names couldn't be found 
    if not classes:
        raise FileNotFoundError(f"couldn't find any classes in {directory}")
    
    # 3. create a dictionary of index labels 
    class_to_idx = {class_name: i for i, class_name in enumerate(classes)}
    return classes, class_to_idx

In [6]:
classes, class_to_idx = find_classes(directory=IMAGE_PATH)
class_to_idx

{'F_Banana': 0,
 'F_Lemon': 1,
 'F_Lulo': 2,
 'F_Mango': 3,
 'F_Orange': 4,
 'F_Strawberry': 5,
 'F_Tamarillo': 6,
 'F_Tomato': 7,
 'S_Banana': 8,
 'S_Lemon': 9,
 'S_Lulo': 10,
 'S_Mango': 11,
 'S_Orange': 12,
 'S_Strawberry': 13,
 'S_Tamarillo': 14,
 'S_Tomato': 15}

In [7]:
IMAGE_PATH / classes[0]

WindowsPath('data/spoiled-fresh/FRUIT-16K/F_Banana')

In [8]:
# 1. catch all the folder classes as iterables 
len(list(Path(IMAGE_PATH / classes[0]).glob("*.jpg")))

1000

In [9]:
a = [1, 2, 3]
a.extend([4, 5, 6])
a

[1, 2, 3, 4, 5, 6]

In [10]:
# write a custom dataset class 

# 1. subclass torch.utils.data.Dataset 
class ImageFolderCustom(Dataset):
    # 2. initialize the constructor
    def __init__(self, targ_dir: str, heads: list[str], transform=None, is_training: bool = True):
        # 3. create several attributes 
        # get all the image paths
        self.training = []
        self.testing = []
        for tag in heads: 
            self.img_list = list(Path(targ_dir / tag).glob("*.jpg"))
            self.train_length = int(len(self.img_list) * 0.8)
            self.training.extend(self.img_list[:self.train_length])
            self.testing.extend(self.img_list[self.train_length:])

        if is_training: 
            self.paths = self.training
        else: 
            self.paths = self.testing
        # setup transforms
        self.transform = transform
        # create classes and class_to_idx 
        self.classes, self.class_to_idx = find_classes(targ_dir)

    # 4. create a function to load images 
    def load_image(self, index: int) -> Image.Image: 
        "opens an image via a path and returns it"
        image_path = self.paths[index]
        return Image.open(image_path)
    
    # 5. overwrite __len__()
    def __len__(self) -> int: 
        return len(self.paths)
    
    # 6. overwrite __getitem__() to return a particular sample
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        "returns one sample of data, data and the label (X, y)"
        img = self.load_image(index)
        class_name = self.paths[index].parent.name # expects path in format: data_folder/class_name/image.jpg
        class_idx = self.class_to_idx[class_name]

        # transform if necessary 
        if self.transform:
            return self.transform(img), class_idx
        else: 
            return img, class_idx

In [11]:
normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])

In [12]:
train_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)), 
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(), 
    normalize
])

test_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)), 
    transforms.ToTensor(), 
    normalize
])

In [13]:
# test out ImageFolderCustom()
train_data_custom = ImageFolderCustom(targ_dir=IMAGE_PATH, heads=classes, 
                                        transform=train_transform, is_training=True)

test_data_custom = ImageFolderCustom(targ_dir=IMAGE_PATH, heads=classes,
                                        transform=test_transform, is_training=False)

In [14]:
type(IMAGE_PATH)

pathlib.WindowsPath

In [15]:
from data_setup import create_dataloaders

BATCH_SIZE = 32

train_dataloader, test_dataloader, class_names = create_dataloaders(image_dir=IMAGE_PATH, heads=classes, train_transform=train_transform, test_transform=test_transform, batch_size=BATCH_SIZE, num_workers=0)

In [16]:
len(train_dataloader), len(test_dataloader), class_names

(400,
 100,
 ['F_Banana',
  'F_Lemon',
  'F_Lulo',
  'F_Mango',
  'F_Orange',
  'F_Strawberry',
  'F_Tamarillo',
  'F_Tomato',
  'S_Banana',
  'S_Lemon',
  'S_Lulo',
  'S_Mango',
  'S_Orange',
  'S_Strawberry',
  'S_Tamarillo',
  'S_Tomato'])

## experiments

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [20]:
# setup hyperparameters
models = ["effnetb0", "effnetb2"]
num_epochs = [5]
train_dataloaders = {"fruitsvegs0": train_dataloader}

In [21]:
from experiments import run_experiment as rex 

rex(train_dataloaders=train_dataloaders, test_dataloader=test_dataloader, num_epochs=num_epochs, models=models, class_names=class_names, device=device)

[INFO] experiment number: 1
[INFO] model: effnetb0
[INFO] dataloader: fruitsvegs0
[INFO] number of epochs: 5
[INFO] created a model effnetb0
[INFO] Created SummaryWriter(), saving to: runs\2024-08-23\fruitsvegs0\effnetb0\5_epochs


  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.5014 | train_acc: 0.9327 | test_loss: 0.3278 | test_acc: 0.9109


 20%|██        | 1/5 [00:55<03:40, 55.19s/it]

Epoch: 2 | train_loss: 0.1112 | train_acc: 0.9835 | test_loss: 0.2340 | test_acc: 0.9306


 40%|████      | 2/5 [01:52<02:49, 56.60s/it]

Epoch: 3 | train_loss: 0.0721 | train_acc: 0.9866 | test_loss: 0.2082 | test_acc: 0.9303


 60%|██████    | 3/5 [02:51<01:55, 57.56s/it]

Epoch: 4 | train_loss: 0.0543 | train_acc: 0.9894 | test_loss: 0.1383 | test_acc: 0.9550


 80%|████████  | 4/5 [03:49<00:57, 57.72s/it]

Epoch: 5 | train_loss: 0.0456 | train_acc: 0.9912 | test_loss: 0.1576 | test_acc: 0.9450


100%|██████████| 5/5 [04:47<00:00, 57.51s/it]


[INFO] Saving model to: models\07_effnetb0_fruitsvegs0_5_epochs.pt
--------------------------------------------------

[INFO] experiment number: 2
[INFO] model: effnetb2
[INFO] dataloader: fruitsvegs0
[INFO] number of epochs: 5
[INFO] created a model effnetb2
[INFO] Created SummaryWriter(), saving to: runs\2024-08-23\fruitsvegs0\effnetb2\5_epochs


  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.5092 | train_acc: 0.9251 | test_loss: 0.2748 | test_acc: 0.9453


 20%|██        | 1/5 [01:15<05:00, 75.25s/it]

Epoch: 2 | train_loss: 0.1184 | train_acc: 0.9812 | test_loss: 0.2028 | test_acc: 0.9463


 40%|████      | 2/5 [02:35<03:55, 78.37s/it]

Epoch: 3 | train_loss: 0.0800 | train_acc: 0.9845 | test_loss: 0.1591 | test_acc: 0.9547


 60%|██████    | 3/5 [03:56<02:38, 79.22s/it]

Epoch: 4 | train_loss: 0.0613 | train_acc: 0.9883 | test_loss: 0.1514 | test_acc: 0.9519


 80%|████████  | 4/5 [05:18<01:20, 80.31s/it]

Epoch: 5 | train_loss: 0.0499 | train_acc: 0.9894 | test_loss: 0.1541 | test_acc: 0.9513


100%|██████████| 5/5 [06:40<00:00, 80.08s/it]

[INFO] Saving model to: models\07_effnetb2_fruitsvegs0_5_epochs.pt
--------------------------------------------------




