# PySpark PyTorch Inference

### Image Classification
Based on: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

In [1]:
import torch

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
torch.__version__

'1.11.0+cpu'

In [3]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [4]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [5]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape} {X.dtype}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) torch.float32
Shape of y: torch.Size([64]) torch.int64


### Create model

In [6]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

Using cpu device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


### Train Model

In [7]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [8]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [9]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [10]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.305333  [    0/60000]
loss: 2.293172  [ 6400/60000]
loss: 2.271651  [12800/60000]
loss: 2.265887  [19200/60000]
loss: 2.253572  [25600/60000]
loss: 2.213814  [32000/60000]
loss: 2.226953  [38400/60000]
loss: 2.188729  [44800/60000]
loss: 2.190818  [51200/60000]
loss: 2.163593  [57600/60000]
Test Error: 
 Accuracy: 41.1%, Avg loss: 2.150287 

Epoch 2
-------------------------------
loss: 2.165365  [    0/60000]
loss: 2.156008  [ 6400/60000]
loss: 2.096980  [12800/60000]
loss: 2.116343  [19200/60000]
loss: 2.065670  [25600/60000]
loss: 2.000599  [32000/60000]
loss: 2.033970  [38400/60000]
loss: 1.950576  [44800/60000]
loss: 1.960650  [51200/60000]
loss: 1.901740  [57600/60000]
Test Error: 
 Accuracy: 56.4%, Avg loss: 1.886987 

Epoch 3
-------------------------------
loss: 1.922720  [    0/60000]
loss: 1.892854  [ 6400/60000]
loss: 1.774740  [12800/60000]
loss: 1.823813  [19200/60000]
loss: 1.711796  [25600/60000]
loss: 1.663581  [32000/600

### Save Model State Dict
This is the [currently recommended save format](https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-model-for-inference).

In [11]:
torch.save(model.state_dict(), "model_weights.pt")
print("Saved PyTorch Model State to model_weights.pt")

Saved PyTorch Model State to model_weights.pt


### Save Entire Model
This saves the entire model using python pickle, but has the [following disadvantage](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model):
> The serialized data is bound to the specific classes and the exact directory structure used when the model is saved... Because of this, your code can break in various ways when used in other projects or after refactors.

In [12]:
torch.save(model, "model.pt")

### Save Model as TorchScript
This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python).  However, this currently doesn't work with spark, which uses pickle serialization.

In [13]:
scripted = torch.jit.script(model)

In [14]:
scripted.save("model.ts")

### Load Model State

In [15]:
model_from_state = NeuralNetwork()
model_from_state.load_state_dict(torch.load("model_weights.pt"))

<All keys matched successfully>

