In [1]:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score
import pandas as pd
import pickle

data, target = load_iris(return_X_y=True, as_frame=True)

print(len(data), type(data), len(target), type(target))
print(data)

models = {
    "rf": RandomForestClassifier(n_estimators=100, max_depth=10),
    "lr": LogisticRegression(),
    "knn": KNeighborsClassifier(n_neighbors=3),
}

train_x, test_x, train_y, test_y = train_test_split(data, target, test_size = 0.33, random_state=17)

list_metrics = []

for model in models:
    metrics = {}
    models[model].fit(train_x, train_y)
    pred = models[model].predict(test_x)
    metrics['accuracy'] = accuracy_score(test_y, pred)
    metrics['f1'] = f1_score(test_y, pred, average='macro')
    list_metrics.append(metrics)

metrics_df =pd.DataFrame(list_metrics, index=models.keys())
print(metrics_df)

data.to_csv('/opt/airflow/shared/iris_x.csv', index=False)
target.to_csv('/opt/airflow/shared/iris_y.csv', index=False)

for model in models:
    fname = f'{model}.pickle'
    with open(fname, 'wb') as f:
        pickle.dump(models[model], f)

150 <class 'pandas.core.frame.DataFrame'> 150 <class 'pandas.core.series.Series'>
     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
0                  5.1               3.5                1.4               0.2
1                  4.9               3.0                1.4               0.2
2                  4.7               3.2                1.3               0.2
3                  4.6               3.1                1.5               0.2
4                  5.0               3.6                1.4               0.2
..                 ...               ...                ...               ...
145                6.7               3.0                5.2               2.3
146                6.3               2.5                5.0               1.9
147                6.5               3.0                5.2               2.0
148                6.2               3.4                5.4               2.3
149                5.9               3.0                5.1 