In [None]:
from pprint import pprint

import dagshub
import pandas as pd
from mlflow.client import MlflowClient
from mlflow.entities import ViewType

In [None]:
REPO_NAME = 'driver-tae'
USER_NAME = 'matejfric'
dagshub.init(REPO_NAME, USER_NAME, mlflow=True)  # type: ignore

client = MlflowClient()

## Get all experiment runs

In [None]:
experiments = client.search_experiments(filter_string="name!='Default'")
pprint([experiment.name for experiment in experiments])

In [None]:
# Get all runs from the experiments
all_runs = []
for experiment in experiments:
    runs = client.search_runs(
        experiment_ids=[experiment.experiment_id],
        filter_string='',
        run_view_type=ViewType.ACTIVE_ONLY,
    )
    all_runs.extend(runs)

# Create a DataFrame from the runs
runs_df = pd.DataFrame(
    [
        {
            'run_id': r.info.run_id,
            'experiment_id': r.info.experiment_id,
            'experiment_name': client.get_experiment(r.info.experiment_id).name,
            'status': r.info.status,
            'start_time': pd.to_datetime(r.info.start_time, unit='ms'),
            'end_time': pd.to_datetime(r.info.end_time, unit='ms')
            if r.info.end_time
            else None,
            'artifact_uri': r.info.artifact_uri,
            **r.data.params,  # Add all parameters
            **{
                f'metric.{k}': v for k, v in r.data.metrics.items()
            },  # Add all metrics with "metric." prefix
        }
        for r in all_runs
    ]
)

# Display the DataFrame
print(f'Total runs: {len(runs_df)}')
runs_df.head()

## Filtering

In [None]:
integer_columns = ['image_size', 'latent_dim', 'batch_size', 'early_stopping']
runs_df[integer_columns] = runs_df[integer_columns].astype(int)

In [None]:
idx = runs_df.groupby(['driver', 'source_type', 'latent_dim', 'image_size'])[
    'metric.roc_auc'
].idxmax()
best_runs_df = runs_df.loc[idx]

In [None]:
df = best_runs_df[
    (best_runs_df['image_size'] == 64) & (best_runs_df['latent_dim'] == 128)
]
df

## Download predictions

In [None]:
for run_id in best_runs_df['Run ID']:
    local_path = client.download_artifacts(run_id, 'predictions.json', 'outputs')