# ColumnStore Bulk Data Adapters - Demo

## Import all necessary libraries and set necessary configurations

In [None]:
from pyspark import SparkContext
from pyspark.sql import Row, SQLContext
from pyspark.ml.classification import RandomForestClassificationModel
from matplotlib import pyplot as plt
import mysql.connector as mariadb
import sys, decimal, datetime
import numpy as np
import time

url = 'jdbc:mysql://columnstore_host_nm:3306'
properties = {'user': 'jupiter_user', 'password': 'jupiter_pass', 'driver': 'org.mariadb.jdbc.Driver'}

sc = SparkContext("local", "MariaDB Spark ColumnStore Demo")
sqlContext = SQLContext(sc)

## Load the images for classification

We apply a trained random forest classification model [\[1\]](./Model_Training.ipynb) to the MNIST database of handwritten digits to determine and digitalize the digit written.

In [None]:
# load the handwritten numbers to predict
test = sqlContext.read.format("libsvm").option("numFeatures", "784").load("./mnist.t")

# output statistics
print("We have %d test images." % test.count())
display(test)

In [None]:
# visualizes the features vector
def visualizeFeatures(features):
    image = np.array(features, dtype='float')
    pixels = image.reshape((28, 28))
    plt.imshow(pixels, cmap='gray')
    plt.show()
    
for row in test.head(3):
    visualizeFeatures(row.features)
    print("visualization of image with label %d" % (row.label,))

## Predict the handwritten numbers

In [None]:
# load the trained model to predict the numbers
model = RandomForestClassificationModel.load("mnist-model-random-forest")

# predict the handwritten numbers
predictions = model.transform(test)
display(predictions)

In [None]:
# visualize the first three results
for row in predictions.head(3):
    visualizeFeatures(row.features)
    print("prediction: %d\tconfidence: %f\tlabel: %d" % (row.prediction, row.probability[int(row.prediction)], row.label))

## Restructure the dataframe for storage

In [None]:
def extract(row):
    return (row.label, row.prediction) + tuple(row.probability.toArray().tolist())

output = predictions.rdd.map(extract).toDF(["label","prediction","prob_0","prob_1","prob_2","prob_3","prob_4","prob_5","prob_6","prob_7","prob_8","prob_9"])
print("Number of predictions: %d" % output.count())
output.printSchema()

## Ingest dataframe through JDBC

In [None]:
t = time.time()
limit = 1000

output.limit(limit).write \
    .mode("overwrite") \
    .option("numPartitions", 1) \
    .option("createTableOptions", "ENGINE=columnstore") \
    .option("createTableColumnTypes", "label double, prediction double, prob_0 double, prob_1 double, prob_2 double, prob_3 double, prob_4 double, prob_5 double, prob_6 double, prob_7 double, prob_8 double, prob_9 double") \
    .jdbc(url, "test.jdbc", properties=properties)

print("%d rows ingested in %.3fs" % (limit, time.time() - t,))

## Ingest dataframe through Bulk Data Adapter API

In [None]:
# create table function
def createTable(name):
    try:
        conn = mariadb.connect(user='jupiter_user', password='jupiter_pass', host='columnstore_host_nm', database='test')
        cursor = conn.cursor()
        cursor.execute("CREATE TABLE IF NOT EXISTS %s \
                       (label double, prediction double, prob_0 double, prob_1 double, prob_2 double, prob_3 double, prob_4 double, prob_5 double, prob_6 double, prob_7 double, prob_8 double, prob_9 double)\
                       engine=columnstore" %(name,))

    except mariadb.Error as err:
        print("Error while creating table %s. %s" %(name,err,))
    
    finally:
        if cursor: cursor.close()
        if conn: conn.close()

In [None]:
import pymcsapi            

#create table
createTable("bulk_api_1")

# initialize the driver
driver = pymcsapi.ColumnStoreDriver()
bulk = driver.createBulkInsert('test', 'bulk_api_1', 0, 0)

