# Fine Tuning a PyTorch Model on Our Fruits Dataset
This notebook will demonstrate how to use PyTorch to fine tune a pre-existing object detection model for our fruits data set.

We will begin by creating our own DataSet class that our training module will use for data batching.

Then we will load a pre-trained Faster RCNN model using the PyTorch library and modify it so that it can classify the various fruits in our data set.

Afterwards, we will write our training, evaluation, and testing loops used for training our model.
During training, we will set it up so that if training is interrupted it can be resumed from where it stopped.
We will see how to save our "best" model found during our training loop.
Once training is completed, we will use our test data set to evaluate the precision, recall, and accuracy of our model.

Finally, we will discuss next steps, including making our model work with the PyTorch lightning framework.

## Building a Custom DataSet Class for Our Fruits Data

PyTorch provides a Dataset class that we can work with for batching during model training.
We will begin by creating our own custom Dataset class that inherits from the base PyTorch Dataset class.
Once we have our custom Dataset class, we can use the PyTorch Dataloader module to index our data for training, evaluation, and testing.
The official [PyTorch Tutorial](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) contains detailed information and links to the official documentation for both the Dataset and Dataloader class.

We will start by defining our Dataset class and then explain what each function does.
Afterwards, we'll create some Dataloaders using our custom Dataset class and then load a few examples to verify our code is working correctly


### The Custom Fruits Dataset Class

In [42]:
import torch
from torchvision.io import read_image
from torch.utils.data import Dataset
import pandas as pd

In [43]:
# The dataset class requires us to implement at least the __getitem__ and __len__ functions
class FruitsDataset(Dataset):
    def __init__(self, annotations_file: str, img_dir: str, transform=None, target_transform=None):
        self.data_df = pd.read_csv(annotations_file, encoding='utf-8', engine='python')
        self.img_dir = img_dir
        self.img_labels = self.data_df['label']
        self.img_ids = self.data_df['filename'].unique()  # Use the image file name as the ID for each image
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self.img_ids.shape[0]

    def __getitem__(self, index: int):
        img_id = self.img_ids[index]

        # The full file path to the image
        img_path = self.img_dir + img_id

        # Select all rows in our data frame that contain entries for the image at location data_df[index]
        img_annotations = self.data_df.loc[self.data_df['filename'] == img_id]

        # Get the boxes and labels for our image
        # Convert them to torch tensors so we can use them in our dataloader
        boxes = img_annotations[['x1', 'y1', 'x2', 'y2']].values
        # Get the area of all the boxes. The [:, 3] notation says to give me the entire column at column index 3
        # and so on. This is numpy shorthand for subtracting and multiplying entire columns of arrays
        # The area equation is w*h, and we will have an n-element matrix where each entry is the area of
        # the bounding box for the nth object
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        boxes = torch.as_tensor(boxes, dtype=torch.int)
        area = torch.as_tensor(area, dtype=torch.float32)
        labels = img_annotations['label'].values
        labels = [self.__convert_labels__(x) for x in labels]
        labels = torch.as_tensor(labels)

        # Read the image using torchvision so we can return it
        image = read_image(img_path)

        # Apply any transforms if they were supplied
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            labels = self.target_transform(labels)

        # Add the results for our image to a dictionary. This dictionary will hold the
        # box, labels, and area of our box
        target = {'boxes': boxes,
                  'area': area,
                  'labels': labels,
                  'img_id': img_id}

        return image, target

    def __convert_labels__(self, x):
        # This function will convert our string labels into a one hot encoding value
        # Torch will now allow creating a tensor using strings so our workaround will be
        # to use this encoding.
        # Remember, torch reserves 0 for the "background class" so we start at 1
        # TODO: This may not be true for Faster-RCNN, just for Mask-RCNN
        encoding = {'apple': 1, 'banana': 2, 'orange': 3, 'mixed': 4}
        converted_label = encoding[x]
        return converted_label


### Explanation of Our Custom FruitsDataset Class

### `__init__()` function


### `__len__()` function

### `__getitem__()` function

### `__convert_labels__()` function