In [1]:
from datetime import datetime
from itertools import product
from functools import partial
import operator
from dataclasses import dataclass

import pandas as pd
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier 
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, f1_score
import plotly
import plotly.figure_factory as ff

from flytekit import task, map_task, ImageSpec, Deck, workflow
from flytekit.configuration import Config
from flytekit.remote import FlyteRemote


In [2]:
remote = FlyteRemote(
    Config.for_endpoint("demo.hosted.unionai.cloud"),
    default_project="flytesnacks", 
    default_domain="development",
    interactive_mode_enabled=True,
)

image = ImageSpec(
    builder="default",
    registry="ghcr.io/granthamtaylor",
    name='byoc-sandbox',
    packages=[
        "scikit-learn==1.5.1",
        "pandas==2.2.2",
        'pyarrow',
        'fastparquet',
        'matplotlib',
        'plotly',
        "git+https://github.com/flyteorg/flytekit.git@60499483afd43585279066389357dfe7cfb544e5",
    ],
    apt_packages=["git"]
)

In [3]:
@dataclass
class SearchSpace:
    max_depth: list[int]
    max_features: list[str | None]
    n_estimators: list[int]


@dataclass
class Hyperparameters:
    max_depth: int
    max_features: str | None
    n_estimators: int


In [4]:
@task(container_image=image)
def get_dataframe() -> pd.DataFrame:
    """
    Retrieves a pandas DataFrame containing wine data.
    """
    
    return load_wine(as_frame=True).frame


In [5]:
@task(container_image=image)
def create_search_grid(searchspace: SearchSpace) -> list[Hyperparameters]:
    """
    Generate a search grid based on the given dictionary of lists.
    """
    
    keys = vars(searchspace).keys()
    values = [getattr(searchspace, key) for key in keys]
    
    grid = [Hyperparameters(**dict(zip(keys, combination))) for combination in product(*values)]
    
    return grid


In [6]:
def split(dataframe: pd.DataFrame, test_size: float=0.25) -> tuple[pd.DataFrame, ...]:
    """
    Split the given dataframe into training and testing sets.
    """
    
    targets = dataframe["target"]
    
    return train_test_split(
        dataframe.drop(columns = ["target"]),
        targets,
        test_size=test_size,
        stratify=targets,
        random_state=42
    )

In [7]:
@task(container_image=image)
def train_model(dataframe: pd.DataFrame, hyperparameters: Hyperparameters) -> RandomForestClassifier:
    """
    Trains a random forest classifier model using the given dataframe and hyperparameters.
    """
    
    X_train, X_test, y_train, y_test = split(dataframe)

    model = RandomForestClassifier(**vars(hyperparameters))
    model.fit(X_train, y_train)

    return model


In [8]:
def plot_confusion_matrix(y_true, y_pred):
    """
    Plots a confusion matrix based on the true labels and predicted labels.
    """
    
    array = confusion_matrix(y_true, y_pred)
    
    labels = y_true.unique().tolist()

    # change each element of z to type string for annotations
    z_text = [[str(y) for y in x] for x in array.tolist()]

    # set up figure 
    fig = ff.create_annotated_heatmap(array, x=labels, y=labels, annotation_text=z_text, colorscale='Viridis')

    # add title
    fig.update_layout(
        title_text='Confusion Matrix',
        xaxis = dict(title='Predicted Label'),
        yaxis = dict(title='Actual Label')
    )

    # add custom xaxis title
    fig.add_annotation(dict(
        font=dict(color="black",size=14),
        x=0.5,
        y=-0.15,
        showarrow=False,
        text="Predicted value",
        xref="paper",
        yref="paper",
    ))

    # add custom yaxis title
    fig.add_annotation(dict(
        font=dict(color="black",size=14),
        x=-0.35,
        y=0.5,
        showarrow=False,
        text="Real value",
        textangle=-90,
        xref="paper",
        yref="paper"
    ))

    # adjust margins to make room for yaxis title
    fig.update_layout(margin=dict(t=50, l=200))

    # add colorbar
    fig['data'][0]['showscale'] = True
    
    return fig



In [9]:
@task(container_image=image, enable_deck=True)
def compare_model_results(
    dataframe: pd.DataFrame,
    models: list[RandomForestClassifier],
    hyperparameters: list[Hyperparameters],
) -> RandomForestClassifier:
    """
    Compares the results of different models on a given dataframe using specified hyperparameters.
    """
    
    X_train, X_test, y_train, y_test = split(dataframe)
    
    scores: list[float] = []
    
    for model in models:
        
        yhat = model.predict(X_test)
        
        score = f1_score(y_pred=yhat, y_true=y_test, average="macro")
        
        scores.append(score)
    
    which_best, _ = max(enumerate(scores), key=operator.itemgetter(1))
    
    df = pd.DataFrame.from_records(list(map(vars, hyperparameters)))
    df['scores'] = scores
    df.sort_values(by='scores', ascending=False, inplace=True)
    
    print(df)

    Deck("Model Results", df.to_html())
    
    return models[which_best]


