In [1]:
import model.fruit
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as func
import torchvision
import numpy as np
from PIL import Image

# Importing the Data Set

The first step of training or testing is to import your dataset.
This will make the image set available for use, as well as defining your labels.

Training data must be organized into folders, each of which representing one classification.
For example, a folder called 'Apples' which contains only apple pictures, then a folder called 'Bananas', etc.

The root path passed into the `get_data()` function should contain all of folders, having the structure:

```
root/
  Apples/
    img1.jpg
    img2.jpg
    ...
    imgN.jpg
  Bananas/
    img1.jpg
    img2.jpg
    ...
    imgN.jpg
  ...
```

*Note: importing a dataset is NOT NECESSARY for inference, only for testing and training*

In [2]:
# To run this example, replace the path below with the path to your dataset.
# It is recommended to import twice, once for training and once for testing (with different data sets, obviously).

training_set, classes = model.fruit.FruitTrainingModel.get_data(root=r"D:\Documents\School\MAIN_FRUITS\PLAY",
                                                                batch_size=4,
                                                                shuffle=True,
                                                                num_workers=4)
# get_data() returns a dict mapping the label name to its id (an integer).
# For training, we need to convert this into an ordered list of strings.
classes = sorted(list(classes.keys()), key=lambda cls: classes[cls])
classes

['Apple', 'Banana', 'Lemon', 'Limes']

# Basic Model Training

Below is the world's simplest training pipeline, just for basic demonstration.
Typical features of a training pipeline include things like measurements, accuracy tests, convergence, among others.
The given example blindly trains the model, for demonstration purposes.


In [3]:
# First, we create our model.
m = model.fruit.FruitTrainingModel(labels=classes)

In [4]:
# World's simplest training pipeline
# WARNING: this will take quite a while, especially with a large dataset.
num_epochs = 8
loss = []
try:
    print("Training model...")
    for i, data in enumerate(training_set, start=1):
        # Every 10% of the way there, give a checkpoint update.
        if i % (len(training_set) // 10) == 0:
            print(f"{i} out of {len(training_set)} sets (avg loss so far = {np.array(loss).mean()})")
            loss.clear()
        # Train our model, record the results.
        for j in range(num_epochs):
            loss.append(m.run_epoch(data))
except KeyboardInterrupt:
    print("Training was interrupted")

# Return the average loss during training
print("Done! Final avg loss: ", np.array(loss).mean())

Training model...
16 out of 165 sets (avg loss so far = 0.6747409517566363)
32 out of 165 sets (avg loss so far = 0.33392617444042116)
48 out of 165 sets (avg loss so far = 0.13609429344069213)
64 out of 165 sets (avg loss so far = 0.1566448172670789)
80 out of 165 sets (avg loss so far = 0.14235287410701858)
96 out of 165 sets (avg loss so far = 0.1349896158571937)
112 out of 165 sets (avg loss so far = 0.05407547660070122)
128 out of 165 sets (avg loss so far = 0.07990075941233954)
144 out of 165 sets (avg loss so far = 0.0998560030166118)
160 out of 165 sets (avg loss so far = 0.07949153185836622)
Done! Final avg loss:  0.17901555402204394


# Inference

Finally, we can take our trained model and make predictions with it!

The simplest way is to pass in a PIL image, but you may alternatively pass in a 4D tensor representing the image.

In [5]:
img = Image.open(r"D:\Documents\School\MAIN_FRUITS\banana1.jpg")
# img = model.fruit.FruitModel.transform(img).unsqueeze(0)
m.predict(img)

{'Banana': 0.7035243511199951}

# Importing/Exporting

Once a model has been satisfactorily trained, you can export it to a `.pth` file to be reloaded later.


## Export Model

The model state will be saved to the given location and can be loaded later.
It is recommended that these models be versioned, and the labels that they were trained with be recorded.

In [6]:
m.export(r"D:\Documents\School\MAIN_FRUITS\model.pth")

## Import Model

You will need access to the same labels (in the same order) that you exported your model with!
Using different labels will result in odd or broken predictions.

In [8]:
mx = model.fruit.FruitModel.from_file(labels=classes, path=r"D:\Documents\School\MAIN_FRUITS\model.pth")
mx.predict(Image.open(r"D:\Documents\School\MAIN_FRUITS\banana1.jpg"))


{'Banana': 0.7035243511199951}