In [16]:
model_from_state.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model_from_state(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

Predicted: "Ankle boot", Actual: "Ankle boot"


### Load Model

In [17]:
new_model = torch.load("model.pt")

In [18]:
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = new_model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

Predicted: "Ankle boot", Actual: "Ankle boot"


### Load Torchscript Model

In [19]:
ts_model = torch.jit.load("model.ts")

In [20]:
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = ts_model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

Predicted: "Ankle boot", Actual: "Ankle boot"


## MLFlow Model

### Save MLFlow Model

In [21]:
import mlflow
import numpy as np
import subprocess
import torch

In [22]:
from mlflow.models.signature import infer_signature, ModelSignature
from mlflow.types.schema import Schema, TensorSpec

In [23]:
ts_model = torch.jit.load("model.ts")

#### Inferred signature

In [24]:
sample = test_data[0][0].reshape(1,784)
type(sample), sample.shape

(torch.Tensor, torch.Size([1, 784]))

In [25]:
signature = infer_signature(sample.numpy(), ts_model(sample).detach().numpy())
signature

inputs: 
  [Tensor('float32', (-1, 784))]
outputs: 
  [Tensor('float32', (-1, 10))]

In [26]:
subprocess.call("rm -rf model_infer".split())
mlflow.pytorch.save_model(pytorch_model=ts_model,
                         signature=signature,
                         path="model_infer")



#### Manual signature

In [27]:
# PyTorch flavor doesn't like named inputs
input_schema = Schema([TensorSpec(np.dtype(np.float32), (-1, 784), "dense_input")])
output_schema = Schema([TensorSpec(np.dtype(np.float32), (-1, 10), "dense_1")])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
signature

inputs: 
  ['dense_input': Tensor('float32', (-1, 784))]
outputs: 
  ['dense_1': Tensor('float32', (-1, 10))]

In [28]:
subprocess.call("rm -rf model_manual".split())
mlflow.pytorch.save_model(pytorch_model=ts_model,
                         signature=signature,
                         path="model_manual")



### Load data pandas.DataFrame as 784 floats

In [29]:
import numpy as np
import pandas as pd

In [30]:
data = test_data.data.numpy()
data.shape, data.dtype

((10000, 28, 28), dtype('uint8'))

In [31]:
data = data.reshape(10000, 784) / 255.0
data = data.astype(np.float32)
data.shape, data.dtype

((10000, 784), dtype('float32'))

In [32]:
test_pdf = pd.DataFrame(data)
test_pdf

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,774,775,776,777,778,779,780,781,782,783
0,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0
1,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,...,0.007843,0.011765,0.0,0.011765,0.682353,0.741176,0.262745,0.0,0.0,0.0
2,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.003922,0.000000,...,0.643137,0.227451,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0
3,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.082353,...,0.003922,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0
4,0.0,0.0,0.0,0.007843,0.0,0.003922,0.003922,0.0,0.000000,0.000000,...,0.278431,0.047059,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0
9996,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.121569,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0
9997,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,...,0.105882,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0
9998,0.0,0.0,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.000000,0.000000,...,0.000000,0.000000,0.0,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0


### Load data as pandas.DataFrame of 1 array of 784 floats

In [33]:
test_pdf1 = pd.DataFrame()
test_pdf1['dense_input'] = test_pdf.values.tolist()
test_pdf1

Unnamed: 0,dense_input
0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003..."
3,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
4,"[0.0, 0.0, 0.0, 0.007843137718737125, 0.0, 0.0..."
...,...
9995,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9996,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9997,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9998,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


### Infer using MLFlow PyFuncModel (inferred signature w/o input names)

In [34]:
import mlflow

In [35]:
# Note: PyTorch "flavor" is defined in the MLModel file.
model_infer = mlflow.pyfunc.load_model("model_infer")



In [36]:
print(model_infer.metadata)  # contents of MLModel file

flavors:
  python_function:
    data: data
    env: conda.yaml
    loader_module: mlflow.pytorch
    pickle_module_name: mlflow.pytorch.pickle_module
    python_version: 3.9.10
  pytorch:
    code: null
    model_data: data
    pytorch_version: 1.11.0+cpu
mlflow_version: 1.25.2.dev0
model_uuid: eccae151243f4216a4640fb4e60d2d1c
signature:
  inputs: '[{"type": "tensor", "tensor-spec": {"dtype": "float32", "shape": [-1, 784]}}]'
  outputs: '[{"type": "tensor", "tensor-spec": {"dtype": "float32", "shape": [-1,
    10]}}]'
utc_time_created: '2022-04-26 23:15:05.415817'



#### Infer using pandas.DataFrame (784 floats)

In [37]:
preds = model_infer.predict(test_pdf)

In [38]:
type(preds)

numpy.ndarray

