## 2. Define a TestSuite

In the second phase of SDMT, we define a `TestSuite` that represents the tests the completed model must will have to pass in order to be acceptable for use in the system into which it will be integrated.

#### Initialize MLTE Context

MLTE contains a global context that manages the currently active _session_. Initializing the context tells MLTE how to store all of the artifacts that it produces.

In [None]:
import os
from mlte.session import set_context, set_store

store_path = os.path.join(os.getcwd(), "store")
os.makedirs(
    store_path, exist_ok=True
)  # Ensure we are creating the folder if it is not there.

set_context("IrisClassifier", "0.0.1")
set_store(f"local://{store_path}")

#### Build a `TestSuite`

In MLTE, we define the tests that will be required for the different requirements in a `TestSuite`. Note that a new `Evidence` types (`ConfusionMatrix`) was created in this case to simplify the definition the `Validator` for that case.

In [None]:
from mlte.tests.test_case import TestCase
from mlte.tests.test_suite import TestSuite

from mlte.measurement.storage import LocalObjectSize
from mlte.measurement.cpu import LocalProcessCPUUtilization
from mlte.measurement.memory import LocalProcessMemoryConsumption
from confusion_matrix import ConfusionMatrix
from mlte.evidence.types.real import Real
from mlte.evidence.types.image import Image

spec = TestSuite(
    test_cases=[
        TestCase(
            identifier="accuracy",
            goal="Understand if the model is useful for this case",
            qas_list=["qas1"],
            validator=Real.greater_or_equal_to(0.98),
        ),
        TestCase(
            identifier="confusion matrix",
            goal="Understand if the model is useful for this case",
            qas_list=["qas2"],
            validator=ConfusionMatrix.misclassification_count_less_than(2),
        ),
        TestCase(
            identifier="class distribution",
            goal="Understand if the model is useful for this case",
            qas_list=["qas3"],
            validator=Image.register_info("Inspect the image."),
        ),
        TestCase(
            identifier="model size",
            goal="Check resource consumption",
            qas_list=["qas4"],
            validator=LocalObjectSize.get_output_type().less_than(3000),
        ),
        TestCase(
            identifier="training memory",
            goal="Check resource consumption",
            qas_list=["qas4"],
            validator=LocalProcessMemoryConsumption.get_output_type().average_consumption_less_than(
                60000
            ),
        ),
        TestCase(
            identifier="training cpu",
            goal="Check resource consumption",
            qas_list=["qas4"],
            validator=LocalProcessCPUUtilization.get_output_type().max_utilization_less_than(
                5.0
            ),
        ),
    ]
)
spec.save(parents=True, force=True)