# Imports

In [1]:
import os

## Configure dataset file paths and classes


Run both code blocks below to change the dataset path and classes

In [None]:
%%writefile custom_data.yaml
# Images and labels direcotry should be relative to train.py
TRAIN_DIR_IMAGES: '../dataset/train'
TRAIN_DIR_LABELS: '../dataset/train'
VALID_DIR_IMAGES: '../dataset/valid'
VALID_DIR_LABELS: '../dataset/valid'

# Class names.
CLASSES: [
    '__background__',
    'button',
    'input',
    'checkbox',
    'dropdown',
    'label',
    'icon',
    'radio',
    'switch'
]

# Number of classes (object classes + 1 for background class in Faster RCNN).
NC: 9

# Whether to save the predictions of the validation set while training.
SAVE_VALID_PREDICTION_IMAGES: True

## Train Config Setup
Run the code block for train setup and then run the code below to start training or copy the generic command and run it in the terminal

In [None]:
train_file = 'train.py'
eval_file = 'eval.py'
export_file = 'export.py'
inference_file = 'onnx_inference_image.py'


model = 'resnet101'
model_name = 'output_resnet101'
data = 'custom_data.yaml'
inference_dataset = '../dataset/test/'
epoch_num = 40
batch_size = 4
image_size = 1024
width = image_size
height = image_size
inference_score_threshold = 0.5

## Training

In [None]:
train_command = f'python {train_file} --model {model} --data {data} --epochs {epoch_num} --batch {batch_size} --imgsz {image_size} --name {model_name} -st'
print(f"Either run the following command in terminal or run the cell below:\n{train_command}")

In [None]:
os.system(train_command)

Terminal Command to Run Training

In [None]:
!python train.py --model <model_name> --data custom_data.yaml --epochs <epoch> --batch <batch> --imgsz <size>

## Evaluation
Evaluate the model on the distribution of classes in the validation set

In [None]:
eval_command = f'python {eval_file} --weights outputs/training/{model_name}/best_model.pth --data {data} --model {model} --verbose'
print(f"Either run the following command in terminal or run the cell below:\n{eval_command}")

In [None]:
os.system(eval_command)

Terminal Command to Run Evaluation

In [None]:
!python eval.py --weights outputs/training/<model_name>/best_model.pth --data custom_data.yaml --model <model_name> --verbose

# Export model
Export model for deployment/inference.
Exported models are saved as model.onnx files in a folder structured as weights/model_name/number/model.onnx

In [None]:
export_command = f'python {export_file} --model {model} -w outputs/training/{model_name}/best_model.pth --data {data} --file_name {model_name} --width {image_size} --height {image_size}'
print(f"Either run the following command in terminal or run the cell below:\n{export_command}")

In [None]:
os.system(export_command)

Terminal Command to Run Export

In [None]:
!python export.py --model <model_name> --data custom_data.yaml --out <name> --width <size> --height <size> --file_name <f_name>

# Inferences

In [None]:
inference_command = f'python {inference_file} -i {inference_dataset} --data {data} -w weights/{model}/{model_name}/model.onnx -th {inference_score_threshold} -nlb -ncsv'
print(f"Either run the following command in terminal or run the cell below:\n{inference_command}")

In [None]:
os.system(inference_command)

Terminal Command to Run Inference

In [None]:
!python onnx_inference_image.py -i datasets/<file> --data custom_data.yaml -w weights/<model_name>/<number>/model.onnx -th 0.5 -nlb -ncsv --image <size> --batch <size> --epoch <count>

## Inference viewer

In [None]:
old_num = 0
infer_path = ''
for file in '/outputs/inference/':
    new_num = int(file.split('_')[1])
    if new_num > old_num:
        old_num = new_num
        infer_path = file

In [None]:
from cv2 import imshow as cv2_imshow
import glob as glob
import os
import cv2
try:
  images = glob.glob(infer_path + '/*.jpg')
  for i in range(len(images)):
    image_name = images[i].split(os.path.sep)[-1].split('.')[0]
    image = cv2.imread(images[i])

  # if image.shape[0] > 640:
  #   cv2_imshow('',image)
except:
  print('No images found in inference folder')