In [39]:
pd.DataFrame(preds)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,-2.042467,-2.246295,-0.688900,-2.178937,-0.917537,2.271457,-0.909829,2.509097,1.686130,3.034637
1,0.940828,-3.285853,3.581051,-1.234425,3.221895,-2.022241,2.715813,-3.501479,1.552690,-2.119475
2,1.881065,4.977236,-0.511262,3.587086,0.533625,-2.545371,0.409574,-3.335326,-2.377713,-2.922745
3,1.305144,3.794120,-0.423444,2.669347,0.324053,-1.769337,0.232410,-2.441168,-1.861111,-2.070968
4,1.026658,-1.369347,1.488516,-0.220768,1.353701,-1.123980,1.471480,-2.058816,0.558626,-1.172350
...,...,...,...,...,...,...,...,...,...,...
9995,-2.476522,-3.402988,-0.709699,-3.013192,-1.013795,2.752758,-1.020825,2.582841,2.396552,4.753345
9996,0.954735,2.415779,-0.235035,1.798817,0.224770,-1.150019,0.216863,-1.612515,-1.213792,-1.535196
9997,1.011031,-0.345896,-0.034898,0.637689,0.247803,-0.683495,0.544897,-1.331556,0.591085,-0.740176
9998,1.317076,3.869592,-0.527877,2.937577,0.342804,-1.840980,0.269499,-2.493871,-1.901167,-2.141911


#### Infer using pandas.DataFrame (array of 784 floats)

In [40]:
preds = model_infer.predict(test_pdf1)

MlflowException: Shape of input (10000, 1) does not match expected shape (-1, 784).

#### Infer using dict

In [41]:
preds = model_infer.predict({"dense_input": data})

MlflowException: This model contains a tensor-based model signature with no input names, which suggests a numpy array input, but an input of type <class 'dict'> was found.

### Infer using MLFlow PyFuncModel (manual signature w/ input names)

In [42]:
# Note: PyTorch "flavor" is defined in the MLModel file.
model_manual = mlflow.pyfunc.load_model("model_manual")



In [43]:
print(model_manual.metadata)  # contents of MLModel file

flavors:
  python_function:
    data: data
    env: conda.yaml
    loader_module: mlflow.pytorch
    pickle_module_name: mlflow.pytorch.pickle_module
    python_version: 3.9.10
  pytorch:
    code: null
    model_data: data
    pytorch_version: 1.11.0+cpu
mlflow_version: 1.25.2.dev0
model_uuid: b718a8c147044ad6a7d635dcf9185532
signature:
  inputs: '[{"name": "dense_input", "type": "tensor", "tensor-spec": {"dtype": "float32",
    "shape": [-1, 784]}}]'
  outputs: '[{"name": "dense_1", "type": "tensor", "tensor-spec": {"dtype": "float32",
    "shape": [-1, 10]}}]'
utc_time_created: '2022-04-26 23:15:07.156337'



#### Infer using pandas.DataFrame (784 floats)

In [44]:
preds = model_manual.predict(test_pdf)

MlflowException: Model is missing inputs ['dense_input']. Note that there were extra inputs: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783]

#### Infer using pandas.DataFrame (array of 784 floats)

In [45]:
preds = model_manual.predict(test_pdf1)

ValueError: setting an array element with a sequence.

#### Infer using dict

In [46]:
preds = model_manual.predict({"dense_input": data})

TypeError: The PyTorch flavor does not support List or Dict input types. Please use a pandas.DataFrame or a numpy.ndarray

## PySpark

### Convert numpy dataset to Spark DataFrame (via Pandas DataFrame)

In [47]:
import numpy as np
import pandas as pd
from pyspark.sql.types import StructType, StructField, ArrayType, FloatType

In [48]:
data = test_data.data.numpy()
data.shape, data.dtype

((10000, 28, 28), dtype('uint8'))

In [49]:
data = data.reshape(10000, 784) / 255.0
data = data.astype(np.float32)
data.shape, data.dtype

((10000, 784), dtype('float32'))

In [50]:
test_pdf = pd.DataFrame(data)

### Save as 784 columns of float

In [51]:
%%time
df = spark.createDataFrame(test_pdf)
df.write.mode("overwrite").parquet("fashion_mnist_784")

