In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from main.tooling.Logger import logging_setup

logger = logging_setup("setup")

## Check MLflow Connection

In [None]:
import mlflow
import os
if not os.getenv("MLFLOW_TRACKING_URI"):
   raise Exception("Mlflow not configured")

import requests
try:
    requests.get(str(mlflow.get_tracking_uri()))
except requests.exceptions.ConnectionError:
    raise ConnectionError(f"mlflow not reachable, please run mlflow server")

logger.info("mlflow availiable at %s", mlflow.get_tracking_uri())


## Configurate Experiment Parameters

In [None]:
experimentName = "RelevanceClassifier"
runName = "AllDatasets"
modelName = "Iteration_1_model"

## Check existing MLflow runs

In [None]:
from mlflow.entities import ViewType

if mlflow.get_experiment_by_name(experimentName) is None:
        mlflow.create_experiment(experimentName)

mlflow_runs = mlflow.search_runs(
        experiment_names=[experimentName],
        filter_string=f"attributes.status = 'FINISHED' AND params.project.run_name = '{runName}'",                
        max_results=1,
        run_view_type=ViewType.ACTIVE_ONLY,
        order_by=["metrics.accuracy DESC"],        

)

logger.info("mlflow run found: %s", not mlflow_runs.empty)


## Model Loading via REST request to the remote MLflow server

#### There are problems with downloading the fine-tuned model from the remote MLflow server. When downloading the model via the normal way (see "## Model Loading via mlflow.download_artifacts() (the normal way)" code cell), only a certain part of the model is downloaded (in our case it was &sim;130MB out of 438MB). Therefor, we have to download the model via a HTTP request to the remote MLflow server artifacts!

In [None]:
import requests
from pathlib import Path
from main.tooling.FileManager import getModelPath
from main.tooling.FileManager import cleanup

cleanup()

runID = mlflow_runs.iloc[0].run_id

def downloadMLflowArtifacts(mlflowURLS: list[str], artifactNames: list[str]) -> None:
    """
        Description:
            This method requests artifacts from the remote MLflow server and saves them in the getModelPath(f'{modelName}') directory.
        Args:
            list[str]: The urls' to request
            list[str]: The artifact names for the urls
        Returns:
            None: Saves the artifacts from the remote MLflow server in the getModelPath(f'{modelName}') directory
    """
    
    for idx, mlflowURL in enumerate(mlflowURLS):
        
        if not Path(getModelPath(f'{modelName}')).exists():
            Path(getModelPath(f'{modelName}')).mkdir()
        
        save_path = os.path.join(getModelPath(f'{modelName}'), artifactNames[idx])
            
        response = requests.get(mlflowURL, auth=(os.getenv('MLFLOW_TRACKING_USERNAME'), os.getenv('MLFLOW_TRACKING_PASSWORD')))
        response.raise_for_status()
        
        with open(save_path, 'wb') as file:
            file.write(response.content)

        logger.info(f"Downloaded {artifactNames[idx]} to {save_path}")

artifactNames = [
    "config.json",
    "special_tokens_map.json",
    "tokenizer_config.json",
    "training_args.bin",
    "vocab.txt",
    "model.safetensors"
]

mlflowURLS = [
    f"{os.getenv('MLFLOW_TRACKING_URI')}/get-artifact?path={modelName}/{artifactNames[0]}&run_uuid={runID}",
    f"{os.getenv('MLFLOW_TRACKING_URI')}/get-artifact?path={modelName}/{artifactNames[1]}&run_uuid={runID}",
    f"{os.getenv('MLFLOW_TRACKING_URI')}/get-artifact?path={modelName}/{artifactNames[2]}&run_uuid={runID}",
    f"{os.getenv('MLFLOW_TRACKING_URI')}/get-artifact?path={modelName}/{artifactNames[3]}&run_uuid={runID}",
    f"{os.getenv('MLFLOW_TRACKING_URI')}/get-artifact?path={modelName}/{artifactNames[4]}&run_uuid={runID}",
    f"{os.getenv('MLFLOW_TRACKING_URI')}/get-artifact?path={modelName}/{artifactNames[5]}&run_uuid={runID}"
]

downloadMLflowArtifacts(mlflowURLS, artifactNames)


## Model Loading via mlflow.download_artifacts() (the normal way)

In [None]:
# from main.tooling.FileManager import getModelPath
# from main.tooling.FileManager import cleanup

# cleanup()

# if not mlflow_runs.empty:
    # Extract run ID
#     default_run_id = mlflow_runs.iloc[0]['run_id']
    
    # Load run with the run ID
#     default_run = mlflow.get_run(default_run_id)
    
#     logger.info("Download Model from mlflow to: %s", getModelPath(""))    

#     mlflow.artifacts.download_artifacts(artifact_uri=
#         f"{default_run.info.artifact_uri}/{modelName}",
#         dst_path=getModelPath(""),
# )

# else:
#     logger.info("The requested model does not exist!")
