<img align="left" src="https://panoptes-uploads.zooniverse.org/project_avatar/86c23ca7-bbaa-4e84-8d8a-876819551431.png" type="image/png" height=100 width=100>
</img>
<h1 align="right">KSO Tutorials #5: Train machine learning models</h1>
<h3 align="right">Written by @jannesgg and @vykanton</h3>
<h5 align="right">Last updated: April 5, 2022</h5>

# Set up and requirements

### Import Python packages

In [None]:
# Set the directory of the libraries
import sys, os
from pathlib import Path
sys.path.append('..')

# Set to display dataframes as interactive tables
from itables import init_notebook_mode
init_notebook_mode(all_interactive=True)
from ipyfilechooser import FileChooser

# Import required modules
import kso_utils.tutorials_utils as t_utils
import kso_utils.server_utils as s_utils
import kso_utils.project_utils as p_utils
import kso_utils.t3_utils as t3
import kso_utils.t4_utils as t4
import kso_utils.t5_utils as t5
import kso_utils.t8_utils as t8
from src.prepare_zooniverse import frame_aggregation
from kso_utils.zooniverse_utils import retrieve_zoo_info, populate_subjects, populate_agg_annotations

# Model-specific imports
import yolo_train as train
import yolo_test as test
import yolo_detect as detect

print("Packages loaded successfully")

### Choose your project

In [None]:
project_name = t_utils.choose_project()

In [None]:
project = p_utils.find_project(project_name=project_name.value)

### Initiate SQL database and populate sites, movies and species

In [None]:
# Initiate db
db_info_dict = t_utils.initiate_db(project)

In [None]:
# Connect to Zooniverse project
zoo_project = t_utils.connect_zoo_project(project)

### Retrieve Zooniverse information

In [None]:
zoo_info_dict = t_utils.retrieve__populate_zoo_info(project = project, 
                                                    db_info_dict = db_info_dict,
                                                    zoo_project = zoo_project,
                                                    zoo_info = ["subjects", "workflows", "classifications"])

# Prepare the labelled frames

### Select species of interest and path to store the data

In [None]:
# Choose species of interest for model training
species_i = t4.choose_species(db_info_dict["db_path"])

In [None]:
# Store selected classes of interest
cl = list(species_i.value)

In [None]:
# Specify path to store the labelled frames and annotations
fc = t_utils.choose_folder(".", "output")

In [None]:
# Store selected output path
output_folder = fc.selected

### Aggregate classifications from Zooniverse

In [None]:
# Display a selectable list of workflow names and a list of versions of the workflow of interest
workflows_df = zoo_info_dict["workflows"]
wm = t8.WidgetMaker(workflows_df)
wm

In [None]:
# Retrieve classifications from the workflow of interest
class_df = t8.get_classifications(wm.checks,
                                   workflows_df, 
                                   'frame', 
                                   zoo_info_dict["classifications"], 
                                   db_info_dict["db_path"],
                                   project)

In [None]:
# Specify the agreement threshold required among cit scientists
agg_params = t8.choose_agg_parameters("frame")

In [None]:
agg_class_df, raw_class_df = t8.aggregrate_classifications(
                                    class_df, 'frame', project, agg_params)

In [None]:
# Add annotations to db
populate_agg_annotations(agg_class_df, 'frame', project)

### Download frames and aggregated annotations

In [None]:
# Determine your training parameters
percentage_test = t5.choose_test_prop()

In [None]:
# Run the preparation script
frame_aggregation(project, db_info_dict, output_folder, percentage_test.value, cl,
                  (720, 540), remove_nulls=True, track_frames=False, n_tracked_frames=0)

# Train and evaluate the ML model

In [None]:
batch_size, epochs, conf_thres = t5.choose_train_params()

In [None]:
# Fix important paths
data_path = [str(Path(output_folder, _)) for _ in os.listdir(output_folder) if \
             _.endswith(".yaml") and "hyp" not in _][-1]
hyps_path = str(Path(output_folder, "hyp.yaml"))
weights = "/usr/src/app/data_dir/weights/yolov5m.pt"

# Choose folder that will contain the different model runs
project_path = FileChooser('/cephyr/NOBACKUP/groups/snic2021-6-9/models/koster-ml')

# Project-specific information
entity = "koster"
exp_name = "test"
display(project_path)

### Train model with given configuration

In [None]:
train.run(entity=entity, data=data_path, hyp=hyps_path, weights=weights, 
          project=project_path.selected, name=exp_name,
          img_size=[1080, 720], batch=int(batch_size.value),
          epochs=epochs.value, workers=4, single_cls=False, cache_images=False)

### Evaluate model performance on test set

In [None]:
# Choose model
eval_model = FileChooser(project_path.selected)
display(eval_model)

In [None]:
# Find trained model weights
tuned_weights = f"{Path(project_path.selected, eval_model.selected, 'weights', 'best.pt')}"

In [None]:
# Evaluate YOLO Model on Unseen Test data for mAP metric

In [None]:
test.run(data=data_path, weights=tuned_weights, conf_thres=conf_thres.value, imgsz=640)

### Transfer model to web app server (for API use)

In [None]:
import getpass

In [None]:
server_user = getpass.getpass('Enter your server user')
server_pass = getpass.getpass('Enter your server password')

In [None]:
t5.transfer_model("medins_sp_720_test13", eval_model.selected, "koster/medins", server_user, server_pass)

# (Experimental) : Enhance annotations using trained model

In [None]:
detect.run(weights=tuned_weights, 
           source=output_folder+"/images",
           imgsz=720, conf_thres=0.7, save_txt=True, 
           project=project_path.selected)

In [None]:
# Choose runs
runs = FileChooser(".")
display(runs)

In [None]:
!mv {output_folder}"/labels" {output_folder}"/labels_org"
!mv {runs.selected}"/labels" {output_folder}"/labels"

In [None]:
#END