### Dask Client and PyTorch

Experiments on how to parallelize Pytorch tensors using Dask client. 

#### Dask client (multi-process)

In [None]:
from dask.distributed import Client, fire_and_forget
from dask_cuda import LocalCUDACluster

client = Client(memory_limit='4GB', n_workers=2, processes=True, threads_per_worker=2, dashboard_address=':8791')

# cluster = LocalCUDACluster(n_workers=1, threads_per_worker=1, dashboard_address=':8791',
#                               memory_limit="auto",
#                               device_memory_limit="auto", # memory spilling
#                               #rmm_pool_size="5GB",
#                               #rmm_managed_memory=True,
#                               #silence_logs=False,
#                               local_directory="/tmp/", 
#                               #enable_nvlink=True,
#                               ) # See https://docs.rapids.ai/api/dask-cuda/nightly/api.html
# client = Client(cluster)

display(client)

In [None]:
import time
import numpy as np
import torch
N = 1000

class A:
    def __init__(self):
        self.x = torch.rand(N, N)

class Kernel:
    def __init__(self):
        self.count = 0
        
    def kernel(self, obj):
        self.count += 1
        y = torch.rand(N, N)
        for _ in range(10):
            y = y + torch.rand_like(y)
        # return torch.matmul(obj.x, y)  # this cannot be parallelized by Dask client.
        return y

In [None]:
# op = Kernel()
# for _ in range(500):
#     obj = A()
#     obj_ = client.scatter(obj)
#     future = client.submit(op.kernel, obj_)
#     fire_and_forget(future)
# # op.kernel(A())
# # op.count

#### Future: Tensors

Observations:
- There is an issue of transfering the history of pytorch graph for a tensor when using the multi-process dask client.
- Multi-thread client is fine.

In [None]:
def kernel(x, y):
    # r = torch.zeros(x.shape)
    # for _ in range(1000):
    #     if x.max() > y.max():
    #         r = r + torch.sin(x+y)
    #     else:
    #         r = r + torch.cos(x-y)
    # return r
    return torch.sin(x+y)

def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs = grad_outputs, create_graph=True)[0]
    return grad

In [None]:
size=100

x = torch.rand(size, requires_grad=True)
y = torch.rand(size, requires_grad=False)
tensors = [x, y]

futures = []
for _ in range(10):
    scattered_tensors = client.scatter(tensors, broadcast=True)
    future = client.submit(kernel, *scattered_tensors)
    futures.append(future)
results = client.gather(futures)

gradient(results[0], x)

#### Future: structures

Observations:
- Dask cannot directly parallize a kernel with generic input class such as Structure.
- As a solution, the kernel's inputs have to be translated in form of arrays or tensors. 
- Also defining a Kernel class which takes care of unnecessary inputs are very useful and an elegant design. 

In [None]:
import sys
sys.path.append('../')

import torchip as tp
from torchip import logger
from torchip.datasets import RunnerStructureDataset, ToStructure
from torchip.potentials import NeuralNetworkPotential

tp.device.DEVICE = "cpu"

import torch
import time
from pathlib import Path

In [None]:
potdir = Path("../examples/LJ")

structures = RunnerStructureDataset(Path(potdir, "input.data"), persist=True) 
structure0 = structures[4]

nnp = NeuralNetworkPotential(Path(potdir, "input.nn"))
descriptor = nnp.descriptor["Ne"]
scaler = nnp.scaler["Ne"]

In [None]:
# structure0.calculate_distance(aid=0, neighbors=1, detach=False, return_diff=True)

In [None]:
from torch import Tensor

class Box:
    def __init__(self, lattice):
        self.lattice = lattice
    
    @staticmethod
    def _apply_pbc(dx, lat):
        for i in range(3):
            l = lat[i, i]
            dx[..., i] = torch.where(dx[..., i] >  0.5E0*l, dx[..., i] - l, dx[..., i])
            dx[..., i] = torch.where(dx[..., i] < -0.5E0*l, dx[..., i] + l, dx[..., i])
        return dx
    
    def apply_pbc(self, dx):
        return Box._apply_pbc(dx, self.lattice)
        

class Structure:
    @staticmethod
    def _calculate_distance(
            pos: Tensor,
            aid: int, 
            lat: Tensor = None,
            detach: bool = False, 
            neighbors = None, 
            difference: bool = False
        ) -> Tensor: # TODO: also tuple?
        """
        This method calculates an array of distances of all atoms existing in the structure from an input atom. 
        TODO: input pbc flag, using default pbc from global configuration
        TODO: also see torch.cdist
        """   
        x = pos.detach() if detach else pos
        x = x[neighbors] if neighbors else x 
        x = torch.unsqueeze(x, dim=0) if x.ndim == 1 else x  # for when neighbors index is only a number
        dx = pos[aid] - x  # FIXME: detach?

        # Apply PBC along x,y,and z directions if lattice info is provided 
        if lat is not None:
            dx = Box._apply_pbc(dx, lat) # using broadcasting

        # Calculate distance from dx tensor
        distance = torch.linalg.vector_norm(dx, dim=1)

        return distance if not difference else (distance, dx)
    
    
class Kernel:
    def __init__(self, func, dist, pbc):
        self.func = func
        self.dist = dist
        self.pbc = pbc
        
    def __call__(self, x, at, dtype=None, device=None, emap=None, lat=None):
        for i in range(190000):
            self.func(x)
            self.func(at)
        if emap:
            emap[int(at[0])]
        if self.dist:
            dx = self.dist(x, aid=0, neighbors=1)
            print(dx)
        if lat:
            self.pbc(dx, lat)
        # time.sleep(0.1)

In [None]:
kernel = Kernel(torch.max, dist=Structure._calculate_distance, pbc=Box._apply_pbc)

for structure in structures:
    tensors = [
        structure.position, 
        structure.atype
    ]
    params = {
        'dtype': torch.double, 
        'device': 'cpu', 
        'emap': structure.element_map.atype_to_element,
        'lat': structure.box.lattice if structure.box else None,
    }
    scattered_tensors = client.scatter(tensors, broadcast=True)  
    future = client.submit(kernel, *scattered_tensors, **params)
    fire_and_forget(future)

#### Delay

In [None]:
import torch
from dask import delayed

# @torch.jit.script
def fun(x: torch.Tensor) -> torch.Tensor:
    return x

# fn = delayed(fun, pure=False)  # works
fn = delayed(fun, pure=True)  # causes error

In [None]:
# fn(torch.rand(size, requires_grad=True, dtype=dtype)).compute()