# yolov5 모델 학습하기(배경 이미지 없음)

## 모델 학습 및 평가

In [None]:
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import mlflow
from PIL import Image
from ultralytics import YOLO

working_dir = Path(os.environ.get("WORKING_DIRECTORY"))
mlflow_uri =  os.environ.get("MLFLOW_URI")
mlflow.set_tracking_uri(mlflow_uri)

experiment = mlflow.get_experiment_by_name("train_base_model")
if experiment is None:
    experiment_id = mlflow.create_experiment("train_base_model")
    print("Experiment train_base_model is created")
else:
    experiment_id = experiment.experiment_id

data_name = 'data/filtered_data_nobg/data.yaml'
# Start an MLflow run
with mlflow.start_run(experiment_id=experiment_id) as run:
    # Load and train your YOLOv5 model
    model = YOLO('yolov5su.pt')
    results = model.train(data=data_name, epochs=3)

    # Log parameters, metrics, and model
    mlflow.log_params({"epochs": 1, "model": "yolov5su.pt", "data": data_name})
    metrics = {key.replace("(B)", ""):value for key, value in results.results_dict.items()}
    mlflow.log_metrics(metrics)
    mlflow.log_artifact(local_path=working_dir / results.save_dir / "weights" / "best.pt")

    # Register model
    mlflow.register_model(
            f"runs:/{run.info.run_id}/yolov5_model", "yolov5_model"
        )
    
    # End the run
    mlflow.end_run()

## 학습된 모델을 활용하여 예측

In [None]:
%matplotlib inline

save_dir = "../runs/mlflow/detect/train8" # 학습 로그를 확인하여 입력
image_path = "data/filtered_data200/test/images/02_021_02011027_160655908976565_0_jpeg.rf.1468ccfe1571b554f3aeee8fd4914f89.jpg" # 이미지 경로
# image_path = "data/filtered_data200/test/images/Img_058_0191_jpg.rf.cff970fc226e826de067ce5aea4e61fa.jpg" # 이미지 경로
# image_path = "data/KoreanFOOD_Detecting/test/images/Img_119_0004_jpg.rf.f0015227fa22eb225471725f1ff18f91.jpg" # 이미지 경로

model = YOLO(working_dir/save_dir/"weights"/"best.pt")
image = Image.open(image_path)
results = model(image)
image_bgr = cv2.cvtColor(results[0].plot(), cv2.COLOR_RGB2BGR)
plt.imshow(image_bgr)
plt.axis("off")
plt.show()
