# Pyspark PyTorch Training

Based on:
- https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
- https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

## Distributed Training on executors

In [None]:
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
#    "gloo",
#    rank=rank,
#    init_method=init_method,
#    world_size=world_size)
# For TcpStore, same way as on Linux.

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [None]:
def train_fn(partition):
    import json
    import os
    import random
    import socket
    import torch
    import time
    import numpy as np
    from pyspark import BarrierTaskContext
    from datetime import datetime
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torchvision import datasets
    from torchvision.transforms import ToTensor

    # data
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )

    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=64)
    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=64)

    # get list of participating nodes and my local address
    context = BarrierTaskContext.get()
    task_infos = context.getTaskInfos()
    workers = [t.address.split(':')[0] for t in task_infos]
    rank = context.partitionId()
    world_size = len(workers)
    
    print(f"Running basic DDP example on rank {rank}.")
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    
    model = NeuralNetwork()
    ddp_model = DDP(model)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    for t in range(5):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        test_loop(test_dataloader, model, loss_fn)

    dist.destroy_process_group()
    print("Done!")

    return partition


In [None]:
nodeRDD = sc.parallelize(range(2),2)

In [None]:
nodeRDD.barrier().mapPartitions(train_fn).collect()

## Refactored

### Read from disk

In [1]:
from abc import ABC
from pyspark import BarrierTaskContext

class FrameworkPlugin(ABC):
    @staticmethod
    def setup(context: BarrierTaskContext):
        pass
    
    @staticmethod
    def teardown():
        pass

In [2]:
import torch.distributed as dist

class PyTorchPlugin(FrameworkPlugin):
    @staticmethod
    def setup(context: BarrierTaskContext):
        import json
        import socket

        task_infos = context.getTaskInfos()
        workers = [t.address.split(':')[0] for t in task_infos]
        rank = context.partitionId()
        world_size = len(workers)
        my_addr = workers[rank]

        # find available port for master using allGather as a proxy for broadcast
        if rank == 0:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
                sock.bind(('', 0))
                _, port = sock.getsockname()
                master = "{}:{}".format(my_addr, port)
                master_candidates = context.allGather(master)
        else:
            # all nodes must invoke allGather
            master_candidates = context.allGather("")

        addr, port = master_candidates[0].split(':')
        print(f"Assigning master to: {addr}:{port} on rank {rank}")
        print(f"Running basic DDP example on rank {rank}.")
        os.environ['MASTER_ADDR'] = addr
        os.environ['MASTER_PORT'] = port
        
        dist.init_process_group("gloo", rank=rank, world_size=world_size)
        return rank, world_size
    
    @staticmethod
    def teardown():
        dist.destroy_process_group()
        

In [3]:
import functools

def distribute(framework_plugin):
    def decorator_distribute(train_fn):
        @functools.wraps(train_fn)
        def _wrapper(df_iter):
            from pyspark import BarrierTaskContext

            # get list of participating nodes and my local address
            context = BarrierTaskContext.get()
            rank, world_size = framework_plugin.setup(context)    
            result = train_fn(df_iter)
            framework_plugin.teardown()
            return result
        return _wrapper
    return decorator_distribute

In [4]:
import torch.nn as nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [5]:
def train_loop(dataloader, model, loss_fn, optimizer):
    # size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"batch: {batch:>5}, loss: {loss:>7f}")

In [10]:
import torch

def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [11]:
@distribute(PyTorchPlugin)
def train_fn(partition):
    import json
    import os
    import random
    import socket
    import torch
    import time
    import numpy as np
    from datetime import datetime
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torchvision import datasets
    from torchvision.transforms import ToTensor

    # data
    training_data = datasets.FashionMNIST(
        root="/home/leey/data",
        train=True,
        download=True,
        transform=ToTensor()
    )

    test_data = datasets.FashionMNIST(
        root="/home/leey/data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    train_dataloader = torch.utils.data.DataLoader(training_data, batch_size=64)
    test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=64)
    
    model = NeuralNetwork()
    ddp_model = DDP(model)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    for t in range(5):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, ddp_model, loss_fn, optimizer)
        test_loop(test_dataloader, ddp_model, loss_fn)

    print("Done!")

    return partition


