In [None]:
import pyarrow.parquet as pq

pq.write_table(table, 'example.parquet')


In [None]:
%%bash
export version=`python --version |awk '{print $2}' |awk -F"." '{print $1$2}'`

echo $version

if [ $version == '36' ] || [ $version == '37' ]; then
    echo 'Starting installation...'
    pip3 install pyspark==2.4.8 wget==3.2 pyspark2pmml==0.5.1 > install.log 2> install.log
    if [ $? == 0 ]; then
        echo 'Please <<RESTART YOUR KERNEL>> (Kernel->Restart Kernel and Clear All Outputs)'
    else
        echo 'Installation failed, please check log:'
        cat install.log
    fi
elif [ $version == '38' ] || [ $version == '39' ]; then
    pip3 install pyspark==3.1.2 wget==3.2 pyspark2pmml==0.5.1 > install.log 2> install.log
    if [ $? == 0 ]; then
        echo 'Please <<RESTART YOUR KERNEL>> (Kernel->Restart Kernel and Clear All Outputs)'
    else
        echo 'Installation failed, please check log:'
        cat install.log
    fi
else
    echo 'Currently only python 3.6, 3.7 , 3.8 and 3.9 are supported, in case you need a different version please open an issue at https://github.com/IBM/claimed/issues'
    exit -1
fi


In [None]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
import os
import shutil
import glob
from pyspark import SparkContext, SparkConf, SQLContext
import os
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark2pmml import PMMLBuilder
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import MinMaxScaler
import logging
import shutil
import site
import sys
import wget
import re

In [None]:
data_csv = os.environ.get('data_csv', 'data.csv')
data_parquet = os.environ.get('data_parquet', 'data.parquet')
master = os.environ.get('master', "local[*]")
data_dir = os.environ.get('data_dir', '../../data/')

In [None]:
data_parquet = 'trends.parquet'
data_csv = 'trends.csv'

In [None]:
skip = False
if os.path.exists(data_dir + data_csv):
    skip = True

In [None]:
if not skip:
    sc = SparkContext.getOrCreate(SparkConf().setMaster(master))
    spark = SparkSession.builder.getOrCreate()

In [None]:
if not skip:
    df = spark.read.parquet(data_dir + data_parquet)

In [None]:
if not skip:
    if os.path.exists(data_dir + data_csv):
        shutil.rmtree(data_dir + data_csv)
    df.coalesce(1).write.option("header", "true").csv(data_dir + data_csv)
    file = glob.glob(data_dir + data_csv + '/part-*')
    shutil.move(file[0], data_dir + data_csv + '.tmp')
    shutil.rmtree(data_dir + data_csv)
    shutil.move(data_dir + data_csv + '.tmp', data_dir + data_csv)

In [None]:
image_shape = os.environ.get('image_shape', '400,400')
model_zip = os.environ.get('model_zip', 'model.zip')
data_zip = os.environ.get('data_zip', 'data.zip')
model_folder = os.environ.get('model', 'model')
data = os.environ.get('data', 'data')
epochs = int(os.environ.get('epochs', 1))
checkpoint = boolean(os.environ.get('checkpoint', False))
checkpoint_ip = os.environ.get('checkpoint_ip')
checkpoint_user = os.environ.get('checkpoint_user', 'minio')
checkpoint_pass = os.environ.get('checkpoint_pass', 'minio123')
checkpoint_bucket = os.environ.get('checkpoint_bucket', 'checkpoint')

In [None]:
exists = False

if checkpoint:
    client = Minio(checkpoint_ip,
                   checkpoint_user,
                   checkpoint_pass,
                   secure=False)

    objects = client.list_objects(checkpoint_bucket)
    asset_name = model_zip
    for obj in objects:
        if asset_name == obj.object_name:
            exists = True
            client.fget_object(checkpoint_bucket, model_zip, model_zip)
            break

In [None]:
if not exists:
    unzip('.', data_zip)

In [None]:
if not exists:
    folder = glob.glob(data + "/*")
    num_classes = len(folder)

In [None]:
if not exists:
    batch_size = 32
    input_shape = 'dummy'  # make the compiler happy
    exec('input_shape = (' + image_shape + ')')

    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        'data',
        validation_split=0.2,
        subset="training",
        seed=123,
        image_size=input_shape,
        batch_size=batch_size)

    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        'data',
        validation_split=0.2,
        subset="validation",
        seed=123,
        image_size=input_shape,
        batch_size=batch_size)

    train_ds = train_ds.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))
    val_ds = val_ds.map(lambda x, y: (x, tf.one_hot(y, depth=num_classes)))

In [None]:
def my_net(model, num_classes, freeze_layers=10, full_freeze='N'):
    x = model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.5)(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.5)(x)
    out = Dense(num_classes, activation='softmax')(x)
    model_final = Model(model.input, out)
    if full_freeze != 'N':
        for layer in model.layers[0:freeze_layers]:
            layer.trainable = False
    return model_final

In [None]:
if not exists:
    exec('input_shape = (' + image_shape + ',3)')

    model = tf.keras.applications.MobileNetV2(
        input_shape=input_shape, alpha=1.0, include_top=False,
        input_tensor=None, pooling=None, classes=num_classes,
        classifier_activation='softmax'
    )
    model = my_net(model, num_classes=num_classes)

In [None]:
if not exists:
    model.compile(
        optimizer="adam",
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

In [None]:
if not exists:
    model.fit(
        train_ds,
        batch_size=batch_size,
        epochs=epochs,
        validation_data=val_ds
    )
    model.save(model_folder)
    zipdir(model_zip, model_folder)
else:
    print('Model cached, skipping training')

In [None]:
if not exists:
    size = os.path.getsize(model_zip)
    with open(model_zip, 'rb') as fh:
        buf = BytesIO(fh.read())
        result = client.put_object(
            checkpoint_bucket, model_zip, buf, length=size
        )