In [None]:
DAGSHUB_REPO_OWNER = "chenhaitao405"
DAGSHUB_USER = "chenhaitao405"
DAGSHUB_REPO = "test_dagshub"
DAGSHUB_REPO_BRANCH = "main"
DAGSHUB_FULL_REPO = DAGSHUB_REPO_OWNER + "/" + DAGSHUB_REPO


MODEL_NAME = "cocotest"
MODEL_BASE = "yolo11m-seg.pt"

MLFLOW_PROJECT = "ActiveLearning"
MLFLOW_TRACKING_URI = f"https://dagshub.com/{DAGSHUB_USER}/{DAGSHUB_REPO}.mlflow"


MAX_IMAGES = 10
MAX_SCORE = 0.75

In [None]:
from ultralytics import YOLO
import ultralytics
import dagshub
import mlflow
import re

In [None]:
dagshub.init(repo_owner='chenhaitao405', repo_name='test_dagshub', mlflow=True)

In [None]:
def load_latest_model(name):
    client = mlflow.MlflowClient()
    model_version = client.get_latest_versions(name=name)[0].version

    model_uri = f'models:/{name}/{model_version}'

    model = mlflow.pyfunc.load_model(model_uri)
    return model.unwrap_python_model().model

def load_model(name, version):
    client = mlflow.MlflowClient()

    model_uri = f'models:/{name}/{version}'

    model = mlflow.pyfunc.load_model(model_uri)
    return model.unwrap_python_model().model

In [None]:
def log_test_set(model_name, dataset, device):
    # Load the latest model from the MLflow registry. This will be the best model from the previous training
    model = load_latest_model(model_name)

    # Run the test set on the model and log the metrics to MLflow
    metrics = model.val(data=dataset, device=device, split='test')
    metrics_dict = {f"test/{re.sub('[()]', '', k)}": float(v) for k, v in metrics.results_dict.items()}
    mlflow.log_metrics(metrics=metrics_dict)

In [None]:
from data_utils.dagshub_yolo_cb import generate_callbacks_fn
ultralytics.utils.callbacks.add_integration_callbacks = generate_callbacks_fn(MODEL_NAME)

In [None]:
# load a pretrained model (recommended for training)

model = YOLO('yolo11m-seg.pt')

device = '0'
dataset = 'custom_yolo.yaml'

with mlflow.start_run():

    # Train the model
    model.train(data=dataset, epochs=1, imgsz=640, device=device, project=MLFLOW_PROJECT)

    # Run and log the test set
    log_test_set(MODEL_NAME, dataset, device)