#Ball Classification 

###The Problem:
Given an image of a ball (baseball, basketball, etc), provide the classification of that ball. 

###The Input:
3 channel RGB image of a ball
###The Output: 
Classification of that ball into one of 30 classes. 


This notebook with use the data found here: 
https://www.kaggle.com/datasets/gpiosenka/balls-image-classification

The data was loaded as is to Google Drive and used to train, validate, and test the model used in this notebook.

The code in this notebook was written by:
Conrad Testagrose


Lets install the necessary dependencies that I would like to use to accomplish
this classification problem. 

PyTorch Pretrained ViT is a library written by Lukemelas that provides a pretrained Vision Transformers. ViTs provide exceptional accuracy but require an immense amoutn of data to train from scratch. Therefore we will use this library to implement a vision transformer. 

In [3]:
!pip install monai==0.8.1
!pip install pytorch_pretrained_vit 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting monai==0.8.1
  Downloading monai-0.8.1-202202162213-py3-none-any.whl (721 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m721.9/721.9 KB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: monai
Successfully installed monai-0.8.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Now we will import the required libraries/dependencies.


In [4]:
# import dependencies
import torch 
import torchvision
import pandas as pd
from monai.data import Dataset
import cv2
import json
import random
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from monai.data import CacheDataset, DataLoader, Dataset
from pytorch_pretrained_vit import ViT
import torch.nn as nn

We will tead the csv containing the filepaths and the labels. Each entry in the csv file will be added to corresponding key values in the data dictionary. 

We can also save this data dictionary as a JSON file if we needed to.

We will print the first index of the "Train" key in the dictionary to visualize the structure

In [5]:
#Create data dictionary, this makes the process easier.

data_dict = {"Train":[], "Valid":[], "Test":[]}
labels = []
for i in range(30):
  temp = [0]*30
  temp[i] = 1
  labels.append(temp)

data = pd.read_csv("./drive/MyDrive/Ball_Data/balls.csv")

for index, item in data.iterrows():
  data_set = item["data set"].capitalize()
  temp_dict = {"image":"./drive/MyDrive/Ball_Data/"+item["filepaths"],
               "label": labels[int(item["class id"])]}
  data_dict[data_set].append(temp_dict)

print(data_dict["Train"][0])

{'image': './drive/MyDrive/Ball_Data/train/baseball/001.jpg', 'label': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}


We will now take the lists of dictionaries and shuffle them. We will create a list for the Train, Validation, and Test sets from the data dictionary that we created prior. 

Let's make sure to shuffle them so that the data is not in a learnable order. 

We will also visualize the image shapes. This will help dictate which basic transforms we use when creating our batches. If the image needs more channels or if the image needs to have its shape changes to meet the needs of the algorithm we are using. 

In [6]:
#Get individual lists of data
train_data = data_dict["Train"]
valid_data = data_dict["Valid"]
test_data = data_dict["Test"]

#Random shuffle the data
random.Random(42).shuffle(train_data)
random.Random(42).shuffle(valid_data)
random.Random(42).shuffle(test_data)

In [7]:
img = cv2.imread(train_data[0]["image"])
print(train_data[0]["image"], train_data[0]["label"])
print(img.shape)

img = cv2.imread(valid_data[0]["image"])
print(valid_data[0]["image"], valid_data[0]["label"])
print(img.shape)

./drive/MyDrive/Ball_Data/train/tennis ball/008.jpg [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
(224, 224, 3)
./drive/MyDrive/Ball_Data/valid/football/4.jpg [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
(224, 224, 3)


We now need to create our dataloader. This will help us load the data to the model in the form of batches. 

In [8]:
#This is the data loader used to load the data to the model. 
class dataloader(Dataset):
    def __init__(self, dict, transforms):
        self.dict = dict
        self.transforms = transforms

    def __len__(self):
        return len(self.dict)

    def __getitem__(self, index):
        image = cv2.imread(self.dict[index]['image'])
        image = self.transforms(image)
        label = self.dict[index]['label']
        label = torch.FloatTensor(label)
        return image, label

class val_dataloader(Dataset):
    def __init__(self, dict, transforms):
        self.dict = dict
        self.transforms = transforms

    def __len__(self):
        return len(self.dict)

    def __getitem__(self, index):
        image = cv2.imread(self.dict[index]['image'])
        image = self.transforms(image)
        label = self.dict[index]['label']
        label = torch.FloatTensor(label)
        return image, label

Now we get to the fun part, the transforms. Typically for image classification the transforms will follow a general path:
1. Load image
2. Adjust the shape of the image if needed
3. Resize
4. Apply linear transforms or augmentations
5. Make it a tensor

For validation sets we would typically exclude the linear or augmentation transforms. 

In this code I have selected to AutoAugment the images. PyTorch includes this transform as a way to automatically augment the images in your dataset to a certain policy. I have selected the ImageNet policy. The following webpage provides a good overview of what data augmentation is:

https://research.aimultiple.com/data-augmentation/

The key takeaway is that it helps to synthetically "create" new data samples from our existing data to help diversify the data we train our model with. Without data augmentation we may be able to achieve high levels of accuracy quickly on images similar to our training set but we may perform poorly on real-world samples that are not visually similar to our training data. Data Augmentation will help to include more variety and hopefully increase the generalizablity of our algorithm. 

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


pytorch_Train = T.Compose([
    T.ToPILImage(),
    T.Resize((int(299), int(299))),
    T.AutoAugment(T.AutoAugmentPolicy.IMAGENET),
    T.ToTensor()]
)

pytorch_Validate = torchvision.transforms.Compose([
    T.ToPILImage(),
    T.Resize((int(299), int(299))),
    T.ToTensor()]
)

Now we will create our model. Rather than create a model from scratch, I will use transfer learning to transfer the weights of a pretrained model to the model used in this project. 

Transfer learning helps us train the model and achieve higher levels of accuracy quickly. We transfer the weights of a model trained on a large and general dataset such as ImageNet or Microsoft COCO to our model. We add an additional layer to the end of this model which will help us finetune this model on our dataset of ball images. 



In [10]:
#Create our model class
class ModelDefinition():
    def __init__(self, num_class: int, pretrained_flag=True, dropout_ratio=0.5, fc_nodes=1024, patch_size=32, img_size=int):
        self.num_class = num_class
        self.pretrain_flag = pretrained_flag
        self.dropout_ratio = dropout_ratio
        self.fc_nodes = fc_nodes
        self.patch_size=patch_size
        self.img_size=img_size

    def inception_v3(self, device='cpu'):
        model = models.inception_v3(pretrained=self.pretrain_flag)
        model.fc = nn.Sequential(
            nn.Linear(model.fc.in_features, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, self.num_class))
        model.aux_logits = False
        model.to(device)
        return model

    def ViT_Pretrained(self, device='cpu'):
        if self.patch_size == 32:
            model = ViT('B_32', pretrained=True, num_classes=self.num_class, image_size=self.img_size)
        if self.patch_size == 16:
            model = ViT('B_16', pretrained=True, num_classes=self.num_class, image_size=self.img_size)
        model.aux_logits = False
        model.to(device)
        return model

We will now set the parameters that we will use to train our model such as the learning rate, batch size, the number of classes, etc.

We will also create our dataloader objects that we will use when training the model. 

In [11]:
#hyperparameters
num_classes = 30
learning_rate = 0.001
batch_size = 64
dropout_ratio = 0.50
patch_size = 32
epochs = 100
vit_patch_size = 32
model_path = "./drive/MyDrive/Ball_Data/models"
model_name = "ViT_Pretrained"

train_ds = dataloader(train_data, transforms=pytorch_Train)
validation_ds = val_dataloader(valid_data, transforms=pytorch_Validate)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=int(batch_size), num_workers=8)
validation_loader = torch.utils.data.DataLoader(validation_ds, batch_size=int(batch_size), num_workers=8)




Lets initialize our model select a loss function and optimizer. From my experience, SGD provides better results for the ViT we will be using. If we use Inception V3, either will work well but Adam has provided me with great results in the past. 

In [12]:
mod = ModelDefinition(num_classes, pretrained_flag="pretrained", img_size = 299)

if model_name == 'InceptionV3':
    model = mod.inception_v3()
    #model = nn.DataParallel(model)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 0.001)
elif model_name == 'ViT_Pretrained':
    model = mod.ViT_Pretrained()
    #model = nn.DataParallel(model)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr = 0.001, momentum=0.9)

model.to(device)

Downloading: "https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_32.pth" to /root/.cache/torch/hub/checkpoints/B_32.pth


  0%|          | 0.00/398M [00:00<?, ?B/s]

Resized positional embeddings from torch.Size([1, 50, 768]) to torch.Size([1, 82, 768])
Loaded pretrained weights.


ViT(
  (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
  (positional_embedding): PositionalEmbedding1D()
  (transformer): Transformer(
    (blocks): ModuleList(
      (0): Block(
        (attn): MultiHeadedSelfAttention(
          (proj_q): Linear(in_features=768, out_features=768, bias=True)
          (proj_k): Linear(in_features=768, out_features=768, bias=True)
          (proj_v): Linear(in_features=768, out_features=768, bias=True)
          (drop): Dropout(p=0.1, inplace=False)
        )
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (pwff): PositionWiseFeedForward(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (drop): Dropout(p=0.1, inplace=False)
      )
      (1): Block(
 

Finally we can train. We will train for 100 epochs and save the model based on its performance on the validation set. Let this run, it can take some time. 

Validation accuracy is a good sign that our model is training, however, saving the model based on it can result in us saving a model that overfits the data. We can use many different metrics to save the model though. Different metrics to monitor are:
- Validation Loss
- F1 Scores
- Kappa

We will use validation loss. Validation loss will decrease overtime if we are training correctly. It will eventually start to increase which is a sign the model is overfitting. We want to save when validation loss is at its lowest prior to increasing. 

This is not to say that we can not use the accuracy. It is best to determine the best metric to save by with some trial and error. 

In [13]:
#training loop
best_metric = -1
best_metric2 = 1000
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
val_interval = 1

print("Training...")
for epoch in range(epochs):
    print("-" * 10)
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].type(torch.float).to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss: .4f}")

    if (epoch + 1) % int(val_interval) == 0:
        model.eval()
        num_correct = 0.0
        metric_count = 0
        val_epoch_loss = 0
        val_step = 0
        with torch.no_grad():
            y_pred = torch.tensor([], dtype=torch.float32, device=device)
            y = torch.tensor([], dtype=torch.long, device=device)
            for val_data in validation_loader:
                val_step += 1
                val_images, val_labels = val_data[0].to(device), val_data[1].type(torch.float).to(device)
                val_output = model(val_images)
                value = torch.eq(val_output.argmax(dim=1), val_labels.argmax(dim=1))
                val_loss = loss_function(val_output, val_labels)
                val_epoch_loss += val_loss.item()
                metric_count += len(value)
                num_correct += value.sum().item()
                val_epoch_len = len(validation_ds) // validation_loader.batch_size
            metric = num_correct / metric_count
            metric2 = val_epoch_loss
            metric_values.append(metric)
            if metric2 < best_metric2:
                best_metric = metric
                best_metric2 = metric2
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), model_path + "/best_model_vit.pth")
                print('Saved new model')
            print("Current Epoch: {} current accuracy: {:.4f}"
                  " Best accuracy: {:.4f} at epoch {}".format(epoch + 1, metric, best_metric, best_metric_epoch))
            print("validation_accuracy: " + str(metric) + " Epoch Number: " + str(epoch + 1))
            print("Validation Loss: ", val_epoch_loss)
print(f"training completed, best_metric: {best_metric: .4f}"
      f" at epoch: {best_metric_epoch}")

Training...
----------
epoch 1 average loss:  3.3619
Saved new model
Current Epoch: 1 current accuracy: 0.5867 Best accuracy: 0.5867 at epoch 1
validation_accuracy: 0.5866666666666667 Epoch Number: 1
Validation Loss:  9.893285274505615
----------
epoch 2 average loss:  3.2489
Saved new model
Current Epoch: 2 current accuracy: 0.6267 Best accuracy: 0.6267 at epoch 2
validation_accuracy: 0.6266666666666667 Epoch Number: 2
Validation Loss:  9.50794506072998
----------
epoch 3 average loss:  3.1199
Saved new model
Current Epoch: 3 current accuracy: 0.6267 Best accuracy: 0.6267 at epoch 3
validation_accuracy: 0.6266666666666667 Epoch Number: 3
Validation Loss:  9.076879978179932
----------
epoch 4 average loss:  2.9752
Saved new model
Current Epoch: 4 current accuracy: 0.6267 Best accuracy: 0.6267 at epoch 4
validation_accuracy: 0.6266666666666667 Epoch Number: 4
Validation Loss:  8.599461317062378
----------
epoch 5 average loss:  2.8133
Saved new model
Current Epoch: 5 current accuracy: 0

Now we will evaluate the performance of our algorithm on the test set. 

In [14]:
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import sklearn.metrics as metrics
from itertools import cycle
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from monai.transforms import (
    Compose,
    Activations,
    EnsureType)
import progressbar

test_data = data_dict["Test"]

random.Random(42).shuffle(test_data)

pytorch_Test = torchvision.transforms.Compose([
  torchvision.transforms.ToPILImage(),
  torchvision.transforms.Resize((int(299), int(299))),
  torchvision.transforms.ToTensor()]
)

test_ds = dataloader(test_data, transforms=pytorch_Test)
test_loader = torch.utils.data.DataLoader(test_ds, num_workers=4)

model_filename = model_path + '/best_model_vit.pth'
model = mod.ViT_Pretrained()
model.to(device)
model.load_state_dict(torch.load(model_filename))
model.eval()

y_pred_trans = Compose([EnsureType(), Activations(sigmoid=True)])

count = 0
Image_FN_LIST = []
Y_PRED_NP = np.zeros([len(test_data), num_classes])
y_predicted = []
y_true = []
y_test = []
y_score = []
GT_LIST = []

pb = progressbar.ProgressBar(maxval=len(test_data)+1)

pb.start()
with torch.no_grad():
  for _testImage in test_loader:
    image = torch.as_tensor(_testImage[0]).to(device)
    y_test.append(np.array(_testImage[1]))
    gt = _testImage[1].cpu().detach().numpy()
    GT_LIST.append(int(np.argmax(gt[0])))
    y_true.append(int(np.argmax(gt[0])))
    y = model(image)
    y_pred = y_pred_trans(y).to('cpu')
    y_predicted.append(torch.argmax(y_pred).item())
    pb.update(count)
    count += 1 

y_test = np.array([t.ravel() for t in y_test])
y_score = np.array([t.ravel() for t in Y_PRED_NP]) 

print(classification_report(y_true, y_predicted))



Resized positional embeddings from torch.Size([1, 50, 768]) to torch.Size([1, 82, 768])
Loaded pretrained weights.


 98% (148 of 151) |##################### | Elapsed Time: 0:00:09 ETA:   0:00:00

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         5
           1       0.83      1.00      0.91         5
           2       1.00      1.00      1.00         5
           3       1.00      1.00      1.00         5
           4       1.00      1.00      1.00         5
           5       1.00      1.00      1.00         5
           6       1.00      1.00      1.00         5
           7       1.00      1.00      1.00         5
           8       0.83      1.00      0.91         5
           9       1.00      1.00      1.00         5
          10       1.00      1.00      1.00         5
          11       1.00      1.00      1.00         5
          12       1.00      1.00      1.00         5
          13       1.00      1.00      1.00         5
          14       0.83      1.00      0.91         5
          15       1.00      0.80      0.89         5
          16       1.00      1.00      1.00         5
          17       1.00    