# PyTorch and Dask

First some basics on using creating a Dask cluster. MLeRP uses SLURM as a resource manager which Dask is able to tap into.

In [3]:
from dask_jobqueue import SLURMCluster
from distributed import Client, LocalCluster
from dask import delayed
import dask

# Point Dask to the SLURM to use as it's back end
cluster = SLURMCluster(
    memory="64g", processes=1, cores=8
)

# Scale out to 4 nodes
num_nodes = 4
cluster.scale(num_nodes)
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://192.168.0.213:8787/status,

0,1
Dashboard: http://192.168.0.213:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://192.168.0.213:36235,Workers: 0
Dashboard: http://192.168.0.213:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


Dask will spin our jobs up in anticipation for work to the scale that you specify

You can check in on your jobs like you would with any other SLURM job with `squeue`

In [2]:
!squeue


             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
              1394     batch dask-wor mhar0048 PD       0:00      1 (None)
              1393     batch dask-wor mhar0048 PD       0:00      1 (None)
              1392     batch dask-wor mhar0048 PD       0:00      1 (None)
              1391     batch dask-wor mhar0048 PD       0:00      1 (None)
              1390     batch Jupyter  mhar0048  R    1:08:50      1 mlerp-node05
              1373     batch Jupyter      bpal  R    2:08:33      1 mlerp-node05


In [3]:
# The adapt method will let us scale out as we need the compute
# ...and scale back when we're idle letting others use the cluster
cluster.adapt(minimum=0, maximum=num_nodes)

<distributed.deploy.adaptive.Adaptive at 0x7f464e0fac50>

In [4]:
# You may need to run this cell a few times while waiting for Dask to clean up
!squeue

             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
              1390     batch Jupyter  mhar0048  R    1:08:51      1 mlerp-node05
              1391     batch dask-wor mhar0048  R       0:01      1 mlerp-node05
              1392     batch dask-wor mhar0048  R       0:01      1 mlerp-node09
              1393     batch dask-wor mhar0048  R       0:01      1 mlerp-node09
              1394     batch dask-wor mhar0048  R       0:01      1 mlerp-node09
              1373     batch Jupyter      bpal  R    2:08:34      1 mlerp-node05


Dask has a lovely UI that will let you see how the tasks are being computed.
You won't be able to connect to this with your web browser but VSCode has an extension for you to connect to it. 

Use the loopback address: http://127.0.0.1:8787 (Adjust the port if needed)

Now let's define a dask array and perform some computation. Dask arrays are parallelised across your workers nodes so they can be greater than the size of one worker's memory. Dask evaluates lazily, retuning 'futures' which record the tasks needed to be completed in the compute graph. They can be computed later for its value.

In [5]:
# da lets us scale out to the cluster more efficiently than npy
import dask.array as da
x = da.random.random((1000, 1000, 1000))
x

Unnamed: 0,Array,Chunk
Bytes,7.45 GiB,119.21 MiB
Shape,"(1000, 1000, 1000)","(250, 250, 250)"
Count,64 Tasks,64 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 7.45 GiB 119.21 MiB Shape (1000, 1000, 1000) (250, 250, 250) Count 64 Tasks 64 Chunks Type float64 numpy.ndarray",1000  1000  1000,

Unnamed: 0,Array,Chunk
Bytes,7.45 GiB,119.21 MiB
Shape,"(1000, 1000, 1000)","(250, 250, 250)"
Count,64 Tasks,64 Chunks
Type,float64,numpy.ndarray


In [6]:
# check squeue while this is running to see the jobs dynamically spinning up
x.compute()