# insert the dataframe row by row into ColumnStore
for row in output.collect():
    bulk.setColumn(0, row.label)
    bulk.setColumn(1, row.prediction)
    bulk.setColumn(2, row.prob_0)
    bulk.setColumn(3, row.prob_1)
    bulk.setColumn(4, row.prob_2)
    bulk.setColumn(5, row.prob_3)
    bulk.setColumn(6, row.prob_4)
    bulk.setColumn(7, row.prob_5)
    bulk.setColumn(8, row.prob_6)
    bulk.setColumn(9, row.prob_7)
    bulk.setColumn(10, row.prob_8)
    bulk.setColumn(11, row.prob_9)
    bulk.writeRow()
    
# commit the changes
bulk.commit()

# show a summary
summary = bulk.getSummary()
print("Execution time: %s" % (summary.getExecutionTime(),))
print("Rows inserted: %s" % (summary.getRowsInsertedCount(),))
print("Truncation count: %s" % (summary.getTruncationCount(),))
print("Saturated count: %s" % (summary.getSaturatedCount(),))
print("Invalid count: %s" % (summary.getInvalidCount(),))

## Ingest through ColumnStoreExporter / SparkConnector

In [None]:
import columnStoreExporter

createTable("bulk_api_2")
columnStoreExporter.export("test","bulk_api_2",output)

## SparkConnector in detail

In [None]:
def export(database, table, df):
    
    global long
    python2 = True

    if sys.version_info[0] == 3:
        long = int
        python2 = False

    rows = df.collect()
    driver = pymcsapi.ColumnStoreDriver()
    bulkInsert = driver.createBulkInsert(database, table, 0, 0)
    
    # get the column count of table
    dbCatalog = driver.getSystemCatalog()
    dbTable = dbCatalog.getTable(database, table)
    dbTableColumnCount = dbTable.getColumnCount()
    
    # insert row by row into table
    try:
        for row in rows:
            for columnId in range(0, len(row)):
                if columnId < dbTableColumnCount:
                    if isinstance(row[columnId], bool):
                        if row[columnId]:
                            bulkInsert.setColumn(columnId, 1)
                        else:
                            bulkInsert.setColumn(columnId, 0)
                    
                    elif isinstance(row[columnId], datetime.date):
                        bulkInsert.setColumn(columnId, row[columnId].strftime('%Y-%m-%d %H:%M:%S'))
                    
                    elif isinstance(row[columnId], decimal.Decimal):
                        dbColumn = dbTable.getColumn(columnId)
                        #DATA_TYPE_DECIMAL, DATA_TYPE_UDECIMAL, DATA_TYPE_FLOAT, DATA_TYPE_UFLOAT, DATA_TYPE_DOUBLE, DATA_TYPE_UDOUBLE
                        if dbColumn.getType() == 4 or dbColumn.getType() == 18 or dbColumn.getType() == 7 or dbColumn.getType() == 21 or dbColumn.getType() == 10 or dbColumn.getType() == 23:
                            s = '{0:f}'.format(row[columnId])
                            bulkInsert.setColumn(columnId, pymcsapi.ColumnStoreDecimal(s))
                        #ANY OTHER DATA TYPE
                        else:
                            bulkInsert.setColumn(columnId, long(row[columnId]))
    
                    #handle python2 unicode strings
                    elif python2 and isinstance(row[columnId], unicode):
                        bulkInsert.setColumn(columnId, row[columnId].encode('utf-8'))

                    #any other datatype is inserted without parsing
                    else:
                        bulkInsert.setColumn(columnId, row[columnId])
            bulkInsert.writeRow()
        bulkInsert.commit()
    except Exception as e:
        bulkInsert.rollback()
        print(row[columnId], type(row[columnId]))
        print(type(e))
        print(e)
       
    #print a short summary of the insertion process
    summary = bulkInsert.getSummary()
    print("Execution time: %s" % (summary.getExecutionTime(),))
    print("Rows inserted: %s" % (summary.getRowsInsertedCount(),))
    print("Truncation count: %s" %(summary.getTruncationCount(),))
    print("Saturated count: %s" %(summary.getSaturatedCount(),))
    print("Invalid count: %s" %(summary.getInvalidCount(),))

In [None]:
createTable("bulk_api_3")
export("test","bulk_api_3",output)