In [12]:
nodeRDD = sc.parallelize(range(2),2)

In [13]:
nodeRDD.barrier().mapPartitions(train_fn).collect()

                                                                                

[0, 1]

### Read from Spark DataFrame

In [14]:
import pandas as pd
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
# data
training_data = datasets.FashionMNIST(
    root="/home/leey/data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="/home/leey/data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [None]:
images = training_data.data.numpy()
images.shape

In [None]:
labels = training_data.targets.numpy()
labels.shape

In [None]:
pdf784 = pd.DataFrame(images.reshape(-1, 784) / 255.0)
pdf784

In [None]:
pdf1 = pd.DataFrame()
pdf1['image'] = pdf784.values.tolist()
pdf1['label'] = labels
pdf1

In [None]:
from pyspark.sql.types import StructType, StructField, ArrayType, FloatType, IntegerType

# force FloatType since Spark defaults to DoubleType
schema = StructType([
    StructField("image",ArrayType(FloatType()), True),
    StructField("label",IntegerType(), True),
])

df = spark.createDataFrame(pdf1, schema=schema)
df.show()

In [None]:
df.write.mode("overwrite").parquet("fashion_mnist_1")

In [15]:
df = spark.read.parquet("/home/leey/dev/nvsparkdl/examples/fashion_mnist_1")

In [16]:
df.show()

[Stage 3:>                                                          (0 + 1) / 1]

+--------------------+-----+
|               image|label|
+--------------------+-----+
|[0.0, 0.0, 0.0, 0...|    2|
|[0.0, 0.0, 0.0, 0...|    0|
|[0.0, 0.0, 0.0, 0...|    8|
|[0.0, 0.0, 0.0, 0...|    6|
|[0.0, 0.0, 0.0039...|    0|
|[0.0, 0.0, 0.0039...|    2|
|[0.0, 0.0, 0.0, 0...|    1|
|[0.0, 0.0, 0.0, 0...|    1|
|[0.0, 0.0, 0.0, 0...|    2|
|[0.0, 0.0, 0.0, 0...|    2|
|[0.0, 0.0, 0.0, 0...|    9|
|[0.0, 0.0, 0.0, 0...|    2|
|[0.0, 0.0, 0.0, 0...|    8|
|[0.0, 0.0, 0.0, 0...|    6|
|[0.0, 0.0, 0.0, 0...|    7|
|[0.0, 0.0, 0.0, 0...|    7|
|[0.0, 0.0, 0.0, 0...|    5|
|[0.0, 0.0, 0.0, 0...|    2|
|[0.0, 0.0, 0.0, 0...|    6|
|[0.0, 0.0, 0.0, 0...|    1|
+--------------------+-----+
only showing top 20 rows



                                                                                

In [17]:
@distribute(PyTorchPlugin)
def train_fn(partition):
    import json
    import os
    import random
    import socket
    import torch
    import time
    import numpy as np
    from datetime import datetime
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torchvision import datasets
    from torchvision.transforms import ToTensor

    # Receive data from Spark
    # for pdf in partition:
    #     foo = pdf.to_numpy()
    #     bar = [tuple(x) for x in foo]
       
    foo = [pdf.to_numpy() for pdf in partition]
    baz = np.concatenate(foo)
    bar = [tuple(x) for x in baz]

    train_dataloader = torch.utils.data.DataLoader(bar, batch_size=64)
    # test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=64)
    
    model = NeuralNetwork()
    ddp_model = DDP(model)
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    for t in range(5):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, ddp_model, loss_fn, optimizer)
        # test_loop(test_dataloader, ddp_model, loss_fn)

    print("Done!")

    return partition


In [18]:
rdd_out = df \
    .repartition(2) \
    .mapInPandas(train_fn, schema="image array<int>, label float") \
    .rdd \
    .barrier() \
    .mapPartitions(lambda x: x) \
    .collect()

                                                                                

## Scratch