-
Notifications
You must be signed in to change notification settings - Fork 545
/
multi_device_kernel.py
65 lines (47 loc) · 2.53 KB
/
multi_device_kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
import torch
from torch.nn.parallel import DataParallel
from .kernel import Kernel
from ..lazy import CatLazyTensor, lazify
from .. import settings
class MultiDeviceKernel(DataParallel, Kernel):
r"""
Allocates the covariance matrix on distributed devices, e.g. multiple GPUs.
Args:
- :attr:`base_kernel`: Base kernel to distribute
- :attr:`device_ids`: list of `torch.device` objects to place kernel chunks on
- :attr:`output_device`: Device where outputs will be placed
"""
def __init__(self, base_kernel, device_ids, output_device=None, **kwargs):
DataParallel.__init__(self,
module=base_kernel,
device_ids=device_ids,
output_device=output_device,
dim=-2)
self.output_device = output_device if output_device else device_ids[0]
self.__cached_x1 = torch.empty(1)
self.__cached_x2 = torch.empty(1)
def forward(self, x1, x2, diag=False, **kwargs):
if diag:
return self.module.forward(x1, x2, diag=True, **kwargs).to(self.output_device)
if not x1.device == self.__cached_x1.device or not torch.equal(x1, self.__cached_x1):
self._x1_scattered, self._kwargs = self.scatter((x1,), kwargs, self.device_ids)
self.__cached_x1 = x1
if not x2.device == self.__cached_x2.device or not torch.equal(x2, self.__cached_x2):
self._x2_subs = [x2.to(x1_[0].device) for x1_ in self._x1_scattered]
self.__cached_x2 = x2
inputs = tuple((x1_[0], x2_) for x1_, x2_ in zip(self._x1_scattered, self._x2_subs))
if not self.device_ids:
return self.module.forward(*inputs, **self._kwargs)
if len(self.device_ids) == 1:
return self.module.forward(*inputs[0], **self._kwargs[0])
# Can't cache the replication because the base kernel module can change every time (e.g. param updates)
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
# TODO: parallel_apply might be too heavyweight in some cases?
with settings.lazily_evaluate_kernels(False):
outputs = self.parallel_apply(replicas, inputs, self._kwargs)
return self.gather(outputs, self.output_device)
def gather(self, outputs, output_device):
return CatLazyTensor(*[lazify(o) for o in outputs], dim=self.dim, output_device=self.output_device)
def size(self, x1, x2):
return self.base_kernel.size(x1, x2)