# PySpark PyTorch Inference

### Image Classification
Based on: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

In [None]:
import torch

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
torch.__version__

In [None]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [None]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape} {X.dtype}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

### Create model

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

### Train Model

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

### Save Model State Dict
This is the [currently recommended save format](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference).

In [None]:
torch.save(model.state_dict(), "model_weights.pt")
print("Saved PyTorch Model State to model_weights.pt")

### Save Entire Model
This saves the entire model using python pickle, but has the [following disadvantage](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model):
> The serialized data is bound to the specific classes and the exact directory structure used when the model is saved... Because of this, your code can break in various ways when used in other projects or after refactors.

In [None]:
torch.save(model, "model.pt")

### Save Model as TorchScript
This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python).  However, this currently doesn't work with spark, which uses pickle serialization.

In [None]:
scripted = torch.jit.script(model)

In [None]:
scripted.save("model.ts")

### Load Model State

In [None]:
model_from_state = NeuralNetwork()
model_from_state.load_state_dict(torch.load("model_weights.pt"))

In [None]:
model_from_state.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model_from_state(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

### Load Model

In [None]:
new_model = torch.load("model.pt")

In [None]:
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = new_model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

### Load Torchscript Model

In [None]:
ts_model = torch.jit.load("model.ts")

In [None]:
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = ts_model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

## PySpark

### Convert numpy dataset to Spark DataFrame (via Pandas DataFrame)

In [None]:
import pandas as pd
from pyspark.sql.types import StructType, StructField, ArrayType, FloatType

In [None]:
data = test_data.data.numpy()
data.shape, data.dtype

In [None]:
data = data.reshape(10000, 784) / 255.0
data.shape, data.dtype

In [None]:
test_pdf = pd.DataFrame(data)

In [None]:
%%time
# 1 column of array<float>
test_pdf['data'] = test_pdf.values.tolist()
pdf = test_pdf[['data']]
pdf.shape

In [None]:
%%time
# force FloatType since Pandas uses double
schema = StructType([StructField("data",ArrayType(FloatType()), True)])
df = spark.createDataFrame(pdf, schema)

In [None]:
df.schema

### Save the test dataset as parquet files

In [None]:
%%time
df.write.mode("overwrite").parquet("fashion_mnist_test")

### Check arrow memory configuration

In [None]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "128")
# This line will fail if the vectorized reader runs out of memory
assert len(df.head()) > 0, "`df` should not be empty"

## Inference using Spark ML Model
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import sparkext
import torch

from pathlib import Path
from torch import nn

In [None]:
df = spark.read.parquet("fashion_mnist_test")

In [None]:
df.show()

### Using TorchScript Model
TorchScript models do not require the model definition prior to loading, but they don't serialize well from Spark driver to executors, so we must use a `model_loader` function that is invoked on the executor.

In [None]:
def model_loader(path: str):
    return torch.jit.load(path)

In [None]:
model_path = Path.cwd() / "model.ts"
model_path

In [None]:
model = sparkext.torch.Model(str(model_path), model_loader) \
            .setInputCol("data") \
            .setOutputCol("preds") \
            .setInputShape((-1,28,28))

In [None]:
predictions = model.transform(df)

In [None]:
predictions.write.mode("overwrite").parquet("mnist_predictions")

In [None]:
predictions.take(1)

In [None]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [None]:
img = np.array(df.take(1)[0].data)

In [None]:
plt.figure()
plt.imshow(img.reshape(28,28))
plt.show()

### Using Saved Model

Since the model is pickled, the model class must be defined before loading.

In [None]:
# Get cpu or gpu device for training.
# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
model = sparkext.torch.Model("model.pt") \
            .setInputCol("data") \
            .setOutputCol("preds") \
            .setInputShape((-1,28,28))

In [None]:
predictions = model.transform(df)

In [None]:
predictions.write.mode("overwrite").parquet("mnist_predictions")

In [None]:
predictions.take(1)

In [None]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [None]:
img = np.array(df.take(1)[0].data)

In [None]:
plt.figure()
plt.imshow(img.reshape(28,28))
plt.show()

## Inference using Spark DL UDF
Note: you can restart the kernel and run from this point to simulate running in a different node or environment.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from pathlib import Path
from pyspark.sql.functions import col
from sparkext.torch import model_udf

In [None]:
df = spark.read.parquet("fashion_mnist_test")

### Using TorchScript Model
TorchScript models do not require the model definition prior to loading, but they don't serialize well from Spark driver to executors, so we must use a `model_loader` function that is invoked on the executor.

In [None]:
def model_loader(path: str):
    import torch
    return torch.jit.load(path)

In [None]:
model_path = Path.cwd() / "model.ts"
model_path

In [None]:
classify = model_udf(str(model_path), model_loader=model_loader)

In [None]:
predictions = df.withColumn("preds", classify(col("data")))

In [None]:
%%time
preds = predictions.collect()

In [None]:
pred = predictions.take(1)

In [None]:
img = np.array(pred[0].data)

In [None]:
plt.figure()
plt.imshow(img.reshape(28,28))
plt.show()

In [None]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [None]:
pred[0].preds

### Using Saved Model

Since the model is pickled, the model class must be defined before loading.

In [None]:
from torch import nn

# Get cpu or gpu device for training.
# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
classify = model_udf("model.pt")

In [None]:
predictions = df.withColumn("preds", classify(col("data")))

In [None]:
predictions.take(1)

In [None]:
img = np.array(df.take(1)[0].data)

In [None]:
plt.figure()
plt.imshow(img.reshape(28,28))
plt.show()

In [None]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [None]:
pred[0].preds