In [10]:

@task(container_image=image, enable_deck=True)
def analyze_model(
    model: RandomForestClassifier,
    dataframe: pd.DataFrame,
) -> None:
    """
    Analyzes the performance of a RandomForestClassifier model on a given dataframe.
    """
    
    X_train, X_test, y_train, y_test = split(dataframe)
    
    yhat = model.predict(X_test)
    
    plot = plot_confusion_matrix(y_true=y_test, y_pred=yhat)
    
    plot.show()

    Deck("Confusion Matrix", plotly.io.to_html(plot))


In [11]:
@task(container_image=image)
def create_searchspace() -> SearchSpace:

    return SearchSpace(
        max_depth=[1, 2, 5, 10],
        max_features=[None, 'sqrt'],
        n_estimators=[1, 2, 5, 10, 20, 50]
    )

In [12]:
@workflow
def train() -> RandomForestClassifier:
    """
    Executes the training workflow for grid-searched random forest classifier.
    """
    
    searchspace = create_searchspace()

    dataframe = get_dataframe()
    
    hyperparameters = create_search_grid(searchspace=searchspace)

    models = map_task(partial(train_model, dataframe=dataframe))(hyperparameters=hyperparameters)
    
    best_model = compare_model_results(models=models, hyperparameters=hyperparameters, dataframe=dataframe)

    analyze_model(model=best_model, dataframe=dataframe)
    
    return best_model


In [13]:
execution = remote.execute(train, inputs={}, version=str(datetime.now()))
print(remote.generate_console_url(execution))


[34mImage ghcr.io/granthamtaylor/byoc-sandbox:8BPl3b_KZaZz26H_iZlrTA not found. building...[0m
[34mRun command: docker image build --tag ghcr.io/granthamtaylor/byoc-sandbox:8BPl3b_KZaZz26H_iZlrTA --platform linux/amd64 --push /var/folders/mf/0mf3w2f16rz60zsc7zfmlmkm0000gn/T/tmp6k696toe [0m


#0 building with "desktop-linux" instance using docker driver

#1 [internal] load build definition from Dockerfile
#1 transferring dockerfile: 1.68kB done
#1 DONE 0.0s

#2 resolve image config for docker-image://docker.io/docker/dockerfile:1.5
#2 DONE 0.6s

#3 docker-image://docker.io/docker/dockerfile:1.5@sha256:39b85bbfa7536a5feceb7372a0817649ecb2724562a38360f4d6a7782a409b14
#3 CACHED

#4 [internal] load .dockerignore
#4 transferring context: 2B done
#4 DONE 0.0s

#5 [internal] load metadata for ghcr.io/astral-sh/uv:0.2.37
#5 ...

#6 [auth] astral-sh/uv:pull token for ghcr.io
#6 DONE 0.0s

#7 [internal] load metadata for docker.io/mambaorg/micromamba:1.5.8-bookworm-slim
#7 DONE 0.4s

#8 [internal] load metadata for docker.io/library/debian:bookworm-slim
#8 DONE 0.4s

#5 [internal] load metadata for ghcr.io/astral-sh/uv:0.2.37
#5 DONE 0.5s

#9 [stage-2  1/10] FROM docker.io/library/debian:bookworm-slim@sha256:ad86386827b083b3d71139050b47ffb32bbd9559ea9b1345a739b14fec2d9ecf
#9 DONE 0.0

https://demo.hosted.unionai.cloud/console/projects/flytesnacks/domains/development/executions/anp62tkvnrw5zgfgfg77


In [14]:
model = train()

    max_depth max_features  n_estimators    scores
47         10         sqrt            50  1.000000
34          5         sqrt            20  1.000000
29          5         None            50  1.000000
40         10         None            20  1.000000
41         10         None            50  1.000000
39         10         None            10  0.979724
36         10         None             1  0.979724
28          5         None            20  0.979724
21          2         sqrt            10  0.979724
20          2         sqrt             5  0.979724
38         10         None             5  0.979724
46         10         sqrt            20  0.979724
26          5         None             5  0.979497
42         10         sqrt             1  0.977143
35          5         sqrt            50  0.977143
32          5         sqrt             5  0.977143
45         10         sqrt            10  0.976498
14          2         None             5  0.959559
23          2         sqrt     