<a href="https://colab.research.google.com/github/cfong32/netflix-prize/blob/main/try_sparktorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install sparktorch pyspark

In [2]:
from sparktorch import serialize_torch_obj, SparkTorch
import torch
import torch.nn as nn
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession
from pyspark.ml.pipeline import Pipeline

In [None]:
spark = SparkSession.builder.appName("examples").master('local[2]').getOrCreate()
df = spark.read.option("inferSchema", "true").option("header", "true").csv('mnist_train.csv').coalesce(2)

In [6]:
df = df.sample(fraction=0.01)

In [7]:
network = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.Softmax(dim=1)
)

# Build the pytorch object
torch_obj = serialize_torch_obj(
    model=network,
    criterion=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam,
    lr=0.0001
)

# Setup features
vector_assembler = VectorAssembler(inputCols=df.columns[1:785], outputCol='features')

# Create a SparkTorch Model with torch distributed. Barrier execution is on by default for this mode.
spark_model = SparkTorch(
    inputCol='features',
    labelCol='label',
    predictionCol='predictions',
    torchObj=torch_obj,
    iters=50,
    verbose=1
)

In [8]:
# Can be used in a pipeline and saved.
p = Pipeline(stages=[vector_assembler, spark_model]).fit(df)
p.save('simple_dnn')

In [10]:
vars(p.stages[-1])

{'_input_kwargs': {'inputCol': 'features',
  'predictionCol': 'predictions',
  'modStr': 'gASVlQIAAAAAAACMGnRvcmNoLm5uLm1vZHVsZXMuY29udGFpbmVylIwKU2VxdWVudGlhbJSTlCmB\nlH2UKIwIdHJhaW5pbmeUiIwLX3BhcmFtZXRlcnOUjAtjb2xsZWN0aW9uc5SMC09yZGVyZWREaWN0\nlJOUKVKUjAhfYnVmZmVyc5RoCSlSlIwbX25vbl9wZXJzaXN0ZW50X2J1ZmZlcnNfc2V0lI+UjBNf\nYmFja3dhcmRfcHJlX2hvb2tzlGgJKVKUjA9fYmFja3dhcmRfaG9va3OUaAkpUpSMFl9pc19mdWxs\nX2JhY2t3YXJkX2hvb2uUTowOX2ZvcndhcmRfaG9va3OUaAkpUpSMGl9mb3J3YXJkX2hvb2tzX3dp\ndGhfa3dhcmdzlGgJKVKUjBJfZm9yd2FyZF9wcmVfaG9va3OUaAkpUpSMHl9mb3J3YXJkX3ByZV9o\nb29rc193aXRoX2t3YXJnc5RoCSlSlIwRX3N0YXRlX2RpY3RfaG9va3OUaAkpUpSMFV9zdGF0ZV9k\naWN0X3ByZV9ob29rc5RoCSlSlIwaX2xvYWRfc3RhdGVfZGljdF9wcmVfaG9va3OUaAkpUpSMG19s\nb2FkX3N0YXRlX2RpY3RfcG9zdF9ob29rc5RoCSlSlIwIX21vZHVsZXOUaAkpUpQojAEwlIwXdG9y\nY2gubm4ubW9kdWxlcy5saW5lYXKUjAZMaW5lYXKUk5QpgZR9lChoBYhoBmgJKVKUKIwGd2VpZ2h0\nlIwMdG9yY2guX3V0aWxzlIwSX3JlYnVpbGRfcGFyYW1ldGVylJOUaC6MEl9yZWJ1aWxkX3RlbnNv\ncl92MpSTlCiMDXRvcmNoLnN0b3JhZ2WUjBBfbG9hZF9mcm9tX2J5

In [11]:
predictions = p.transform(df)

In [12]:
predictions.show()

+-----+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+