22/04/26 16:17:29 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
22/04/26 16:17:29 WARN TaskSetManager: Stage 0 contains a task of very large size (4313 KiB). The maximum recommended task size is 1000 KiB.
[Stage 0:>                                                        (0 + 16) / 16]

CPU times: user 1min 15s, sys: 167 ms, total: 1min 15s
Wall time: 1min 23s


                                                                                

In [52]:
len(df.columns)

784

### Save as 1 column of 784 floats

In [53]:
%%time
# 1 column of array<float>
test_pdf['data'] = test_pdf.values.tolist()
pdf = test_pdf[['data']]
pdf.shape

CPU times: user 224 ms, sys: 51.8 ms, total: 276 ms
Wall time: 274 ms


(10000, 1)

In [54]:
%%time
# force FloatType since Pandas uses double
schema = StructType([StructField("data",ArrayType(FloatType()), True)])
df1 = spark.createDataFrame(pdf, schema)

CPU times: user 3.25 s, sys: 44 ms, total: 3.3 s
Wall time: 3.37 s


In [55]:
len(df1.columns)

1

In [56]:
df1.schema

StructType(List(StructField(data,ArrayType(FloatType,true),true)))

In [57]:
%%time
df1.write.mode("overwrite").parquet("fashion_mnist_test")

22/04/26 16:17:39 WARN TaskSetManager: Stage 1 contains a task of very large size (4315 KiB). The maximum recommended task size is 1000 KiB.
[Stage 1:>                                                        (0 + 16) / 16]

CPU times: user 14 ms, sys: 3.86 ms, total: 17.9 ms
Wall time: 1.36 s


                                                                                

### Check arrow memory configuration

In [58]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "128")
# This line will fail if the vectorized reader runs out of memory
assert len(df.head()) > 0, "`df` should not be empty"

22/04/26 16:17:40 WARN TaskSetManager: Stage 2 contains a task of very large size (4313 KiB). The maximum recommended task size is 1000 KiB.


## Inference using MLFlow pyfunc.spark_udf

In [59]:
import mlflow
from pyspark.sql.functions import struct

In [60]:
df = spark.read.parquet("fashion_mnist_784")

In [61]:
columns = df.columns
len(columns)

784

#### Inferred signature

In [62]:
mnist_infer = mlflow.pyfunc.spark_udf(spark, model_uri="model_infer", result_type="array<float>")



In [63]:
mnist_infer.metadata.signature

inputs: 
  [Tensor('float32', (-1, 784))]
outputs: 
  [Tensor('float32', (-1, 10))]

In [65]:
%%time
preds = df.withColumn("preds", mnist_infer(*columns)).toPandas()

