Skip to content

Commit

Permalink
wrap fp16 to ROCm-supported dtype in amdgpu (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw committed Oct 10, 2021
1 parent b1468bd commit ec6823b
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Tutel MoE: An Optimized Mixture-of-Experts Implementation.

- Supported Framework: Pytorch
- Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32)
- Supported GPUs: CUDA(fp32 + fp16), ROCm(fp32 + fp16)

How to setup Tutel MoE for Pytorch:
```
Expand Down
2 changes: 1 addition & 1 deletion tutel/custom/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#define CHECK_EQ(x, y) AT_ASSERTM((x) == (y), "CHECK_EQ fails.")
#define CHECK_NE(x, y) AT_ASSERTM((x) != (y), "CHECK_NE fails.")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")

static std::string file_read(const char *path) {
Expand Down
8 changes: 5 additions & 3 deletions tutel/impls/fast_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch import Tensor

from .jit_compiler import IS_HIP_EXTENSION
from ..jit_kernels import sparse as jit_kernel

class GatingEncoder(torch.autograd.Function):
Expand Down Expand Up @@ -65,7 +66,8 @@ def __init__(self, num_global_experts, capacity, model_dim, dispatch_dtype):
self.capacity = capacity
self.model_dim = model_dim
self.kernel_pool = dict()
self.dtype = dispatch_dtype
self.dtype = dispatch_dtype if not IS_HIP_EXTENSION else torch.float32
self.original_dtype = dispatch_dtype
self.aligned_dim = model_dim // (2 if self.dtype == torch.float16 else 1)

def update(self, indices_, locations_, gates_, capacity=None):
Expand All @@ -87,9 +89,9 @@ def update(self, indices_, locations_, gates_, capacity=None):
self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper = self.kernel_pool[tuple((sample_size, capacity))]

def encode(self, data):
return GatingEncoder.apply(self, data)
return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype)

def decode(self, data):
return GatingDecoder.apply(self, data, *self.gates_)
return GatingDecoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype)

fast_dispatcher = TutelMoeFastDispatcher
2 changes: 1 addition & 1 deletion tutel/impls/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def forward(self, input: Tensor, **kwargs: Any):
if reshaped_input.size(0) > self.expected_sample_size:
raise Exception('MoE JIT is designed to work on sample size = %s, while receiving sample size = %s (> %s)' % (self.expected_sample_size, reshaped_input.size(0), self.expected_sample_size))
else:
if get_world_rank(expert_group) == 0:
if get_world_rank(self.expert_group) == 0:
print('[WARN] MoE is initialized to keep working on sample size = %s, while receiving sample size = %s (will slow down this forward step)' % (self.expected_sample_size, reshaped_input.size(0)))
pad_input = torch.zeros([self.expected_sample_size, self.model_dim], dtype=reshaped_input.dtype, layout=reshaped_input.layout, device=reshaped_input.device)
pad_input[:reshaped_input.size(0)] = reshaped_input
Expand Down

0 comments on commit ec6823b

Please sign in to comment.