<a href="https://colab.research.google.com/github/bhaktichowkwale/SparkTorch_Demo/blob/main/SparkTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import sys


In [2]:
!pip install sparktorch

Collecting sparktorch
  Downloading sparktorch-0.1.2.tar.gz (21 kB)
Building wheels for collected packages: sparktorch
  Building wheel for sparktorch (setup.py) ... [?25l[?25hdone
  Created wheel for sparktorch: filename=sparktorch-0.1.2-py3-none-any.whl size=24597 sha256=394afa884becedc40cb3185ff05fc15286d27e001ba39f927e47e14dd0225e5c
  Stored in directory: /root/.cache/pip/wheels/67/d4/a2/a288b918877e28698fc8ff8cb1d8290713bd84abcb80715d47
Successfully built sparktorch
Installing collected packages: sparktorch
Successfully installed sparktorch-0.1.2


In [3]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.2.0.tar.gz (281.3 MB)
[K     |████████████████████████████████| 281.3 MB 8.7 kB/s 
[?25hCollecting py4j==0.10.9.2
  Downloading py4j-0.10.9.2-py2.py3-none-any.whl (198 kB)
[K     |████████████████████████████████| 198 kB 49.2 MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.2.0-py2.py3-none-any.whl size=281805912 sha256=075ad5cfabb580cf5fa0cfa7c40b2dc786d5489419334ca1971f9d66e3c1eaa5
  Stored in directory: /root/.cache/pip/wheels/0b/de/d2/9be5d59d7331c6c2a7c1b6d1a4f463ce107332b1ecd4e80718
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.2 pyspark-3.2.0


In [4]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://downloads.apache.org/spark/spark-2.4.8/spark-2.4.8-bin-hadoop2.7.tgz
!tar xf spark-2.4.8-bin-hadoop2.7.tgz
!pip install -q findspark

In [5]:
!java -version

openjdk version "11.0.11" 2021-04-20
OpenJDK Runtime Environment (build 11.0.11+9-Ubuntu-0ubuntu2.18.04)
OpenJDK 64-Bit Server VM (build 11.0.11+9-Ubuntu-0ubuntu2.18.04, mixed mode, sharing)


In [6]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-2.4.8-bin-hadoop2.7"

In [7]:
!update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java
!java -version

update-alternatives: using /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java to provide /usr/bin/java (java) in manual mode
openjdk version "1.8.0_292"
OpenJDK Runtime Environment (build 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10)
OpenJDK 64-Bit Server VM (build 25.292-b10, mixed mode)


In [8]:
import findspark
findspark.init()

In [10]:
from sparktorch import serialize_torch_obj, SparkTorch
import torch
import torch.nn as nn
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession
from pyspark.ml.pipeline import Pipeline

In [11]:
#df = spark.read.csv("/mnist_train.csv")
df = spark.read.csv('/content/sample_data/mnist_train_small.csv')

In [15]:
#Typecast features into double
from pyspark.sql.functions import col, udf, column
for col_name in df.columns:
    df = df.withColumn(col_name, col(col_name).cast('Double'))

In [16]:
network = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.Softmax(dim=1)
)

In [17]:
# Build the pytorch object
torch_obj = serialize_torch_obj(
    model=network,
    criterion=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam,
    lr=0.0001
)


In [18]:
# Setup features
vector_assembler = VectorAssembler(inputCols=df.columns[1:785], outputCol='features')

In [19]:
# Create a SparkTorch Model with torch distributed. Barrier execution is on by default for this mode.
spark_model = SparkTorch(
    inputCol='features',
    labelCol='_c0',
    predictionCol='predictions',
    torchObj=torch_obj,
    iters=1000,
    miniBatch=256,
    verbose=1
)



In [20]:
# Can be used in a pipeline and saved.
p = Pipeline(stages=[vector_assembler, spark_model]).fit(df)
p.save('simple_dnn_2')

In [22]:
# Run predictions and evaluation
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
predictions = p.transform(df).persist()
evaluator = MulticlassClassificationEvaluator(labelCol="_c0", predictionCol="predictions", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Train accuracy = %g" % accuracy)


Train accuracy = 0.88425
