In [None]:
from typing import Any, cast

from DeepPurpose import utils, CompoundPred
from tdc.single_pred import ADME
from tdc.utils import retrieve_dataset_names

from utils.printing import IPythonPrinter
from multi_model import MultiModel, MultiModelSinglePredictionModelArgument

In [None]:
printer = IPythonPrinter()

In [None]:
adme_datasets = cast(list[str], retrieve_dataset_names("ADME"))
print(f"Number of ADME datasets available from TDC: {len(adme_datasets)}")
printer.print_markdown("**Datasets available:**")
for dataset in adme_datasets:
    printer.print_markdown(f"1. {dataset}")
printer.flush()

In [None]:
augmented_models = [
    MultiModelSinglePredictionModelArgument(
        name=dataset_name,
        model=None,
        dataloader=ADME(name=dataset_name),
    ) for dataset_name in adme_datasets
]
augmented_models

In [None]:
multi_model = MultiModel(augmented_models=augmented_models)
print(multi_model.augmented_models)

In [None]:
multi_model.train()

In [None]:
def save(multi: MultiModel, model_dir) -> None:
    """
    Save the model to a directory.

    :param model_dir: The directory to save the model to.
    """
    model_dir = Path(model_dir)  # Ensure that model_dir is a Path object
    model_dir.mkdir(parents=True, exist_ok=True)
    for augmented_model in multi.augmented_models:
        augmented_model.model.save_model(
            str(
                model_dir / augmented_model.name
            )  # No need to sanizize the name here, it is already sanitized in the __init__ method
        )

In [None]:
from pathlib import Path
multi_model.save(model_dir=Path("../../Data/property_predictors/multi_models/adme"))