In [1]:
import numpy as np

In [3]:
import kensu_dam.pandas as pd
from kensu_dam.sklearn.linear_model import LogisticRegression

In [5]:
data_url = 'https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv'

In [6]:
from kensu_dam.sklearn.model_selection import train_test_split
df = pd.read_csv(data_url)

In [7]:
X_train, X_test, y_train, y_test = train_test_split(df, df['Survived'], test_size=0.33, random_state=42)

In [8]:
X_train.head()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
6,7,0,1,"McCarthy, Mr. Timothy J",male,54.0,0,0,17463,51.8625,E46,S
718,719,0,3,"McEvoy, Mr. Michael",male,,0,0,36568,15.5,,Q
685,686,0,2,"Laroche, Mr. Joseph Philippe Lemercier",male,25.0,1,2,SC/Paris 2123,41.5792,,C
73,74,0,3,"Chronopoulos, Mr. Apostolos",male,26.0,1,0,2680,14.4542,,C
882,883,0,3,"Dahlberg, Miss. Gerda Ulrika",female,22.0,0,0,7552,10.5167,,S


- Pclass: passenger class
- Parch: parents and children

In [9]:
feature_cols = ['Pclass', 'Fare']

In [10]:
# you want all rows, and the feature_cols' columns
X_train = X_train.loc[:, feature_cols]

In [11]:
X_train.shape

(596, 2)

In [12]:
y_train.shape

(596,)

Instanciate the model

In [13]:
logreg = LogisticRegression()

Training it

In [14]:
logreg.fit(X_train, y_train)

LogisticRegression()

We can branch out, and save the training spit set

In [15]:
X_train.to_csv('X_train.set')

## Looking in the logs

In [16]:
logs = read_scikit_logs()

In [17]:
len(logs)

46

In [18]:
datasources = list(filter(lambda x: x["entity"]=="DATA_SOURCE", logs))

In [19]:
list(filter(lambda x: "model" in x["jsonPayload"]["pk"]["location"],datasources))[0]["jsonPayload"]["pk"]["location"]

'in-mem://model/c4131dc5506138bb7d076ba0321bd307eb81f8dec733a3214030e5c111688f08/in-mem-transformation'

In [20]:
list(filter(lambda x: x["entity"]=="MODEL", logs))[0]["jsonPayload"]

{'pk': {'name': 'SkLearn.LogisticRegression'}}

In [21]:
list(filter(lambda x: x["entity"]=="MODEL_TRAINING", logs))[0]["jsonPayload"]

{'pk': {'modelRef': {'byGUID': 'k-184f2b402bbfef845b1f202c6936de6af75aea7063a405ed844f8d30b474240f'},
  'processLineageRef': {'byGUID': 'k-f07a86f5d60eb791cb6b5669dab0daaa03378a6970d528cb0a830f7efecca011'}}}

In [22]:
metrics = list(filter(lambda x: x["entity"]=="MODEL_METRICS", logs))[0]["jsonPayload"]
metrics

{'pk': {'modelTrainingRef': {'byGUID': 'k-ee00c846f0701a45c690852914ad6b9e27bda2ec47fcd3df0532ff4646ac7acf'},
  'lineageRunRef': {'byGUID': 'k-76675ae5d108bac0a4461f976b6724cbc5a23276e30bbfa4df060c6c4186d240'},
  'storedInSchemaRef': {'byGUID': 'k-1d68981ba39d8010f6ff78050dffa5dacbcd0fa2f63dc0b8136890abb62be2a2'}},
 'metrics': {'train.score': 0.6694630872483222,
  'train.explained_variance': -0.7649714790712623,
  'train.neg_mean_absolute_error': 0.33053691275167785,
  'train.neg_mean_squared_error': 0.33053691275167785,
  'train.neg_mean_squared_log_error': 0.1588074559427612,
  'train.neg_median_absolute_error': 0.0,
  'train.r2': -0.8865911464609948},
 'hyperParamsAsJson': '{"C": 1.0, "class_weight": null, "dual": false, "fit_intercept": true, "intercept_scaling": 1, "l1_ratio": null, "max_iter": 100, "multi_class": "auto", "n_jobs": null, "penalty": "l2", "random_state": null, "solver": "lbfgs", "tol": 0.0001, "verbose": 0, "warm_start": false}'}

In [None]:
metrics["metrics"]

In [None]:
json.loads(metrics["hyperParamsAsJson"])

## Strategy

Same as per `pandas`, using interceptors:
* https://github.com/kensuio/dam-client-python/blob/ft%2Fscikit-learn/kensu_dam/sklearn/linear_model.py#L10
* https://github.com/kensuio/dam-client-python/blob/ft%2Fscikit-learn/kensu_dam/sklearn/extractor.py#L80

## Other strategies

* Using API (e.g. Tableau Lineage API)
* Using OpenTracing/OpenTelemetry (rather than graph based entity logging): 
    * JAX-RS: https://github.com/kensuio-oss/jaxrs-sample/blob/master/src/main/java/io/kensu/collector/interceptors/KensuTracingInterceptorFeature.java#L74
    * JDBC:
        * driver: https://github.com/kensuio-oss/java-jdbc/blob/cleaning/src/main/java/io/opentracing/contrib/jdbc/TracingDriver.java#L37
        * stats: https://github.com/kensuio-oss/java-jdbc/blob/cleaning/src/main/java/io/opentracing/contrib/jdbc/TracingResultSet.java#L288