In [None]:
##setup
#!pip install -Iq kubernetes==10.0.1
#!curl -LO "https://dl.k8s.io/release/v1.21.5/bin/darwin/amd64/kubectl"; chmod u+x kubectl
#!./kubectl version
#!./kubectl apply -k "github.com/kubeflow/training-operator/manifests/overlays/standalone?ref=v1.3.0"
#!./kubectl apply -f https://raw.githubusercontent.com/kubernetes/dashboard/v2.6.0/aio/deploy/recommended.yaml
!cp ~/.kube/config ./include/.kube/config


In [None]:
import getpass
password = getpass.getpass('Enter password: ')
account = getpass.getpass('Enter account: ')
state_dict = {"connection_parameters": {"password": password},
              "compute_parameters" : {"default_warehouse": "XSMALL_WH"}
             }
state_dict['connection_parameters']['user'] = 'jack' 
state_dict['connection_parameters']['account'] = account
state_dict['connection_parameters']['role']='PUBLIC'
state_dict['connection_parameters']['database']='CITIBIKEML_jack'
state_dict['connection_parameters']['schema']='DEMO'
state_dict['feature_table_name']='FEATURE_03A08400_EE3C_11EC_A5EE_ACDE48001122'
state_dict['pred_table_name']='PRED_03A08400_EE3C_11EC_A5EE_ACDE48001122'
state_dict['model_file_name']='forecast_model.zip'
state_dict['le_file_name']='label_encoders.pkl'
state_dict['cat_cols'] = ['STATION_ID', 'HOLIDAY']
state_dict['k8s_namespace'] = 'citibike'
state_dict['train_image'] = 'docker.io/mpgregor/airkube:latest'
state_dict['train_job_name'] = 'citibike-train'

import json
with open('./include/state.json', 'w') as sdf:
    json.dump(state_dict, sdf)

In [None]:
%%writefile training/load_train.py
import argparse

def load_and_encode(state_dict):

    from snowflake import snowpark as snp
    import pandas as pd
    from sklearn.preprocessing import LabelEncoder
    from collections import defaultdict
    import pickle
    
    session = snp.Session.builder.configs(state_dict['connection_parameters']).create()
    session.use_warehouse(state_dict['compute_parameters']['default_warehouse'])

    feature_df = session.table(state_dict['feature_table_name']).to_pandas()
    #forecast_df = session.table(state_dict['forecast_table_name']).to_pandas()

    session.close()

    feature_df['DATE'] = pd.to_datetime(feature_df['DATE'])
    feature_df.set_index('DATE', inplace=True)
    
    #forecast_df['DATE'] = pd.to_datetime(forecast_df['DATE'])
    #forecast_df.set_index('DATE', inplace=True)

    cat_cols = state_dict['cat_cols']
    num_cols = [set(feature_df.columns)-set(cat_cols)]
    state_dict['num_cols'] = num_cols

    try:
        with open(state_dict['le_file_name'], 'rb') as fh: 
            d=pickle.load(fh)
        feature_df[cat_cols]=feature_df[cat_cols].apply(lambda x: d[x.name].transform(x))

    except: 
        d = defaultdict(LabelEncoder)
        feature_df[cat_cols]=feature_df[cat_cols].apply(lambda x: d[x.name].fit_transform(x))

        with open(state_dict['le_file_name'], 'wb') as fh: 
            pickle.dump(d, fh)

    return state_dict, feature_df

def train_and_save(state_dict, feature_df):
    import pandas as pd
    from pytorch_tabnet.tab_model import TabNetRegressor
    
    feature_df.sort_values(by='DATE', ascending=True, inplace=True)

    train_df = feature_df.groupby('STATION_ID').head(-365)
    valid_df = feature_df.groupby('STATION_ID').tail(365)

    state_dict['cat_idxs'] = [feature_df.drop(columns=['COUNT'], axis=1).columns.get_loc(col) for col in state_dict['cat_cols']]
    state_dict['cat_dims'] = list(feature_df.drop(columns=['COUNT'], axis=1).iloc[:, state_dict['cat_idxs']].nunique().values)

    y_train = train_df['COUNT'].values.reshape(-1,1)
    X_train = train_df.drop(columns ='COUNT', axis=1).values

    y_valid = valid_df['COUNT'].values.reshape(-1,1)
    X_valid = valid_df.drop(columns ='COUNT', axis=1).values
    
    model = TabNetRegressor(cat_idxs=state_dict['cat_idxs'], cat_dims=state_dict['cat_dims'])

    model.fit(
        X_train, y_train,
        eval_set=[(X_valid, y_valid)],
        max_epochs=1,
        patience=100,
        batch_size=2048, 
        virtual_batch_size=256,
        num_workers=0,
        drop_last=True)

    model.save_model(state_dict['model_file_name'].split('.')[0])
    
    return state_dict