array([[[0.1528269 , 0.73128592, 0.54163584, ..., 0.86250423,
         0.00938134, 0.99052283],
        [0.98192313, 0.18536231, 0.94974781, ..., 0.66356631,
         0.50824672, 0.27509965],
        [0.55184687, 0.91434221, 0.1738722 , ..., 0.11187265,
         0.27047543, 0.21469505],
        ...,
        [0.89691145, 0.58460763, 0.13615294, ..., 0.30249869,
         0.27796905, 0.43230158],
        [0.21621162, 0.26594022, 0.64527125, ..., 0.52795095,
         0.00117765, 0.41975114],
        [0.59384762, 0.09135375, 0.16423744, ..., 0.59283497,
         0.22632899, 0.20654085]],

       [[0.16640887, 0.62154262, 0.43467535, ..., 0.01744579,
         0.10629485, 0.01673153],
        [0.66014438, 0.87871253, 0.01207695, ..., 0.91857453,
         0.45560393, 0.86222286],
        [0.377713  , 0.89378863, 0.00605999, ..., 0.22809737,
         0.26318572, 0.53233163],
        ...,
        [0.62743982, 0.98589463, 0.86501689, ..., 0.92374094,
         0.05474441, 0.98064099],
        [0.0

In [7]:
# We can also accelerate dask arrays with GPUs using cupy
# There are similar analogues for the rest of RAPIDs
dask.config.set({"array.backend": "cupy"})
y = da.random.random((1000, 1000, 1000))
y.compute()

array([[[0.34572215, 0.43556403, 0.12558177, ..., 0.49003415,
         0.85630438, 0.79450724],
        [0.66733073, 0.025424  , 0.02184396, ..., 0.21553809,
         0.52054652, 0.28891133],
        [0.3733833 , 0.45080879, 0.92277408, ..., 0.81119711,
         0.08831033, 0.99256316],
        ...,
        [0.9953759 , 0.00840521, 0.85993305, ..., 0.27023196,
         0.48049836, 0.7744583 ],
        [0.81221035, 0.74607981, 0.19407742, ..., 0.73366886,
         0.6999948 , 0.41679576],
        [0.86969061, 0.84317931, 0.62250744, ..., 0.85697147,
         0.82231557, 0.79759314]],

       [[0.76864969, 0.76818799, 0.7999283 , ..., 0.81951779,
         0.03667452, 0.47616253],
        [0.94602241, 0.19928793, 0.15302713, ..., 0.67441548,
         0.83188983, 0.88484947],
        [0.4180656 , 0.85396532, 0.43575906, ..., 0.61590593,
         0.31649307, 0.85529874],
        ...,
        [0.64447301, 0.1578159 , 0.26899086, ..., 0.72044637,
         0.72996813, 0.52013984],
        [0.5

In [4]:
# Shut down the cluster
client.shutdown()


### Let's see how Dask works with a typical PyTorch workflow

Let's switch to a localcluster as its easier for interactive development. This will make all code execute locally allowing you to view print statements and debug errors normally rather than dealing with remote code execusion before we're ready.

Dask prefers to control all processes so that it can manage them more gracefully if they fail, but we need to give PyTorch the control to use multiprocessing as needed. To do this set `proccesses=False` to allow for multiprocessing inside Dask jobs.

In [5]:
# Set up a new local cluster and client
cluster = LocalCluster(processes=False)
client = Client(cluster)

We're going to use CIFAR as our case study for this example.
Content adapted from: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.multiprocessing as mp

# Define data transformations
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Define dataset and dataloader
batch_size = 1024
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
validset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)

# Note that we need to set the multiprocessing context so that PyTorch doesn't get
# PyTorch likes to use 'forking' while Dask uses 'spawn'
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=16, multiprocessing_context=mp.get_context("fork"))
validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size,
                                          shuffle=True, num_workers=16, multiprocessing_context=mp.get_context("fork"))

Files already downloaded and verified
Files already downloaded and verified


In [7]:
# Define a simple conv net
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        self.fc1 = nn.Linear(4 * 4 * 64, 4 * 64)
        self.fc2 = nn.Linear(4 * 64, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [13]:
import torch.optim as optim
from tqdm.notebook import tqdm
criterion = nn.CrossEntropyLoss()

# Train one epoch
def train(loader, path="./model", load=False, test=False, error=False):
    # Initialise model, optimizer and device
    model = Net()
    optimizer = optim.Adam(model.parameters(), lr=3e-4)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load state from disk so that we can split up the job
    if load: 
        state = torch.load(path)
        model.load_state_dict(state["model"])
        model.to(device)
        optimizer.load_state_dict(state["optimizer"])
    else:
        model.to(device)
    
    # A typical PyTorch training loop
    model.train()
    running_loss = 0
    for i, (inputs, labels) in enumerate(loader):
        # put the inputs on the device
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.detach().item()
        
        # Force an error
        if error:
            assert 0 == 1
        
        # Stop after one batch when testing        
        if test: 
            print("When running in a local cluster you can see print statements")
            break
    
    # Save model after each epoch
    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict()
        }, path)
    
    return running_loss / len(loader) if not test else loss.detach().item()

In [14]:
# Valid one epoch
def valid(loader, path="./model"):
    # Initialise model, optimizer and device
    model = Net()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load state from disk so that we can split up the job
    state = torch.load(path)
    model.load_state_dict(state["model"])
    model.to(device)
    model.eval()
    
    # A typical PyTorch validation loop
    running_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(loader):
            # put the inputs on the device
            inputs, labels = inputs.to(device), labels.to(device)

            # forward
            outputs = model(inputs)
            
            # loss
            loss = criterion(outputs, labels)
            running_loss += loss.detach().item()
            
            # accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            

    return running_loss / len(loader), correct / total


In [15]:
# Test our code locally first
client.submit(train, trainloader, test=True).result()

When running in a local cluster you can see print statements


2.303812265396118

In [16]:
# Shut down the cluster
client.shutdown()


Dask usually uses a 'nanny' that monitors any worker processes and gracefully restarts them if they fail or are killed while performing computations. The nanny is not compatable with daemonic processes - that is dask workers cannot perform multiprocessing while it's being used. We therefore need to set `nanny=False` to turn off the nanny to allow for multiprocessing inside Dask jobs for the cluster to work with PyTorch. (Just like when we `processes=False` for the local cluster.)

We can pass in extra SLURM requirements in job_extra_directives to request a GPU for our jobs


In [17]:
# Switch over to remote execusion
cluster = SLURMCluster(
    memory="64g", processes=1, cores=8, job_extra_directives=["--gres=gpu:1"], nanny=False
)

cluster.scale(1)
client = Client(cluster)

In [18]:
# Since this code is executing remotely we won't see our print statements
client.submit(train, trainloader, test=True).result()

2.306060314178467

In [19]:
# Dask will raise any errors that the process triggers locally, even when executing remotely
client.submit(train, trainloader, error=True).result()

AssertionError: 

In [20]:
# Run the training loop
epochs = 2

with tqdm(total=(epochs)) as pbar:
    for epoch in range(epochs):
        train_loss = client.submit(train, trainloader, load=epochs).result()
        valid_loss, accuracy = client.submit(valid, validloader).result()
        pbar.update()
        pbar.set_postfix(loss=train_loss)
        print(
            f"epoch: {epoch}, train_loss: {train_loss : .3f}, valid_loss: {valid_loss : .3f}, accuracy: {accuracy : .3f}")

  0%|          | 0/2 [00:00<?, ?it/s]

2023-02-09 23:59:29,801 - distributed.client - ERROR - Failed to reconnect to scheduler after 30.00 seconds, closing client


epoch: 0, train_loss:  2.254, valid_loss:  2.102, accuracy:  0.235
epoch: 1, train_loss:  2.030, valid_loss:  1.971, accuracy:  0.280


In [None]:
# Shut down the cluster
client.shutdown()
