# STEP 3: Model Training

### Task 1: Preprocessing your dataset 
Here we divide the dataset to training and validation to find tune the model

In [None]:
!pip install detecto

In [None]:
from detecto import core, utils
from torchvision import transforms
import matplotlib.pyplot as plt
from pathlib import Path

import logging
import warnings

logging.basicConfig(level=logging.CRITICAL)
warnings.filterwarnings('ignore')

images_path = "./images"
training_filepath =  "./resources/train_labels.csv"
validation_filepath = "./resources/val_labels.csv"
classes_filepath =  "./resources/predefined_classes.txt"

with open(classes_filepath) as f:
    classes_list = [line.rstrip() for line in f]
        
# Convert XML files to CSV format
utils.xml_to_csv(str(images_path), str(training_filepath))
utils.xml_to_csv(str(images_path), str(validation_filepath))

# Define custom transforms to apply to your dataset
custom_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(800),
    transforms.ColorJitter(saturation=0.3),
    transforms.ToTensor(),
    utils.normalize_transform(),
])

# Pass in a CSV file instead of XML files for faster Dataset initialization speeds
dataset = core.Dataset(training_filepath, images_path, transform=custom_transforms)
print('Number of samples in training dataset = {}.'.format(len(dataset)))

val_dataset = core.Dataset(validation_filepath, images_path)  # Validation dataset for training
print('Number of samples in validation dataset = {}.'.format(len(val_dataset)))

# Create your own DataLoader with custom options
if len(dataset) > 0:
    loader = core.DataLoader(dataset, batch_size=2, shuffle=True) 

    print([dataset[i][1]['labels'] for i in range(len(dataset))])
else:
    print("ERROR: Your Dataset is empty! Are you sure you did step 2? Did you save your labelled pictures?")

In [None]:
# for i in range(len(dataset)):
#     print(i)
#     if i in [15,36,93,237,247,282]:
#         continue
#     print(dataset[i][1])

### Task 2: Training Session. Decide how many times you want to re-train the model
The number of training sessions is called "epochs". The more epochs, the longer time the model takes, but the higher accuracy it might reach. 

In [None]:
num_epochs = int(input('The number of models training sessions:'))

if num_epochs < 1:
    print('The number of epochs has to be more than 1.') 
    num_epochs = 1

In [None]:
import logging
import warnings

logging.basicConfig(level=logging.CRITICAL)
warnings.filterwarnings('ignore')

model = core.Model(classes_list, model_name='fasterrcnn_mobilenet_v3_large_fpn')
losses = model.fit(loader, val_dataset, epochs=num_epochs, learning_rate=0.001, verbose=True)

save_custom_model_filepath = "./resources/custom_model_weights.pth"
model.save(save_custom_model_filepath)  # Save model to a file
print('The model is saved, you can move to next step!')