In [9]:
import wandb
import pandas as pd

from pathlib import Path

In [10]:
# Path to working directory
working_dir = "D:/SLEAP/20250102_generalizability_experiment/primary/sorghum" # This should be the same as the previous notebook

In [11]:
# Set the working directory
cwd = Path(working_dir)
print(f"Current working directory: {cwd}")

Current working directory: D:\SLEAP\20250102_generalizability_experiment\primary\sorghum


In [12]:
# Constants for W&B initialization
ENTITY_NAME = "eberrigan-salk-institute-for-biological-studies"
PROJECT_NAME = "sleap-roots"
EXPERIMENT_NAME = "sorghum-primary-pilot01-2025-01-03"  # Unique name for the experiment
CSV_PATH = cwd / "train_test_splits.csv"  # Path to the CSV file with the train/test splits
REGISTRY = "wandb-registry-model"
COLLECTION_NAME = EXPERIMENT_NAME  # Name of the collection to store the model artifact

# Tags for the model artifact
MODEL_TAGS = ["sorghum", "primary", "5-12DAG", "pilot01", "2025-01-03"]

In [13]:
def load_training_data(csv_path):
    """Loads training data from a CSV file.

    Args:
        csv_path (Path): Path to the CSV file containing training data.

    Returns:
        pandas.DataFrame: DataFrame containing the training data.
    """
    return pd.read_csv(csv_path)

def get_training_groups(df):
    """Groups training data by version.

    Args:
        df (pandas.DataFrame): DataFrame containing the training data.

    Returns:
        pandas.core.groupby.DataFrameGroupBy: Grouped DataFrame.
    """
    return df.groupby("version")

def fetch_model_artifact_from_experiment(project_name, entity_name, artifact_name, experiment_version=None):
    """Fetches a specific version of a model artifact from a W&B experiment.

    Args:
        project_name (str): Name of the W&B project.
        entity_name (str): Name of the W&B entity.
        artifact_name (str): Name of the artifact to fetch.
        experiment_version (str, optional): Specific version from the training run names to fetch. Defaults to latest.

    Returns:
        wandb.Artifact: The fetched artifact.
    """
    run = wandb.init(project=project_name, entity=entity_name, job_type="fetch_artifact")
    artifact_version = f":{experiment_version}" if experiment_version else ":latest"
    artifact = run.use_artifact(f"{artifact_name}{artifact_version}")
    artifact_dir = artifact.download()
    print(f"Fetched artifact '{artifact_name}{artifact_version}' to directory '{artifact_dir}'.")
    run.finish()
    return artifact



def fetch_model_artifact_and_link_to_registry(project_name, entity_name, artifact_name, registry_name, collection_name, experiment_version=None):
    """Fetchs a specific version of a model artifact from a W&B experiment and links it to the registry.
    
    Args:
        project_name (str): Name of the W&B project.
        entity_name (str): Name of the W&B entity.
        artifact_name (str): Name of the artifact to fetch.
        registry_name (str): Name of the registry to link the artifact to.
        collection_name (str): Name of the collection to store the model artifact.
        experiment_version (str, optional): Specific version from the training run names to fetch. Defaults to latest.
    """
    run = wandb.init(project=project_name, entity=entity_name, job_type="fetch_artifact")
    artifact_version = f":{experiment_version}" if experiment_version else ":latest"
    artifact = run.use_artifact(f"{artifact_name}{artifact_version}")

    # Link the artifact to the registry
    full_registry_name = f"{entity_name}/{registry_name}/{collection_name}"
    run.link_artifact(artifact, registry=full_registry_name)
    print(f"Linked artifact '{artifact_name}{artifact_version}' to registry '{full_registry_name}'.")
    run.finish()

In [14]:
def main(csv_path):
    """Main function to add model artifacts to the W&B registry.
    
    Args:
        csv_path (Path): Path to the CSV file containing train-test splits paths.
    """
    df = load_training_data(csv_path)
    grouped = get_training_groups(df)

    for version, group in grouped:
        print(f"Processing version {version}...")
        print(f"Group: {group}")
        
        # Build artifact name from version
        artifact_name = f"{EXPERIMENT_NAME}_training_v00{version}"
        # Fetch the model artifact from the experiment and link it to the registry
        fetch_model_artifact_and_link_to_registry(PROJECT_NAME, ENTITY_NAME, artifact_name, REGISTRY, COLLECTION_NAME, version)