In [None]:
from fiftyone import plugins

plugins.download_plugin(
    url_or_gh_repo="https://github.com/voxel51/fiftyone_mlflow_plugin/"
)

plugins.install_plugin_requirements(
    plugin_name="@voxel51/mlflow"
)

## Prepping for Training

Let's kick things off by loading in all of our required libraries. While we are at it, we will start our MLflow client and specifying our `tracking_uri`

In [None]:
import os
import fiftyone as fo
import fiftyone.utils.random as four

os.environ["MLFLOW_TRACKING_URI"] = "http://127.0.0.1:5000"

### Load dataset

In [None]:
dataset = fo.load_dataset("lecture_dataset_train")

dataset = dataset.clone()

### Stratified sample

In [None]:
from fiftyone import ViewField as F

sample_size = 0.051  # adjust as needed

class_counts = dataset.count_values("ground_truth.detections.label")

class_samples = {cls: int(count * sample_size) for cls, count in class_counts.items()}

stratified_sample = fo.Dataset() # instantiate an empty dataset
stratified_sample.default_classes = dataset.default_classes # copy the default classes from the original dataset

for label, sample_count in class_samples.items():
    existing_ids = stratified_sample.values("id") # list of ids for samples already added
    filter_expression = F("label") == label 
    class_view = dataset.match_labels(filter = filter_expression, fields="ground_truth", bool=True)
    subset_view = class_view.take(sample_count, seed=51) # take a random sample of the view
    stratified_sample.add_samples(subset_view.exclude(existing_ids)) #add the samples in the view to the stratified_sample dataset
    stratified_sample.shuffle(seed=51) # shuffle the dataset

### Filter dataset based on image quality

In [None]:
import fiftyone.operators as foo

compute_brightness = foo.get_operator(
    "@jacobmarks/image_issues/compute_brightness"
)

compute_brightness(stratified_sample)

In [None]:
too_bright_value = stratified_sample.mean("brightness") + 3 * stratified_sample.std("brightness")

brightness_filter = F("brightness") < too_bright_value

brightness_filtered_view = stratified_sample.match(brightness_filter) 

brightness_filtered_dataset = brightness_filtered_view.clone()

In [None]:
fo.launch_app(brightness_filtered_dataset)

### Export FiftyOne Dataset to YOLO format

In [None]:
# load the training config
import yaml

config_path = '/home/harpreet/workspace/Hands-on-Data-Centric-Visual-AI/training_helpers/training_config.yaml'
with open(config_path, 'r') as file:
    training_config = yaml.safe_load(file)

In [None]:
four.random_split(brightness_filtered_dataset, {"train": training_config['train_split'], "val": training_config['val_split']})

In [None]:
brightness_filtered_dataset.export(
    export_dir="./model_training/data",
    dataset_type=fo.types.YOLOv5Dataset,
    label_field="ground_truth",
    classes=brightness_filtered_dataset.default_classes,
    split='train'
)

In [None]:
brightness_filtered_dataset.export(
    export_dir="./model_training/data",
    dataset_type=fo.types.YOLOv5Dataset,
    label_field="ground_truth",
    classes=brightness_filtered_dataset.default_classes,
    split= 'val'
)

### Start the MLflow Server
Before we begin, we will start our MLflow server locally to serve as our backend for the demo. Open the terminal and enter the following the same project directory:

```
mlflow server --backend-store-uri model_training/runs/mlflow
```

In [None]:
import fiftyone.operators as foo
from ultralytics import YOLO, settings

settings.update({"mlflow": True})

EXPERIMENT_NAME = "model_training/brightness_filtered"
RUN_NAME = "run-1"
LABEL_FIELD = "predictions" 

log_mlflow_run = foo.get_operator("@voxel51/mlflow/log_mlflow_run")

model = YOLO("yolov8m.pt")

In [None]:
results = model.train(
    data = "./model_training/data/dataset.yaml",
    project=EXPERIMENT_NAME,
    name=RUN_NAME,
    **training_config['train_params']
)

log_mlflow_run(
    brightness_filtered_dataset, 
    EXPERIMENT_NAME, 
    run_name=RUN_NAME, 
    predictions_field=LABEL_FIELD
)

### Evaluate model

In [None]:
results = dataset.evaluate_detections(pred_field="predictions", gt_field="ground_truth", eval_key="eval", compute_mAP=True)
