In [None]:
import torch
from IPython.display import clear_output  # to display images
from hyperopt import fmin, tpe, hp, Trials, space_eval
from ultralytics import YOLO

print(f"Setup complete. Using torch {torch.__version__} ({torch.cuda.get_device_properties(0).name if torch.cuda.is_available() else 'CPU'})")

In [None]:
def objective(params):
    # Define el modelo con los parámetros dados
    model = YOLO('yolov8n-cls.pt')
    
    # Entrena el modelo
    model.train(data="/tf/data/Mammographies/yolo_data/abnormality_classification",
                task="classify",
                epochs=10,
                dropout=params["dropout"],
                lr0=params["lr"],                
                batch=2,
                augment=True,
                imgsz=320)
    
    # Evalúa el modelo
    metrics = model.val()
    
    # Devuelve la métrica a optimizar
    return -metrics.results_dict["metrics/accuracy_top1"]

# Define los rangos de los hiperparámetros
space = {
    'dropout': hp.choice('dropout',[0.0, 0.2, 0.4]),
    'lr': hp.loguniform('lr', -6, -1),
}

trials = Trials()

# Realiza la búsqueda de hiperparámetros
best = fmin(objective, space, algo=tpe.suggest, max_evals=15, trials=trials)

In [None]:
trials.trials

In [None]:
best = space_eval(space, best)

In [None]:
# Define el modelo con los parámetros dados
model = YOLO('yolov8n-cls.pt')

# Entrena el modelo
model.train(data="/tf/data/Mammographies/yolo_data/abnormality_classification",
            epochs=100,
            dropout=best["dropout"],
            lr0=best["lr"],                
            batch=2,
            augment=True,
            imgsz=320)

In [None]:
# Evalúa el modelo
metrics = model.val()

In [None]:
metrics

In [None]:
model.export()

In [None]:
from PIL import Image
model_load=YOLO("./runs/classify/train3/weights/best.pt")
image=Image.open(f"/tf/data/Mammographies/yolo_data/abnormality_classification/train/Distortion/2.png").convert("RGB")
results=model_load(image)[0] #Saved on runs/detect/predict
print(results.probs.data.tolist())
results.names