22/04/26 16:18:32 WARN TaskSetManager: Lost task 15.0 in stage 4.0 (TID 49) (192.168.86.223 executor 1): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1273, in udf
    os.kill(scoring_server_proc.pid, signal.SIGTERM)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1115, in _predict_row_batch
    result = predict_fn(pdf)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1255, in batch_predict_fn
    return loaded_model.predict(pdf)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 630, in predict
    data = _enforce_schema(data, input_schema)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 585, in _enforce_schema
    _enforce_tensor_schema(pfInput, input_schema)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 509, in _enforce_tensor_schema
    new_pfInput = _enforce_tensor_spec(pfInput.to_numpy(), input_

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1273, in udf
    os.kill(scoring_server_proc.pid, signal.SIGTERM)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1115, in _predict_row_batch
    result = predict_fn(pdf)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1255, in batch_predict_fn
    return loaded_model.predict(pdf)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 630, in predict
    data = _enforce_schema(data, input_schema)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 585, in _enforce_schema
    _enforce_tensor_schema(pfInput, input_schema)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 509, in _enforce_tensor_schema
    new_pfInput = _enforce_tensor_spec(pfInput.to_numpy(), input_schema.inputs[0])
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 453, in _enforce_tensor_spec
    raise MlflowException(
mlflow.exceptions.MlflowException: Shape of input (128, 1) does not match expected shape (-1, 784).


#### Manual schema

In [66]:
mnist_manual = mlflow.pyfunc.spark_udf(spark, model_uri="model_manual", result_type="array<float>")



In [67]:
mnist_manual.metadata.signature

inputs: 
  ['dense_input': Tensor('float32', (-1, 784))]
outputs: 
  ['dense_1': Tensor('float32', (-1, 10))]

In [68]:
%%time
preds = df.withColumn("preds", mnist_manual(*columns)).toPandas()

22/04/26 16:18:50 WARN TaskSetManager: Lost task 2.0 in stage 5.0 (TID 52) (192.168.86.223 executor 1): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1273, in udf
    os.kill(scoring_server_proc.pid, signal.SIGTERM)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1115, in _predict_row_batch
    result = predict_fn(pdf)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1255, in batch_predict_fn
    return loaded_model.predict(pdf)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 630, in predict
    data = _enforce_schema(data, input_schema)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 585, in _enforce_schema
    _enforce_tensor_schema(pfInput, input_schema)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 498, in _enforce_tensor_schema
    new_pfInput[col_name] = _enforce_tensor_spec(
  File "/home/le

PythonException: 
  An exception was thrown from the Python worker. Please see the stack trace below.
Traceback (most recent call last):
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1273, in udf
    os.kill(scoring_server_proc.pid, signal.SIGTERM)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1115, in _predict_row_batch
    result = predict_fn(pdf)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 1255, in batch_predict_fn
    return loaded_model.predict(pdf)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 630, in predict
    data = _enforce_schema(data, input_schema)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 585, in _enforce_schema
    _enforce_tensor_schema(pfInput, input_schema)
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 498, in _enforce_tensor_schema
    new_pfInput[col_name] = _enforce_tensor_spec(
  File "/home/leey/devpub/mlflow/mlflow/pyfunc/__init__.py", line 444, in _enforce_tensor_spec
    raise MlflowException(
mlflow.exceptions.MlflowException: Shape of input (128,) does not match expected shape (-1, 784).


22/04/26 16:18:50 WARN TaskSetManager: Lost task 4.0 in stage 5.0 (TID 54) (192.168.86.223 executor 1): TaskKilled (Stage cancelled)
22/04/26 16:18:50 WARN TaskSetManager: Lost task 0.0 in stage 5.0 (TID 50) (192.168.86.223 executor 1): TaskKilled (Stage cancelled)
22/04/26 16:18:50 WARN TaskSetManager: Lost task 3.0 in stage 5.0 (TID 53) (192.168.86.223 executor 0): TaskKilled (Stage cancelled)
22/04/26 16:18:50 WARN TaskSetManager: Lost task 5.0 in stage 5.0 (TID 55) (192.168.86.223 executor 0): TaskKilled (Stage cancelled)
22/04/26 16:18:50 WARN TaskSetManager: Lost task 12.0 in stage 5.0 (TID 62) (192.168.86.223 executor 1): TaskKilled (Stage cancelled)
22/04/26 16:18:50 WARN TaskSetManager: Lost task 10.0 in stage 5.0 (TID 60) (192.168.86.223 executor 1): TaskKilled (Stage cancelled)
22/04/26 16:18:50 WARN TaskSetManager: Lost task 14.0 in stage 5.0 (TID 64) (192.168.86.223 executor 1): TaskKilled (Stage cancelled)
22/04/26 16:18:50 WARN TaskSetManager: Lost task 6.0 in stage 5.0 

### Check predictions

In [None]:
import numpy as np
from matplotlib import pyplot as plt

In [None]:
img = preds.drop(columns=['preds']).values[0]

In [None]:
plt.figure()
plt.imshow(img.reshape(28,28))
plt.show()

In [None]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

In [None]:
preds['preds'][0]