def pred(state_dict, feature_df):
    from pytorch_tabnet.tab_model import TabNetRegressor
    import pandas as pd
    from torch import tensor
    
    model = TabNetRegressor(cat_idxs=state_dict['cat_idxs'], cat_dims=state_dict['cat_dims'])

    model.load_model(state_dict['model_file_name'])
    
    pred_df = feature_df.copy(deep=True)
    
    pred_df['PRED'] = model.predict(tensor(feature_df.drop(columns=['COUNT']).values)).round().astype('int')
    
    return state_dict, pred_df

def forecast(state_dict, feature_df, forecast_df):

    if len(state_dict['lag_values']) > 0:
        for step in range(state_dict['forecast_steps']):
            #station_id = df.iloc[-1]['STATION_ID']
            future_date = df.iloc[-1]['DATE']+timedelta(days=1)
            lags=[df.shift(lag-1).iloc[-1]['COUNT'] for lag in state_dict['lag_values']]
            forecast=forecast_df.loc[forecast_df['DATE']==future_date.strftime('%Y-%m-%d')]
            forecast=forecast.drop(labels='DATE', axis=1).values.tolist()[0]
            features=[*lags, *forecast]
            pred=round(model.predict(np.array([features]))[0][0])
            row=[future_date, pred, *features, pred]
            df.loc[len(df)]=row

    return state_dict, pred_df

def decode_and_write(state_dict, pred_df):
    from snowflake import snowpark as snp
    import pandas as pd
    import pickle
    
    with open(state_dict['le_file_name'], 'rb') as fh: 
        d=pickle.load(fh)

    pred_df[state_dict['cat_cols']] = pred_df[state_dict['cat_cols']].apply(lambda x: d[x.name].inverse_transform(x))


    session = snp.Session.builder.configs(state_dict['connection_parameters']).create()
    session.use_warehouse(state_dict['compute_parameters']['default_warehouse'])

    session.create_dataframe(pred_df).write.mode('overwrite').save_as_table(state_dict['pred_table_name'])
    
    session.close()
    
    return state_dict

if __name__ == '__main__':

    # Defining and parsing the command-line arguments
    parser = argparse.ArgumentParser(description='airkube training')
    parser.add_argument('--password', type=str)
    parser.add_argument('--account', type=str)
    parser.add_argument('--username', type=str)
    parser.add_argument('--role', type=str)
    parser.add_argument('--database', type=str)
    parser.add_argument('--schema', type=str)
    parser.add_argument('--feature_table_name', type=str)
    parser.add_argument('--pred_table_name', type=str)
    
    args = parser.parse_args()

    # Creating the directory where the output file will be created (the directory may or may not exist).
    #Path(args.accuracy).parent.mkdir(parents=True, exist_ok=True)

    state_dict = {"connection_parameters": {"password": args.password},
                  "compute_parameters" : {"default_warehouse": "XSMALL_WH"}}
    state_dict['connection_parameters']['user'] = args.username
    state_dict['connection_parameters']['account'] = args.account
    state_dict['connection_parameters']['role'] = args.role
    state_dict['connection_parameters']['database'] = args.database
    state_dict['connection_parameters']['schema'] = args.schema
    state_dict['feature_table_name'] = args.feature_table_name
    state_dict['pred_table_name'] = args.pred_table_name
    state_dict['model_file_name']='forecast_model.zip'
    state_dict['le_file_name']='label_encoders.pkl'
    state_dict["cat_cols"] = ['STATION_ID', 'HOLIDAY']

    load_state_dict, feature_df = load_and_encode(state_dict)
    train_state_dict = train_and_save(load_state_dict, feature_df)
    pred_state_dict, pred_df = pred(state_dict, feature_df)
    state_dict = decode_and_write(state_dict, pred_df)



In [None]:
from load_train import load_and_encode, train_and_save, pred, decode_and_write
import json

with open('./include/state.json') as sdf:
    state_dict = json.load(sdf)    

load_state_dict, feature_df = load_and_encode(state_dict)
train_state_dict = train_and_save(load_state_dict, feature_df)
pred_state_dict, pred_df = pred(state_dict, feature_df)
state_dict = decode_and_write(state_dict, pred_df)

In [None]:
#check output
from snowflake import snowpark as snp

session = snp.Session.builder.configs(state_dict['connection_parameters']).create()
session.use_warehouse(state_dict['compute_parameters']['default_warehouse'])
session.table(state_dict['pred_table_name']).show()
session.close()

In [None]:
# %%writefile ./include/train.yaml
# apiVersion: "kubeflow.org/v1"
# kind: PyTorchJob
# metadata:
#   name: citibike-train
#   namespace: citibike
# spec:
#   pytorchReplicaSpecs:
#     Master:
#       replicas: 1
#       restartPolicy: Never
#       template:
#         spec:
#           containers:
#             - name: pytorch
#               image: docker.io/mpgregor/airkube:latest
#               imagePullPolicy: Always
#               command:
#                 - "python"
#                 - "/pipeline/load_train.py"
#                 - "--account="
#                 - "--password="
#                 - "--username="
#                 - "--role="
#                 - "--database="
#                 - "--schema="
#                 - "--feature_table_name="
#                 - "--pred_table_name="
                
