# PyTorch and Dask

Creating a cluster and performing some computation

In [1]:
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster
from dask import delayed
import dask

cluster = SLURMCluster(
    memory="64g", processes=1, cores=2
)
num_nodes = 4

cluster.scale(num_nodes)
# cluster = LocalCluster(processes=False)
client = Client(cluster)
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 42939 instead


0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://192.168.0.213:42939/status,

0,1
Dashboard: http://192.168.0.213:42939/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://192.168.0.213:42643,Workers: 0
Dashboard: http://192.168.0.213:42939/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [2]:
# Note how dask spins our jobs up in anticipation for work
!squeue

             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
              1312     batch Jupyter  mhar0048  R      20:59      1 mlerp-node05
              1215     batch Jupyter    yiliao  R 3-21:21:03      1 mlerp-node09
              1214     batch Jupyter    yiliao  R 3-21:24:39      1 mlerp-node05
              1336     batch dask-wor mhar0048  R       0:03      1 mlerp-node05
              1337     batch dask-wor mhar0048  R       0:03      1 mlerp-node05
              1338     batch dask-wor mhar0048  R       0:03      1 mlerp-node05
              1339     batch dask-wor mhar0048  R       0:03      1 mlerp-node09


In [3]:
# The adapt method will let us scale out as we need the compute
# ...and scale back when we're idle
cluster.adapt(minimum=0, maximum=num_nodes)

<distributed.deploy.adaptive.Adaptive at 0x7efc80c60790>

In [4]:
!squeue

             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
              1312     batch Jupyter  mhar0048  R      21:02      1 mlerp-node05
              1215     batch Jupyter    yiliao  R 3-21:21:06      1 mlerp-node09
              1214     batch Jupyter    yiliao  R 3-21:24:42      1 mlerp-node05
              1336     batch dask-wor mhar0048  R       0:06      1 mlerp-node05
              1337     batch dask-wor mhar0048  R       0:06      1 mlerp-node05
              1338     batch dask-wor mhar0048  R       0:06      1 mlerp-node05
              1339     batch dask-wor mhar0048  R       0:06      1 mlerp-node09


In [5]:
# Dask has a lovely UI that will let you see how the tasks are being computed
# VSCode has an extension for you to connect to this: http://127.0.0.1:8787 (Adjust the port if needed)

In [6]:
# da lets us scale out to the cluster more efficiently than npy
import dask.array as da
x = da.random.random((1000, 1000, 1000))
x

Unnamed: 0,Array,Chunk
Bytes,7.45 GiB,119.21 MiB
Shape,"(1000, 1000, 1000)","(250, 250, 250)"
Count,64 Tasks,64 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 7.45 GiB 119.21 MiB Shape (1000, 1000, 1000) (250, 250, 250) Count 64 Tasks 64 Chunks Type float64 numpy.ndarray",1000  1000  1000,

Unnamed: 0,Array,Chunk
Bytes,7.45 GiB,119.21 MiB
Shape,"(1000, 1000, 1000)","(250, 250, 250)"
Count,64 Tasks,64 Chunks
Type,float64,numpy.ndarray


In [7]:
# dask evaluates lazily, retuning 'futures'
# they can then be computed later for its value
x.compute()

