<a href="https://colab.research.google.com/github/bhaktichowkwale/SparkTorch_Demo/blob/main/SparkTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
## Set Python - Spark environment.
import os
import sys


In [2]:
!pip install sparktorch

Collecting sparktorch
  Downloading sparktorch-0.1.2.tar.gz (21 kB)
Building wheels for collected packages: sparktorch
  Building wheel for sparktorch (setup.py) ... [?25l[?25hdone
  Created wheel for sparktorch: filename=sparktorch-0.1.2-py3-none-any.whl size=24597 sha256=32b5b102e372987425a84526669c9b357f932496c99b4e48032668494bbfeee1
  Stored in directory: /root/.cache/pip/wheels/67/d4/a2/a288b918877e28698fc8ff8cb1d8290713bd84abcb80715d47
Successfully built sparktorch
Installing collected packages: sparktorch
Successfully installed sparktorch-0.1.2


In [4]:
!pip install pyspark

[31mERROR: Operation cancelled by user[0m


In [5]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://downloads.apache.org/spark/spark-2.4.8/spark-2.4.8-bin-hadoop2.7.tgz
!tar xf spark-2.4.8-bin-hadoop2.7.tgz
!pip install -q findspark

In [7]:
!java -version

openjdk version "11.0.11" 2021-04-20
OpenJDK Runtime Environment (build 11.0.11+9-Ubuntu-0ubuntu2.18.04)
OpenJDK 64-Bit Server VM (build 11.0.11+9-Ubuntu-0ubuntu2.18.04, mixed mode, sharing)


In [11]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-2.4.8-bin-hadoop2.7"

In [13]:
!update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java
!java -version

openjdk version "1.8.0_292"
OpenJDK Runtime Environment (build 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10)
OpenJDK 64-Bit Server VM (build 25.292-b10, mixed mode)


In [14]:
import findspark
findspark.init()

In [15]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()

In [21]:

from sparktorch import serialize_torch_obj, SparkTorch
import torch
import torch.nn as nn
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession
from pyspark.ml.pipeline import Pipeline

In [22]:
df = spark.read.csv("/content/sample_data/mnist_train_small.csv")

In [28]:
#Typecast features into double
for col_name in df.columns:
    df = df.withColumn(col_name, col(col_name).cast('Double'))

In [29]:
network = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.Softmax(dim=1)
)

In [30]:
# Build the pytorch object
torch_obj = serialize_torch_obj(
    model=network,
    criterion=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam,
    lr=0.0001
)


In [31]:
# Setup features
vector_assembler = VectorAssembler(inputCols=df.columns[1:785], outputCol='features')

In [32]:
# Create a SparkTorch Model with torch distributed. Barrier execution is on by default for this mode.
spark_model = SparkTorch(
    inputCol='features',
    labelCol='_c0',
    predictionCol='predictions',
    torchObj=torch_obj,
    iters=50,
    verbose=1
)

In [33]:
# Can be used in a pipeline and saved.
p = Pipeline(stages=[vector_assembler, spark_model]).fit(df)
p.save('simple_dnn')

In [34]:
# Run predictions and evaluation
predictions = p.transform(df).persist()
evaluator = MulticlassClassificationEvaluator(labelCol="_c0", predictionCol="predictions", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Train accuracy = %g" % accuracy)


Train accuracy = 0.5966