# !kubectl create namespace citibike
# !kubectl create -f ./include/train.yaml 
# !kubectl delete pytorchjob citibike-train -n citibike
# !kubectl delete namespace citibike                

In [None]:
from kubernetes.client import V1PodTemplateSpec
from kubernetes.client import V1ObjectMeta
from kubernetes.client import V1PodSpec
from kubernetes.client import V1Container
from kubernetes.client import V1ResourceRequirements

from kubeflow.pytorchjob import constants
from kubeflow.pytorchjob import utils
from kubeflow.pytorchjob import V1ReplicaSpec
from kubeflow.pytorchjob import V1PyTorchJob
from kubeflow.pytorchjob import V1PyTorchJobSpec
from kubeflow.pytorchjob import PyTorchJobClient

#namespace = state_dict['k8s_namespace'] #utils.get_default_target_namespace()

import json
with open('./include/state.json') as sdf:
    state_dict = json.load(sdf)

from kubernetes import client, config

config.load_kube_config('./include/kube_config.yaml')

k8s_client = client.CoreV1Api()
if state_dict['k8s_namespace'] not in [item.metadata.name for item in k8s_client.list_namespace().items]:
    k8s_client.create_namespace(client.V1Namespace(metadata=client.V1ObjectMeta(name=state_dict['k8s_namespace'])))
    

container = V1Container(
    name="pytorch",
    image=state_dict['train_image'],
    image_pull_policy="Always",
    command=["python", 
             "/pipeline/load_train.py",
             "--account="+state_dict['connection_parameters']['account'], 
             "--password="+state_dict['connection_parameters']['password'],
             "--username="+state_dict['connection_parameters']['user'],
             "--role="+state_dict['connection_parameters']['role'], 
             "--database="+state_dict['connection_parameters']['database'], 
             "--schema="+state_dict['connection_parameters']['schema'], 
             "--feature_table_name="+state_dict['feature_table_name'], 
             "--pred_table_name="+state_dict['pred_table_name']
            ]
)

master = V1ReplicaSpec(
    replicas=1,
    restart_policy="OnFailure",
    template=V1PodTemplateSpec(
        spec=V1PodSpec(
            containers=[container]
        )
    )
)

worker = V1ReplicaSpec(
    replicas=1,
    restart_policy="OnFailure",
    template=V1PodTemplateSpec(
        spec=V1PodSpec(
            containers=[container]
        )
    )
)

pytorchjob = V1PyTorchJob(
    api_version="kubeflow.org/v1",
    kind="PyTorchJob",
    metadata=V1ObjectMeta(name=state_dict['train_job_name'], namespace=state_dict['k8s_namespace']),
    spec=V1PyTorchJobSpec(
        clean_pod_policy="None",
        pytorch_replica_specs={"Master": master} 
    )
)

pytorch_client = PyTorchJobClient()
resp = pytorch_client.create(pytorchjob)

#pytorch_client.get_job_status(name=resp['metadata']['name'], 
#                              namespace=resp['metadata']['namespace'])

_ = pytorch_client.wait_for_condition(name=resp['metadata']['name'], 
                                       namespace=resp['metadata']['namespace'],
                                       expected_condition='Succeeded')

pytorch_client.delete(name=resp['metadata']['name'], 
                      namespace=resp['metadata']['namespace'])

In [None]:
%%writefile ./dags/k8s_test.py

from datetime import datetime, timedelta

from airflow import DAG
from airflow.configuration import conf
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator

default_args = {
    'owner': 'airflow',
    'depends_on_past': False,
    'start_date': datetime(2019, 1, 1),
    'email_on_failure': False,
    'email_on_retry': False,
    'retries': 1,
    'retry_delay': timedelta(minutes=5),
}

namespace = conf.get('kubernetes', 'NAMESPACE')

# This will detect the default namespace locally and read the
# environment namespace when deployed to Astronomer.
if namespace =='default':
    config_file = '/usr/local/airflow/include/.kube/config'
    in_cluster = False
else:
    in_cluster = True
    config_file = None

dag = DAG('example_kubernetes_pod', schedule_interval='@once', default_args=default_args)


with dag:
    KubernetesPodOperator(
        namespace=namespace,
        image="hello-world",
        labels={"foo": "bar"},
        name="airflow-test-pod",
        task_id="task-one",
        in_cluster=in_cluster,  # if set to true, will look in the cluster, if false, looks for file
        cluster_context="docker-desktop",  # is ignored when in_cluster is set to True
        config_file=config_file,
        is_delete_operator_pod=True,
        get_logs=True,
    )