array([[[7.95991422e-01, 7.29635940e-01, 2.81710508e-02, ...,
         7.80411871e-01, 6.47691554e-01, 7.09732354e-01],
        [4.68557888e-02, 1.61744555e-01, 9.68807923e-01, ...,
         4.14736245e-01, 3.45848334e-01, 2.94003999e-01],
        [8.68459168e-01, 3.31556693e-02, 7.29167989e-01, ...,
         2.67547397e-01, 3.62516388e-01, 4.47140395e-01],
        ...,
        [2.50422490e-01, 1.16683537e-01, 3.64873815e-01, ...,
         8.11636480e-01, 2.19410748e-01, 8.34222863e-01],
        [5.32683002e-01, 9.89774462e-01, 3.81822225e-01, ...,
         5.22320346e-01, 2.45289960e-01, 4.99485662e-01],
        [5.73752150e-02, 4.16188506e-01, 2.01487615e-01, ...,
         4.78838214e-01, 6.14959854e-01, 7.40204742e-01]],

       [[3.78725383e-01, 3.03743484e-02, 4.81805729e-01, ...,
         1.80843551e-01, 8.89631808e-01, 2.05693056e-01],
        [7.99460633e-01, 9.32727544e-01, 3.78722157e-01, ...,
         7.09053389e-02, 1.75157899e-01, 5.05155611e-01],
        [8.65631434e-01, 

In [8]:
# Let's switch to a localcluster for easier active development
# This will make all code execute locally
# We need to make proccesses=False to allow for multiprocessing inside Dask jobs
# for the local cluster to work with PyTorch
client.shutdown()
cluster = LocalCluster(processes=False)
client = Client(cluster)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 44055 instead


### Let's see how Dask works with a typical PyTorch workflow
Content adapted from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [9]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.multiprocessing as mp

# Define data transformations
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Define dataset and dataloader
batch_size = 1024
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2, multiprocessing_context=mp.get_context("fork"))

Files already downloaded and verified


In [10]:
# Define a simple conv net
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.fc1 = nn.Linear(4 * 4 * 64, 4 * 64)
        self.fc2 = nn.Linear(4 * 64, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [11]:
import torch.optim as optim
from tqdm.notebook import tqdm
criterion = nn.CrossEntropyLoss()

# Train one epoch
def train(loader, path="./model", load=False, test=False, error=False):
    # Initialise model, optimizer and device
    model = Net()
    optimizer = optim.Adam(model.parameters(), lr=3e-4)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load state from disk so that we can split up the job
    if load: 
        state = torch.load(path)
        model.load_state_dict(state["model"])
        model.to(device)
        optimizer.load_state_dict(state["optimizer"])
    else:
        model.to(device)
    
    # A typical PyTorch training loop
    running_loss = 0
    for i, (inputs, labels) in enumerate(trainloader):
        # put the inputs on the device
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.detach().item()
        
        # Force an error
        if error:
            assert 0 == 1
        
        # Stop after one batch when testing        
        if test: 
            print("When running in a local cluster you can see print statements")
            break
    
    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict()
        }, path)
    
    return running_loss / len(trainloader) if not test else loss.detach().item()

2023-02-06 00:26:58,981 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client


In [12]:
# Test our code locally first
client.submit(train, trainloader, test=True).result()

  [<torch.utils.data.dataloader.DataLoader object at 0x7efc80c60520>]
Consider scattering large objects ahead of time
with client.scatter to reduce scheduler burden and 
keep data on workers

    future = client.submit(func, big_data)    # bad

    big_future = client.scatter(big_data)     # good
    future = client.submit(func, big_future)  # good


When running in a local cluster you can see print statements


2.305037021636963

In [15]:
# We need to turn off the nanny to allow for multiprocessing inside Dask jobs for the cluster to work with PyTorch
# We can pass in SLURM requirements to ensure we get a GPU for our jobs
client.shutdown()
cluster = SLURMCluster(
    memory="64g", processes=1, cores=2, job_extra_directives=["--gres=gpu:1"], nanny=False
)
cluster.scale(1)
client = Client(cluster)

In [16]:
# Test our code on the SLURM cluster
# Since this code is executing remotely we won't see our print statements
client.submit(train, trainloader, test=True).result()

2.3079774379730225

In [17]:
# Dask will raise any errors that the process triggers locally, even when executing remotely
client.submit(train, trainloader, error=True).result()

2023-02-06 00:29:12,785 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client


AssertionError: 

In [18]:
# Run the training loop
epochs = 2

with tqdm(total=(epochs)) as pbar:
    for epoch in range(epochs):
        loss = client.submit(train, trainloader, load=epochs).result()
        pbar.update()
        pbar.set_postfix(loss=loss)
        print(f"epoch: {epoch} loss: {loss : .3f}")
client.shutdown()

  0%|          | 0/2 [00:00<?, ?it/s]

epoch: 0 loss:  2.262
epoch: 1 loss:  2.014


2023-02-06 00:30:17,778 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client
