From eca05da451390cd1e5a8058f808db8848da13321 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 7 Apr 2020 12:28:28 -0700 Subject: [PATCH 01/18] Implement stochastic rounding features --- .../ATen/native/cuda/StochasticRounding.cu | 75 ++++++++++ .../native/cuda/StochasticRoundingAdam.cu | 129 ++++++++++++++++++ .../ATen/native/cuda/StochasticRoundingSGD.cu | 99 ++++++++++++++ .../ATen/native/cuda/stochastic_rounding.cuh | 67 +++++++++ aten/src/ATen/native/native_functions.yaml | 13 ++ 5 files changed, 383 insertions(+) create mode 100644 aten/src/ATen/native/cuda/StochasticRounding.cu create mode 100644 aten/src/ATen/native/cuda/StochasticRoundingAdam.cu create mode 100644 aten/src/ATen/native/cuda/StochasticRoundingSGD.cu create mode 100644 aten/src/ATen/native/cuda/stochastic_rounding.cuh diff --git a/aten/src/ATen/native/cuda/StochasticRounding.cu b/aten/src/ATen/native/cuda/StochasticRounding.cu new file mode 100644 index 0000000000000..44bcab3e22c4e --- /dev/null +++ b/aten/src/ATen/native/cuda/StochasticRounding.cu @@ -0,0 +1,75 @@ +#include +#include + +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +namespace at { +namespace native { + +template +__global__ void stochastic_rounding_kernel( + const input_t* input, + output_t* output, + const int64_t numel, + std::pair seed_and_offset) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state); + + for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) { + float inp = static_cast(input[i]); + output[i] = round_stochastically(inp, curand_uniform(&state)); + } +} + +Tensor stochastic_rounding_cuda(const Tensor& input, Tensor& output, Generator gen_) { + TORCH_CHECK(input.numel() > 0 && input.numel() == output.numel()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(output.is_contiguous()); + + const int64_t numel = input.numel(); + const int block = 256; + const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block; + unsigned int grid = (numel + block - 1) / block; + grid = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid); + + auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); + std::pair rng_engine_inputs; + { + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs((numel + block * grid - 1) / (block * grid)); + } + + DISPATCH_FLOAT_AND_HALF( + input.scalar_type(), 0, "round_stochastically_input", + DISPATCH_FLOAT_AND_HALF( + output.scalar_type(), 1, "round_stochastically_output", + stochastic_rounding_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + numel, rng_engine_inputs); + )); + + return output; +} + +} // namespace native +} // namespace at +#undef DISPATCH_FLOAT_AND_HALF diff --git a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu new file mode 100644 index 0000000000000..5bcc3fecb6a73 --- /dev/null +++ b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu @@ -0,0 +1,129 @@ +#include +#include + + +namespace at { +namespace native { + +template +__global__ void stochastic_rounding_adam_step_kernel( + scalar_t *weights, scalar_t *gradients, + scalar_t *exp_avg, scalar_t *exp_avg_sq, scalar_t *max_exp_avg_sq, + float *inv_scale, float *found_inf, + float lr, float beta1, float beta2, + float weight_decay, float eps, int step, + bool is_decoupled, bool is_amsgrad, + int numel, std::pair seeds) { + + if (*found_inf) return; + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seeds.first, tid, seeds.second, &state); + + float m_correction = 1.0 - powf(beta1, step); + float v_correction = 1.0 - powf(beta2, step); + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + float weight = static_cast(weights[i]); + float gradient = static_cast(gradients[i]) * (*inv_scale); + float m = static_cast(exp_avg[i]); + // Stochastic Rounding Adam tracks square root of the exponential average of squared gradient. + float v = static_cast(exp_avg_sq[i]); + v = v * v; + float4 random_values = curand_uniform4(&state); + + if (weight_decay != 0.0f) { + if (is_decoupled) + weight *= (1 - lr * weight_decay); + else + gradient += weight_decay * weight; + } + + // Update m and v. + m = beta1 * m + (1.0 - beta1) * gradient; + v = beta2 * v + (1.0 - beta2) * (gradient * gradient); + + // Unbias v + float max_v = v; + if (is_amsgrad) { + float prev_max_v = static_cast(max_exp_avg_sq[i]); + prev_max_v = prev_max_v * prev_max_v; + max_v = fmaxf(prev_max_v, v); + } + + // Update parameter + weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps); + + // Rounding + // `maybe_square` must not be used in this section. + weights[i] = round_stochastically(weight, random_values.x); + exp_avg[i] = round_stochastically(m, random_values.y); + exp_avg_sq[i] = round_stochastically(sqrtf(v), random_values.z); + if (is_amsgrad) { + max_exp_avg_sq[i] = round_stochastically(sqrtf(max_v), random_values.w); + } + } +} + + +Tensor stochastic_rounding_adam_step_cuda( + Tensor& param, + const Tensor& grad, + Tensor& exp_avg, + Tensor& exp_avg_sq, + Tensor& max_exp_avg_sq, + const Tensor& inv_scale, + const Tensor& found_inf, + double lr, double beta1, double beta2, + double weight_decay, double eps, int64_t step, + bool is_decoupled, bool is_amsgrad, Generator gen_) { + + if (param.numel() == 0) return param; + + TORCH_CHECK(param.is_contiguous()); + TORCH_CHECK(grad.is_contiguous()); + TORCH_CHECK(exp_avg.is_contiguous()); + TORCH_CHECK(exp_avg_sq.is_contiguous()); + TORCH_CHECK(max_exp_avg_sq.is_contiguous()); + + // Based on ATen/native/cuda/Dropout.cu + // link: https://github.com/pytorch/pytorch/blob/c21896327094637cfb83bc0f536e6a442b9877a1/aten/src/ATen/native/cuda/Dropout.cu + const int64_t numel = param.numel(); + const int block_size = 256; + const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); + + auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); + + uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (block_size * grid.x)) * 4; + std::pair rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + param.scalar_type(), "stochastic_rounding_adam_step_kernel", [&] { + stochastic_rounding_adam_step_kernel<<>>( + param.data_ptr(), + grad.data_ptr(), + exp_avg.data_ptr(), + exp_avg_sq.data_ptr(), + max_exp_avg_sq.data_ptr(), + inv_scale.data_ptr(), + found_inf.data_ptr(), + lr, beta1, beta2, weight_decay, eps, step, + is_decoupled, is_amsgrad, + numel, rng_engine_inputs); + } + ); + AT_CUDA_CHECK(cudaGetLastError()); + return param; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu new file mode 100644 index 0000000000000..0c2c9d0b19cfb --- /dev/null +++ b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu @@ -0,0 +1,99 @@ +#include +#include + + +namespace at { +namespace native { + +// SGD update math with Stochastic Rounding +template +__global__ void stochastic_rounding_sgd_step_kernel( + scalar_t *weights, scalar_t *gradients, scalar_t *momentum_buffer, + float* inv_scale, float* found_inf, + float weight_decay, float momentum, float dampening, float lr, + bool nesterov, bool first_run, int numel, std::pair seeds) +{ + + // 1.0 indicates that any gradients contain inf or nan. + // See below about `found_inf`: + // - https://github.com/mcarilli/pytorch/blob/382d02f01d104049179f4f056cc9258caad029af/aten/src/ATen/native/cuda/AmpKernels.cu#L40-L41 + // - https://github.com/mcarilli/pytorch/blob/382d02f01d104049179f4f056cc9258caad029af/aten/src/ATen/native/cuda/AmpKernels.cu#L116-L117 + if (*found_inf) return; + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + curandStatePhilox4_32_10_t state; + curand_init(seeds.first, tid, seeds.second, &state); + + for (int i = tid; i < numel; i += blockDim.x * gridDim.x) { + float weight = static_cast(weights[i]); + float gradient = static_cast(gradients[i]) * (*inv_scale); + float velocity = static_cast(momentum_buffer[i]); + float4 random_values = curand_uniform4(&state); + + if (weight_decay != 0.0f) + gradient += weight_decay * weight; + + if (momentum != 0.0f) { + if (!first_run) + velocity = velocity * momentum + (1.0f - dampening) * gradient; + else + velocity = gradient; + + if (nesterov) + gradient += momentum * velocity; + else + gradient = velocity; + } + + weight -= lr * gradient; + + // Rounding. + weights[i] = round_stochastically(weight, random_values.x); + if (momentum != 0.0f) + momentum_buffer[i] = round_stochastically(velocity, random_values.y); + } +} + +Tensor stochastic_rounding_sgd_step_cuda( + Tensor& param, const Tensor& grad, Tensor& momentum_buffer, + const Tensor& inv_scale, const Tensor& found_inf, + double lr, double momentum, double weight_decay, double dampening, + bool nesterov, bool first_run, Generator gen_) { + + if (param.numel() == 0) return param; + + TORCH_CHECK(param.is_contiguous()); + TORCH_CHECK(grad.is_contiguous()); + TORCH_CHECK(momentum_buffer.is_contiguous()); + + const int64_t numel = param.numel(); + const int block_size = 256; + const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; + dim3 dim_block(block_size); + dim3 grid((numel + block_size - 1) / block_size); + grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); + + auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); + uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (dim_block.x * grid.x)) * 4; + std::pair rng_engine_inputs; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + rng_engine_inputs = gen->philox_engine_inputs(counter_offset); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + param.scalar_type(), "stochastic_rounding_sgd_step_cuda", [&] { + stochastic_rounding_sgd_step_kernel<<>>( + param.data_ptr(), + grad.data_ptr(), + momentum_buffer.data_ptr(), + inv_scale.data_ptr(), found_inf.data_ptr(), + static_cast(weight_decay), static_cast(momentum), static_cast(dampening), static_cast(lr), + nesterov, first_run, numel, rng_engine_inputs); + }); + return param; +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/cuda/stochastic_rounding.cuh b/aten/src/ATen/native/cuda/stochastic_rounding.cuh new file mode 100644 index 0000000000000..a7e1987144179 --- /dev/null +++ b/aten/src/ATen/native/cuda/stochastic_rounding.cuh @@ -0,0 +1,67 @@ +// Ref: https://gitlab.com/riship11/stochastic-rounding +#ifndef _STOCHASTIC_ROUNDING_CUH_ +#define _STOCHASTIC_ROUNDING_CUH_ + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// 2^-10 is the step for normal FP16 numbers. +// 2^-24 is the unit in the last place (ULP)/precision limitation. +// 24 is **NOT** related to the number of mantissa bits of single precision format. +// ref: +// - https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Precision_limitations_on_decimal_values_in_[0,_1] +// - https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Precision_limitations_on_decimal_values_in_[1,_2048] +__device__ const float TWO_10 = 0.0009765625; +__device__ const float TWO_24 = 0.000000059604644775390625; + + +template +__device__ __forceinline__ T maybe_upcast(__half x){ + return T(__half2float(x)); +} + +template<> +__device__ __forceinline__ __half maybe_upcast<__half>(__half x){ + return x; +} + +__device__ __forceinline__ float get_delta_fp16(float x) { + int exponent; + frexpf(x, &exponent); + exponent -= 1; + if (exponent >= -14) + return TWO_10 * std::pow(2, exponent); + else + return TWO_24; +} + +// Natalia magic +template +__device__ __forceinline__ scalar_t round_stochastically(float x, float random_value) { + if (x == 0.0) { + return scalar_t(0.0); + } + float delta = get_delta_fp16(x); + float val; + if (x < 0.0) { + val = x - random_value * delta; + } else { + val = x + random_value * delta; + } + return maybe_upcast(__float2half_rz(val)); +} + +#endif // _STOCHASTIC_ROUNDING_CUH_ diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a37b0af4d7723..e3c5f5f5b96d4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6719,3 +6719,16 @@ # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full + +- func: stochastic_rounding(Tensor input, Tensor(a!) output, Generator? gen_=None) -> Tensor(a!) + dispatch: + CUDA: stochastic_rounding_cuda + +- func: stochastic_rounding_adam_step(Tensor(a!) param, Tensor grad, Tensor(b!) exp_avg, Tensor(c!) exp_avg_sq, Tensor(d!) max_exp_avg_sq, Tensor inv_scale, Tensor found_inf, float lr, float beta1, float beta2, float weight_decay, float eps, int step, bool is_decoupled, bool is_amsgrad, Generator? gen_=None) -> Tensor(a!) + dispatch: + CUDA: stochastic_rounding_adam_step_cuda + +- func: stochastic_rounding_sgd_step(Tensor(a!) param, Tensor grad, Tensor(b!) momentum_buffer, Tensor inv_scale, Tensor found_inf, float lr, float momentum, float weight_decay, float dampening, bool nesterov, bool first_run, Generator? gen_=None) -> Tensor(a!) + dispatch: + CUDA: stochastic_rounding_sgd_step_cuda +>>>>>>> 02db424932... Implement stochastic rounding features From 71a0c2954c36820b5d69f2d271f47535f52db126 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 9 Apr 2020 09:10:36 -0700 Subject: [PATCH 02/18] add PyTorch API & test --- test/test_optim.py | 69 +++++++++++++++++++++++ test/test_stochastic_rounding.py | 25 +++++++++ torch/optim/__init__.py | 6 ++ torch/optim/sradam.py | 91 ++++++++++++++++++++++++++++++ torch/optim/sradamw.py | 96 ++++++++++++++++++++++++++++++++ torch/optim/srsgd.py | 72 ++++++++++++++++++++++++ 6 files changed, 359 insertions(+) create mode 100644 test/test_stochastic_rounding.py create mode 100644 torch/optim/sradam.py create mode 100644 torch/optim/sradamw.py create mode 100644 torch/optim/srsgd.py diff --git a/test/test_optim.py b/test/test_optim.py index b0d502b5ef63e..1fc08169685c3 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1547,5 +1547,74 @@ def test_cosine_then_cyclic(self): self.assertLessEqual(last_lr, max_lr) +class TestStochasticRoundingOptim(TestCase): + + exact_dtype = True + + def _test_basic_cases_template( + self, weight, bias, input, constructor, grad_scaler=None): + + def _calc_loss(weight, bias, input): + y = weight.mv(input) + if y.get_device() != bias.get_device(): + y = y.cuda(bias.get_device()) + return (y + bias).pow(2).sum() + + optimizer = constructor(weight, bias) + initial_value = None + for _i in range(200): + optimizer.zero_grad() + loss = _calc_loss(weight, bias, input) + if grad_scaler is None: + loss.backward() + optimizer.step() + else: + grad_scaler.scale(loss).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + if initial_value is None: + initial_value = loss.item() + # self.assertLess(_calc_loss(weight, bias, input).item(), initial_value) + + # Check whether weight and bias can be represented in 16 bits. + with torch.no_grad(): + for param_group in optimizer.param_groups: + for p in param_group['params']: + half_p = p.clone().detach().to(torch.half).to(weight.dtype) + diff = (p - half_p).abs() + self.assertTrue(torch.equal(diff, torch.zeros_like(diff))) + + def _test_basic_cases(self, dtype, constructor, grad_scaler=None): + self._test_basic_cases_template( + torch.nn.Parameter(torch.randn(10, 5).cuda().to(dtype)), + torch.nn.Parameter(torch.randn(10).cuda().to(dtype)), + torch.randn(5, requires_grad=True).cuda().to(dtype), + constructor, grad_scaler) + + if torch.cuda.device_count() > 1: + self._test_basic_cases_template( + torch.nn.Parameter(torch.randn(10, 5).cuda(0).to(dtype)), + torch.nn.Parameter(torch.randn(10).cuda(1).to(dtype)), + torch.randn(5).cuda(0).to(dtype), + constructor, grad_scaler) + + def test_without_GradScaler(self): + if not torch.cuda.is_available(): + return + for opt in (optim.SRSGD, optim.SRAdam, optim.SRAdamW): + for dtype in (torch.float16, torch.float32, torch.float64): + self._test_basic_cases( + dtype, lambda weight, bias: opt([weight, bias], lr=1e-3), None) + + def test_with_GradScaler(self): + if not torch.cuda.is_available(): + return + for opt in (optim.SRSGD, optim.SRAdam, optim.SRAdamW): + for dtype in (torch.float16, torch.float32): + self._test_basic_cases( + dtype, lambda weight, bias: opt([weight, bias], lr=1e-3), + torch.cuda.amp.GradScaler()) + + if __name__ == '__main__': run_tests() diff --git a/test/test_stochastic_rounding.py b/test/test_stochastic_rounding.py new file mode 100644 index 0000000000000..42affee22a48d --- /dev/null +++ b/test/test_stochastic_rounding.py @@ -0,0 +1,25 @@ +import math + +import torch +import pytest + + +N = 2 ** 14 + + +@pytest.mark.parametrize('scale', tuple(range(-18, 11))) +def test_rs(scale): + + base = math.pow(2, scale) + original_value = (base + math.pow(2, scale + 1)) / 2.0 + .5 * base + x = torch.tensor([original_value] * N).cuda() + _, exponent = math.frexp(original_value) + exponent -= 1 + rounded = torch.zeros((N,)).cuda() + rounded = torch.stochastic_rounding(x, rounded) + + mean = torch.mean(rounded).item() + delta_fp16 = math.pow(2, -10 + exponent if exponent >= -14 else -24) + threshold = 1e-6 + diff = math.fabs(original_value - mean) + assert diff < threshold or diff < delta_fp16 / 2.0 diff --git a/torch/optim/__init__.py b/torch/optim/__init__.py index 20fb9406412e9..714a34b54cda0 100644 --- a/torch/optim/__init__.py +++ b/torch/optim/__init__.py @@ -13,6 +13,9 @@ from .adamax import Adamax from .asgd import ASGD from .sgd import SGD +from .sradam import SRAdam +from .sradamw import SRAdamW +from .srsgd import SRSGD from .rprop import Rprop from .rmsprop import RMSprop from .optimizer import Optimizer @@ -27,6 +30,9 @@ del adamax del asgd del sgd +del sradam +del sradamw +del srsgd del rprop del rmsprop del optimizer diff --git a/torch/optim/sradam.py b/torch/optim/sradam.py new file mode 100644 index 0000000000000..6fd8626374fa9 --- /dev/null +++ b/torch/optim/sradam.py @@ -0,0 +1,91 @@ +import torch +from .adam import Adam + + +# TODO(crcrpar): Decide whether to override `state_dict` method as +# this optimizer tracks square root of `exp_avg_sq` and `max_exp_avg_sq` +# that contradicts with its parent optimizer :class:`torch.optim.Adam`. +class SRAdam(Adam): + r"""Implements Adam algorithm with Stochastic Rounding. + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + With Stochastic Rounding, param, `exp_avg`, `exp_avg_sq`, and optionally `max_exp_avg_sq` + can be represented with 16 bits. + This optimizer requires CUDA. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + _step_supports_amp_scaling = True + + @torch.no_grad() + def step(self,closure=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.no_grad(): + loss = closure() + + if grad_scaler is not None: + found_inf = grad_scaler._check_inf_per_device( + self)[torch.device(torch.cuda.current_device())] + scale = grad_scaler._get_scale_async() + inv_scale = scale.double().reciprocal().float() + else: + found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) + inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + + for group in self.param_groups: + for param in group['params']: + if param.grad is None: + continue + grad = param.grad + if grad.is_sparse: + raise RuntimeError('SRAdam does not support sparse gradients') + + state = self.state[param] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + beta1, beta2 = group['betas'] + + state['step'] += 1 + + torch.stochastic_rounding_adam_step( + param, grad, + state['exp_avg'], state['exp_avg_sq'], state['max_exp_avg_sq'], + inv_scale, found_inf, + group['lr'], beta1, beta2, + group['weight_decay'], group['eps'], state['step'], + False, group['amsgrad']) + + return loss diff --git a/torch/optim/sradamw.py b/torch/optim/sradamw.py new file mode 100644 index 0000000000000..ee8b459632856 --- /dev/null +++ b/torch/optim/sradamw.py @@ -0,0 +1,96 @@ +import torch +from .adamw import AdamW + + +# TODO(crcrpar): Decide whether to override `state_dict` method as +# this optimizer tracks square root of `exp_avg_sq` and `max_exp_avg_sq` +# that contradicts with its parent optimizer :class:`torch.optim.Adam`. +class SRAdamW(AdamW): + r"""Implements AdamW algorithm with Stochastic Rounding. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + With Stochastic Rounding, param, `exp_avg`, `exp_avg_sq`, and optionally `max_exp_avg_sq` + can be represented with 16 bits. + This optimizer requires CUDA. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + _step_supports_amp_scaling = True + + @torch.no_grad() + def step(self, closure=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if grad_scaler is not None: + found_inf = grad_scaler._check_inf_per_device( + self)[torch.device(torch.cuda.current_device())] + scale = grad_scaler._get_scale_async() + inv_scale = scale.double().reciprocal().float() + else: + found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) + inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + + + for group in self.param_groups: + for param in group['params']: + if param.grad is None: + continue + + grad = param.grad + if grad.is_sparse: + raise RuntimeError('SRAdamW does not support sparse gradients') + + state = self.state[param] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + beta1, beta2 = group['betas'] + + state['step'] += 1 + + torch.stochastic_rounding_adam_step( + param, grad, + state['exp_avg'], state['exp_avg_sq'], state['max_exp_avg_sq'], + inv_scale, found_inf, + group['lr'], beta1, beta2, + group['weight_decay'], group['eps'], state['step'], + True, group['amsgrad']) + + return loss diff --git a/torch/optim/srsgd.py b/torch/optim/srsgd.py new file mode 100644 index 0000000000000..51c254ab25f6b --- /dev/null +++ b/torch/optim/srsgd.py @@ -0,0 +1,72 @@ +import torch +from .sgd import SGD + + +class SRSGD(SGD): + r"""Implements stochastic gradient descent with Stochastic Rounding. + + With Stochastic Rounding, param and `momentum_buffer` can be represented with 16 bits. + This optimizer requires CUDA. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + """ + + _step_supports_amp_scaling = True + + @torch.no_grad() + def step(self, closure=None, grad_scaler=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + grad_scaler (:class:`torch.cuda.amp.GradScaler`, optional): + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if grad_scaler is not None: + found_inf = grad_scaler._check_inf_per_device( + self)[torch.device(torch.cuda.current_device())] + scale = grad_scaler._get_scale_async() + inv_scale = scale.double().reciprocal().float() + else: + found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) + inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for param in group['params']: + if param.grad is None: + continue + grad = param.grad + if grad.is_sparse: + raise RuntimeError('SRSGD does not support sparse gradients') + + first_run = False + param_state = self.state[param] + if 'momentum_buffer' not in param_state: + first_run = True + param_state['momentum_buffer'] = torch.zeros_like(param) + momentum_buffer = param_state['momentum_buffer'] + + torch.stochastic_rounding_sgd_step( + param, param.grad, momentum_buffer, + inv_scale, found_inf, + weight_decay, momentum, dampening, group['lr'], + nesterov, first_run) + + return loss From e2112f0a69254176fe86b2a7fe915ab6380af052 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 9 Apr 2020 13:14:09 -0700 Subject: [PATCH 03/18] fix order of arguments --- torch/optim/srsgd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/optim/srsgd.py b/torch/optim/srsgd.py index 51c254ab25f6b..a2ea4e07ae3a9 100644 --- a/torch/optim/srsgd.py +++ b/torch/optim/srsgd.py @@ -64,9 +64,9 @@ def step(self, closure=None, grad_scaler=None): momentum_buffer = param_state['momentum_buffer'] torch.stochastic_rounding_sgd_step( - param, param.grad, momentum_buffer, + param, grad, momentum_buffer, inv_scale, found_inf, - weight_decay, momentum, dampening, group['lr'], + group['lr'], momentum, weight_decay, dampening, nesterov, first_run) return loss From 88b9850abdb0ff3355cad5ef3accb6e7994991c4 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 9 Apr 2020 13:14:26 -0700 Subject: [PATCH 04/18] make tests redundant --- test/test_optim.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/test/test_optim.py b/test/test_optim.py index 1fc08169685c3..c9b3bcf39271e 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1562,7 +1562,7 @@ def _calc_loss(weight, bias, input): optimizer = constructor(weight, bias) initial_value = None - for _i in range(200): + for _i in range(10): optimizer.zero_grad() loss = _calc_loss(weight, bias, input) if grad_scaler is None: @@ -1574,7 +1574,7 @@ def _calc_loss(weight, bias, input): grad_scaler.update() if initial_value is None: initial_value = loss.item() - # self.assertLess(_calc_loss(weight, bias, input).item(), initial_value) + self.assertLess(_calc_loss(weight, bias, input).item(), initial_value) # Check whether weight and bias can be represented in 16 bits. with torch.no_grad(): @@ -1598,22 +1598,32 @@ def _test_basic_cases(self, dtype, constructor, grad_scaler=None): torch.randn(5).cuda(0).to(dtype), constructor, grad_scaler) - def test_without_GradScaler(self): + def _test_without_GradScaler(self, opt): if not torch.cuda.is_available(): return - for opt in (optim.SRSGD, optim.SRAdam, optim.SRAdamW): - for dtype in (torch.float16, torch.float32, torch.float64): - self._test_basic_cases( - dtype, lambda weight, bias: opt([weight, bias], lr=1e-3), None) + for dtype in (torch.float16, torch.float32, torch.float64): + self._test_basic_cases( + dtype, lambda weight, bias: opt([weight, bias], lr=1e-2), None) - def test_with_GradScaler(self): + def _test_with_GradScaler(self, opt): if not torch.cuda.is_available(): return - for opt in (optim.SRSGD, optim.SRAdam, optim.SRAdamW): - for dtype in (torch.float16, torch.float32): - self._test_basic_cases( - dtype, lambda weight, bias: opt([weight, bias], lr=1e-3), - torch.cuda.amp.GradScaler()) + for dtype in (torch.float16, torch.float32): + self._test_basic_cases( + dtype, lambda weight, bias: opt([weight, bias], lr=1e-2), + torch.cuda.amp.GradScaler()) + + def test_SRAdam(self): + self._test_without_GradScaler(optim.SRAdam) + self._test_with_GradScaler(optim.SRAdam) + + def test_SRAdamW(self): + self._test_without_GradScaler(optim.SRAdamW) + self._test_with_GradScaler(optim.SRAdamW) + + def test_SRSGD(self): + self._test_without_GradScaler(optim.SRSGD) + self._test_with_GradScaler(optim.SRSGD) if __name__ == '__main__': From 50cbbe4a6adc83a58a765c1159ad71c7723622b5 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 9 Apr 2020 14:04:01 -0700 Subject: [PATCH 05/18] flake8 --- torch/optim/sradam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/optim/sradam.py b/torch/optim/sradam.py index 6fd8626374fa9..3e38a6a25007a 100644 --- a/torch/optim/sradam.py +++ b/torch/optim/sradam.py @@ -36,7 +36,7 @@ class SRAdam(Adam): _step_supports_amp_scaling = True @torch.no_grad() - def step(self,closure=None, grad_scaler=None): + def step(self, closure=None, grad_scaler=None): """Performs a single optimization step. Arguments: From 46750eec50cb11ab270cdf725ccb6ddd3ad8c003 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 9 Apr 2020 19:16:21 -0700 Subject: [PATCH 06/18] add type stub file --- torch/optim/__init__.pyi | 3 +++ torch/optim/sradam.pyi | 7 +++++++ torch/optim/sradamw.pyi | 7 +++++++ torch/optim/srsgd.pyi | 7 +++++++ 4 files changed, 24 insertions(+) create mode 100644 torch/optim/sradam.pyi create mode 100644 torch/optim/sradamw.pyi create mode 100644 torch/optim/srsgd.pyi diff --git a/torch/optim/__init__.pyi b/torch/optim/__init__.pyi index e82b5821e5ce3..4ead24b2e9089 100644 --- a/torch/optim/__init__.pyi +++ b/torch/optim/__init__.pyi @@ -11,3 +11,6 @@ from .rmsprop import RMSprop from .rprop import Rprop from .sgd import SGD as SGD from .sparse_adam import SparseAdam +from .sradam import SRAdam +from .sradamw import SRAdamW +from .srsgd import SRSGD diff --git a/torch/optim/sradam.pyi b/torch/optim/sradam.pyi new file mode 100644 index 0000000000000..03d6c7d98d4c5 --- /dev/null +++ b/torch/optim/sradam.pyi @@ -0,0 +1,7 @@ +from typing import Callable, Optional, List +from torch.cuda.amp import GradScaler +from .adam import Adam + + +class RSAdam(Adam): + def step(self, closure: Optional[Callable[[], float]]=..., grad_scaler: GradScaler=...) -> Optional[float]: ... diff --git a/torch/optim/sradamw.pyi b/torch/optim/sradamw.pyi new file mode 100644 index 0000000000000..4b9cd350dac1e --- /dev/null +++ b/torch/optim/sradamw.pyi @@ -0,0 +1,7 @@ +from typing import Callable, Optional, List +from torch.cuda.amp import GradScaler +from .adamw import AdamW + + +class RSAdamW(AdamW): + def step(self, closure: Optional[Callable[[], float]]=..., grad_scaler: GradScaler=...) -> Optional[float]: ... diff --git a/torch/optim/srsgd.pyi b/torch/optim/srsgd.pyi new file mode 100644 index 0000000000000..03ea939b0ccb0 --- /dev/null +++ b/torch/optim/srsgd.pyi @@ -0,0 +1,7 @@ +from typing import Callable, Optional, List +from torch.cuda.amp import GradScaler +from .sgd import SGD + + +class RSSGD(SGD): + def step(self, closure: Optional[Callable[[], float]]=..., grad_scaler: GradScaler=...) -> Optional[float]: ... From 17696a2e299b757f2d015a0a3d74c5e0d68c600b Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 15 Apr 2020 09:42:21 -0700 Subject: [PATCH 07/18] simplify stochastic rounding for tensor --- .../ATen/native/cuda/StochasticRounding.cu | 40 +++++-------------- .../native/cuda/StochasticRoundingAdam.cu | 2 +- .../ATen/native/cuda/StochasticRoundingSGD.cu | 1 + aten/src/ATen/native/native_functions.yaml | 2 +- test/test_stochastic_rounding.py | 3 +- 5 files changed, 13 insertions(+), 35 deletions(-) diff --git a/aten/src/ATen/native/cuda/StochasticRounding.cu b/aten/src/ATen/native/cuda/StochasticRounding.cu index 44bcab3e22c4e..cc9442ab64797 100644 --- a/aten/src/ATen/native/cuda/StochasticRounding.cu +++ b/aten/src/ATen/native/cuda/StochasticRounding.cu @@ -1,24 +1,6 @@ #include #include -#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } namespace at { namespace native { @@ -39,10 +21,9 @@ __global__ void stochastic_rounding_kernel( } } -Tensor stochastic_rounding_cuda(const Tensor& input, Tensor& output, Generator gen_) { - TORCH_CHECK(input.numel() > 0 && input.numel() == output.numel()); - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(output.is_contiguous()); +Tensor stochastic_rounding_cuda(const Tensor& input, Generator gen_) { + + Tensor output = at::empty_like(input, input.options().dtype(kHalf), input.suggest_memory_format()); const int64_t numel = input.numel(); const int block = 256; @@ -57,19 +38,16 @@ Tensor stochastic_rounding_cuda(const Tensor& input, Tensor& output, Generator g rng_engine_inputs = gen->philox_engine_inputs((numel + block * grid - 1) / (block * grid)); } - DISPATCH_FLOAT_AND_HALF( - input.scalar_type(), 0, "round_stochastically_input", - DISPATCH_FLOAT_AND_HALF( - output.scalar_type(), 1, "round_stochastically_output", - stochastic_rounding_kernel<<>>( - input.data_ptr(), - output.data_ptr(), + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "stochastic_rounding_cuda", [&] { + stochastic_rounding_kernel<<>>( + input.data_ptr(), + output.data_ptr(), numel, rng_engine_inputs); - )); + }); return output; } } // namespace native } // namespace at -#undef DISPATCH_FLOAT_AND_HALF diff --git a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu index 5bcc3fecb6a73..f42e90ee48500 100644 --- a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu +++ b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu @@ -107,7 +107,7 @@ Tensor stochastic_rounding_adam_step_cuda( } AT_DISPATCH_FLOATING_TYPES_AND_HALF( - param.scalar_type(), "stochastic_rounding_adam_step_kernel", [&] { + param.scalar_type(), "stochastic_rounding_adam_step_cuda", [&] { stochastic_rounding_adam_step_kernel<<>>( param.data_ptr(), grad.data_ptr(), diff --git a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu index 0c2c9d0b19cfb..b11c91644ae0e 100644 --- a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu +++ b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu @@ -92,6 +92,7 @@ Tensor stochastic_rounding_sgd_step_cuda( static_cast(weight_decay), static_cast(momentum), static_cast(dampening), static_cast(lr), nesterov, first_run, numel, rng_engine_inputs); }); + AT_CUDA_CHECK(cudaGetLastError()); return param; } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e3c5f5f5b96d4..1f23ba3f4c415 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6720,7 +6720,7 @@ - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor use_c10_dispatcher: full -- func: stochastic_rounding(Tensor input, Tensor(a!) output, Generator? gen_=None) -> Tensor(a!) +- func: stochastic_rounding(Tensor input, Generator? gen_=None) -> Tensor dispatch: CUDA: stochastic_rounding_cuda diff --git a/test/test_stochastic_rounding.py b/test/test_stochastic_rounding.py index 42affee22a48d..2ed03d06998ee 100644 --- a/test/test_stochastic_rounding.py +++ b/test/test_stochastic_rounding.py @@ -15,8 +15,7 @@ def test_rs(scale): x = torch.tensor([original_value] * N).cuda() _, exponent = math.frexp(original_value) exponent -= 1 - rounded = torch.zeros((N,)).cuda() - rounded = torch.stochastic_rounding(x, rounded) + rounded = torch.stochastic_rounding(x) mean = torch.mean(rounded).item() delta_fp16 = math.pow(2, -10 + exponent if exponent >= -14 else -24) From 2e130ebee3f3fdf20ebdb8db7035cde7667f9c1b Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 15 Apr 2020 22:30:45 -0700 Subject: [PATCH 08/18] make SRAdam(W) compatible with Adam(W) --- test/test_optim.py | 48 +++++++++++++++++++++++++++++++++--------- torch/optim/sradam.py | 30 +++++++++++++++++++++++--- torch/optim/sradamw.py | 28 +++++++++++++++++++++--- 3 files changed, 90 insertions(+), 16 deletions(-) diff --git a/test/test_optim.py b/test/test_optim.py index c9b3bcf39271e..656b8e48c000c 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1547,6 +1547,15 @@ def test_cosine_then_cyclic(self): self.assertLessEqual(last_lr, max_lr) + +def _calc_loss(weight, bias, input): + y = weight.mv(input) + if y.get_device() != bias.get_device(): + y = y.cuda(bias.get_device()) + return (y + bias).pow(2).sum() + + +@unittest.skipIf(not torch.cuda.is_available(), 'No CUDA') class TestStochasticRoundingOptim(TestCase): exact_dtype = True @@ -1554,12 +1563,6 @@ class TestStochasticRoundingOptim(TestCase): def _test_basic_cases_template( self, weight, bias, input, constructor, grad_scaler=None): - def _calc_loss(weight, bias, input): - y = weight.mv(input) - if y.get_device() != bias.get_device(): - y = y.cuda(bias.get_device()) - return (y + bias).pow(2).sum() - optimizer = constructor(weight, bias) initial_value = None for _i in range(10): @@ -1599,15 +1602,11 @@ def _test_basic_cases(self, dtype, constructor, grad_scaler=None): constructor, grad_scaler) def _test_without_GradScaler(self, opt): - if not torch.cuda.is_available(): - return for dtype in (torch.float16, torch.float32, torch.float64): self._test_basic_cases( dtype, lambda weight, bias: opt([weight, bias], lr=1e-2), None) def _test_with_GradScaler(self, opt): - if not torch.cuda.is_available(): - return for dtype in (torch.float16, torch.float32): self._test_basic_cases( dtype, lambda weight, bias: opt([weight, bias], lr=1e-2), @@ -1625,6 +1624,35 @@ def test_SRSGD(self): self._test_without_GradScaler(optim.SRSGD) self._test_with_GradScaler(optim.SRSGD) + def _prepare_optimizer(self, opt, update=False): + weight = torch.nn.Parameter(torch.randn(10, 5).cuda()) + bias = torch.nn.Parameter(torch.randn(10).cuda()) + optimizer = opt([weight, bias], lr=1e-2) + if not update: + return optimizer + input = torch.randn(5).cuda() + + optimizer.zero_grad() + _calc_loss(weight, bias, input).backward() + optimizer.step() + optimizer.zero_grad() + + return optimizer + + def _test_state_dict(self, opt_1, opt_2): + optimizer = self._prepare_optimizer(opt_1) + optimizer.load_state_dict(optimizer.state_dict()) + optimizer2 = self._prepare_optimizer(opt_2, False) + optimizer2.load_state_dict(optimizer.state_dict()) + optimizer2 = self._prepare_optimizer(opt_2, False) + optimizer.load_state_dict(optimizer2.state_dict()) + self._prepare_optimizer(opt_1, False).load_state_dict(optimizer2.state_dict()) + + def test_state_dict_compatibility(self): + self._test_state_dict(optim.SRSGD, optim.SGD) + self._test_state_dict(optim.SRAdam, optim.Adam) + self._test_state_dict(optim.SRAdamW, optim.AdamW) + if __name__ == '__main__': run_tests() diff --git a/torch/optim/sradam.py b/torch/optim/sradam.py index 3e38a6a25007a..08c5161201322 100644 --- a/torch/optim/sradam.py +++ b/torch/optim/sradam.py @@ -1,10 +1,28 @@ +import copy + import torch from .adam import Adam -# TODO(crcrpar): Decide whether to override `state_dict` method as -# this optimizer tracks square root of `exp_avg_sq` and `max_exp_avg_sq` -# that contradicts with its parent optimizer :class:`torch.optim.Adam`. +def _apply_square_to_state_dict(state_dict): + with torch.no_grad(): + for state_per_param in state_dict['state'].values(): + state_per_param['exp_avg_sq'].square_() + state_per_param['max_exp_avg_sq'].square() + return state_dict + + +def _apply_sqrt_to_state_dict(state_dict): + with torch.no_grad(): + for state_per_param in state_dict['state'].values(): + state_per_param['exp_avg_sq'].sqrt_() + if 'max_exp_avg_sq' not in state_per_param: + state_per_param['max_exp_avg_sq'] = torch.zeros_like(state_per_param['exp_avg_sq']) + else: + state_per_param['max_exp_avg_sq'].sqrt_() + return state_dict + + class SRAdam(Adam): r"""Implements Adam algorithm with Stochastic Rounding. @@ -35,6 +53,12 @@ class SRAdam(Adam): _step_supports_amp_scaling = True + def state_dict(self): + return _apply_square_to_state_dict(super().state_dict()) + + def load_state_dict(self, state_dict): + super().load_state_dict(_apply_sqrt_to_state_dict(state_dict)) + @torch.no_grad() def step(self, closure=None, grad_scaler=None): """Performs a single optimization step. diff --git a/torch/optim/sradamw.py b/torch/optim/sradamw.py index ee8b459632856..1a7d67e71657b 100644 --- a/torch/optim/sradamw.py +++ b/torch/optim/sradamw.py @@ -2,9 +2,25 @@ from .adamw import AdamW -# TODO(crcrpar): Decide whether to override `state_dict` method as -# this optimizer tracks square root of `exp_avg_sq` and `max_exp_avg_sq` -# that contradicts with its parent optimizer :class:`torch.optim.Adam`. +def _apply_square_to_state_dict(state_dict): + with torch.no_grad(): + for state_per_param in state_dict['state'].values(): + state_per_param['exp_avg_sq'].square_() + state_per_param['max_exp_avg_sq'].square() + return state_dict + + +def _apply_sqrt_to_state_dict(state_dict): + with torch.no_grad(): + for state_per_param in state_dict['state'].values(): + state_per_param['exp_avg_sq'].sqrt_() + if 'max_exp_avg_sq' not in state_per_param: + state_per_param['max_exp_avg_sq'] = torch.zeros_like(state_per_param['exp_avg_sq']) + else: + state_per_param['max_exp_avg_sq'].sqrt_() + return state_dict + + class SRAdamW(AdamW): r"""Implements AdamW algorithm with Stochastic Rounding. @@ -38,6 +54,12 @@ class SRAdamW(AdamW): _step_supports_amp_scaling = True + def state_dict(self): + return _apply_square_to_state_dict(super().state_dict()) + + def load_state_dict(self, state_dict): + super().load_state_dict(_apply_sqrt_to_state_dict(state_dict)) + @torch.no_grad() def step(self, closure=None, grad_scaler=None): """Performs a single optimization step. From 907f568259f1b352a804330bc679d55acc2f01cc Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 15 Apr 2020 22:37:06 -0700 Subject: [PATCH 09/18] cosmetic --- aten/src/ATen/native/cuda/StochasticRoundingAdam.cu | 6 ------ aten/src/ATen/native/cuda/StochasticRoundingSGD.cu | 9 +-------- aten/src/ATen/native/cuda/stochastic_rounding.cuh | 4 ---- 3 files changed, 1 insertion(+), 18 deletions(-) diff --git a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu index f42e90ee48500..19ade8bd65cf1 100644 --- a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu +++ b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu @@ -52,11 +52,8 @@ __global__ void stochastic_rounding_adam_step_kernel( max_v = fmaxf(prev_max_v, v); } - // Update parameter weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps); - // Rounding - // `maybe_square` must not be used in this section. weights[i] = round_stochastically(weight, random_values.x); exp_avg[i] = round_stochastically(m, random_values.y); exp_avg_sq[i] = round_stochastically(sqrtf(v), random_values.z); @@ -87,8 +84,6 @@ Tensor stochastic_rounding_adam_step_cuda( TORCH_CHECK(exp_avg_sq.is_contiguous()); TORCH_CHECK(max_exp_avg_sq.is_contiguous()); - // Based on ATen/native/cuda/Dropout.cu - // link: https://github.com/pytorch/pytorch/blob/c21896327094637cfb83bc0f536e6a442b9877a1/aten/src/ATen/native/cuda/Dropout.cu const int64_t numel = param.numel(); const int block_size = 256; const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size; @@ -101,7 +96,6 @@ Tensor stochastic_rounding_adam_step_cuda( uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (block_size * grid.x)) * 4; std::pair rng_engine_inputs; { - // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); rng_engine_inputs = gen->philox_engine_inputs(counter_offset); } diff --git a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu index b11c91644ae0e..aa3a3adb9c536 100644 --- a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu +++ b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu @@ -11,13 +11,8 @@ __global__ void stochastic_rounding_sgd_step_kernel( scalar_t *weights, scalar_t *gradients, scalar_t *momentum_buffer, float* inv_scale, float* found_inf, float weight_decay, float momentum, float dampening, float lr, - bool nesterov, bool first_run, int numel, std::pair seeds) -{ + bool nesterov, bool first_run, int numel, std::pair seeds) { - // 1.0 indicates that any gradients contain inf or nan. - // See below about `found_inf`: - // - https://github.com/mcarilli/pytorch/blob/382d02f01d104049179f4f056cc9258caad029af/aten/src/ATen/native/cuda/AmpKernels.cu#L40-L41 - // - https://github.com/mcarilli/pytorch/blob/382d02f01d104049179f4f056cc9258caad029af/aten/src/ATen/native/cuda/AmpKernels.cu#L116-L117 if (*found_inf) return; int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -47,7 +42,6 @@ __global__ void stochastic_rounding_sgd_step_kernel( weight -= lr * gradient; - // Rounding. weights[i] = round_stochastically(weight, random_values.x); if (momentum != 0.0f) momentum_buffer[i] = round_stochastically(velocity, random_values.y); @@ -77,7 +71,6 @@ Tensor stochastic_rounding_sgd_step_cuda( uint64_t counter_offset = ((numel + dim_block.x * grid.x - 1) / (dim_block.x * grid.x)) * 4; std::pair rng_engine_inputs; { - // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); rng_engine_inputs = gen->philox_engine_inputs(counter_offset); } diff --git a/aten/src/ATen/native/cuda/stochastic_rounding.cuh b/aten/src/ATen/native/cuda/stochastic_rounding.cuh index a7e1987144179..520ed4e07a400 100644 --- a/aten/src/ATen/native/cuda/stochastic_rounding.cuh +++ b/aten/src/ATen/native/cuda/stochastic_rounding.cuh @@ -1,4 +1,3 @@ -// Ref: https://gitlab.com/riship11/stochastic-rounding #ifndef _STOCHASTIC_ROUNDING_CUH_ #define _STOCHASTIC_ROUNDING_CUH_ @@ -21,9 +20,6 @@ // 2^-10 is the step for normal FP16 numbers. // 2^-24 is the unit in the last place (ULP)/precision limitation. // 24 is **NOT** related to the number of mantissa bits of single precision format. -// ref: -// - https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Precision_limitations_on_decimal_values_in_[0,_1] -// - https://en.wikipedia.org/wiki/Half-precision_floating-point_format#Precision_limitations_on_decimal_values_in_[1,_2048] __device__ const float TWO_10 = 0.0009765625; __device__ const float TWO_24 = 0.000000059604644775390625; From 31bfb75c1151c2d3950007b87af093ec510f6779 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 16 Apr 2020 09:42:02 -0700 Subject: [PATCH 10/18] Add comment about test condition --- test/test_stochastic_rounding.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_stochastic_rounding.py b/test/test_stochastic_rounding.py index 2ed03d06998ee..321a60666e645 100644 --- a/test/test_stochastic_rounding.py +++ b/test/test_stochastic_rounding.py @@ -21,4 +21,8 @@ def test_rs(scale): delta_fp16 = math.pow(2, -10 + exponent if exponent >= -14 else -24) threshold = 1e-6 diff = math.fabs(original_value - mean) + + # The right condition of `diff < delta_fp16 / 2.0` is for larger `original_value`. + # The larger `original_value` is, the larger `delta_fp16` is. So, no matter how many elements + # we prepare, it's difficult to guarantee that `mean` is close enough the original value. assert diff < threshold or diff < delta_fp16 / 2.0 From ae652ea6d9ad54ab49de77a05af88ef6eb0bb86b Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 17 Apr 2020 10:33:45 -0700 Subject: [PATCH 11/18] use pragma once --- aten/src/ATen/native/cuda/stochastic_rounding.cuh | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/stochastic_rounding.cuh b/aten/src/ATen/native/cuda/stochastic_rounding.cuh index 520ed4e07a400..b59ca5ec29790 100644 --- a/aten/src/ATen/native/cuda/stochastic_rounding.cuh +++ b/aten/src/ATen/native/cuda/stochastic_rounding.cuh @@ -1,5 +1,4 @@ -#ifndef _STOCHASTIC_ROUNDING_CUH_ -#define _STOCHASTIC_ROUNDING_CUH_ +#pragma once #include #include @@ -59,5 +58,3 @@ __device__ __forceinline__ scalar_t round_stochastically(float x, float random_v } return maybe_upcast(__float2half_rz(val)); } - -#endif // _STOCHASTIC_ROUNDING_CUH_ From b765059d1d43c54769b09f281c12c28ef785957d Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 17 Apr 2020 10:35:33 -0700 Subject: [PATCH 12/18] register stochastic rounding test --- test/run_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_test.py b/test/run_test.py index a90381134f91d..e35956e49e73f 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -79,6 +79,7 @@ 'test_overrides', 'test_jit_fuser_te', 'test_tensorexpr', + 'test_stochastic_rounding', ] # skip < 3.3 because mock is added in 3.3 and is used in rpc_spawn From 6f7a93a4a6134174539ac5ec20d5dffd13e7d099 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 17 Apr 2020 10:36:22 -0700 Subject: [PATCH 13/18] return input when dtype is Half --- aten/src/ATen/native/cuda/StochasticRounding.cu | 13 +++++++++++-- test/test_stochastic_rounding.py | 8 +++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cuda/StochasticRounding.cu b/aten/src/ATen/native/cuda/StochasticRounding.cu index cc9442ab64797..97993787e5aeb 100644 --- a/aten/src/ATen/native/cuda/StochasticRounding.cu +++ b/aten/src/ATen/native/cuda/StochasticRounding.cu @@ -23,9 +23,18 @@ __global__ void stochastic_rounding_kernel( Tensor stochastic_rounding_cuda(const Tensor& input, Generator gen_) { - Tensor output = at::empty_like(input, input.options().dtype(kHalf), input.suggest_memory_format()); + TORCH_CHECK(input.is_contiguous()); + + if (input.scalar_type() == kHalf) { + return input; + } + Tensor output = at::empty_like(input, input.options().dtype(kHalf), input.suggest_memory_format()); const int64_t numel = input.numel(); + if (numel == 0) { + return output; + } + const int block = 256; const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block; unsigned int grid = (numel + block - 1) / block; @@ -38,7 +47,7 @@ Tensor stochastic_rounding_cuda(const Tensor& input, Generator gen_) { rng_engine_inputs = gen->philox_engine_inputs((numel + block * grid - 1) / (block * grid)); } - AT_DISPATCH_FLOATING_TYPES_AND_HALF( + AT_DISPATCH_FLOATING_TYPES( input.scalar_type(), "stochastic_rounding_cuda", [&] { stochastic_rounding_kernel<<>>( input.data_ptr(), diff --git a/test/test_stochastic_rounding.py b/test/test_stochastic_rounding.py index 321a60666e645..f2c43f98f78d3 100644 --- a/test/test_stochastic_rounding.py +++ b/test/test_stochastic_rounding.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize('scale', tuple(range(-18, 11))) -def test_rs(scale): +def test_stochastic_rounding(scale): base = math.pow(2, scale) original_value = (base + math.pow(2, scale + 1)) / 2.0 + .5 * base @@ -26,3 +26,9 @@ def test_rs(scale): # The larger `original_value` is, the larger `delta_fp16` is. So, no matter how many elements # we prepare, it's difficult to guarantee that `mean` is close enough the original value. assert diff < threshold or diff < delta_fp16 / 2.0 + + +def test_stochastic_rounding_half(): + x = torch.randn((32, 32)).cuda().half() + y = torch.stochastic_rounding(x) + assert torch.eq(x, y).all() From 9159b031a528dc7dacea54a95814796bf4628569 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 17 Apr 2020 10:52:13 -0700 Subject: [PATCH 14/18] Update documentation - add docstring to `torch.stochastic_rounding` - refer `torch.stochastic_rounding` from stochastic rounding optimizers for details --- docs/source/optim.rst | 6 ++++++ docs/source/torch.rst | 1 + torch/_torch_docs.py | 23 +++++++++++++++++++++++ torch/optim/sradam.py | 2 +- torch/optim/sradamw.py | 2 +- torch/optim/srsgd.py | 2 +- 6 files changed, 33 insertions(+), 3 deletions(-) diff --git a/docs/source/optim.rst b/docs/source/optim.rst index f09685bd53d5c..dfa556b24bcad 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -129,6 +129,12 @@ Algorithms :members: .. autoclass:: SGD :members: +.. autoclass:: SRAdam + :members: +.. autoclass:: SRAdamW + :members: +.. autoclass:: SRSGD + :members: How to adjust learning rate --------------------------- diff --git a/docs/source/torch.rst b/docs/source/torch.rst index d5be6aee2207b..6a515b0c3b71f 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -249,6 +249,7 @@ Pointwise Ops .. autofunction:: sinh .. autofunction:: sqrt .. autofunction:: square +.. autofunction:: stochastic_rounding .. autofunction:: tan .. autofunction:: tanh .. autofunction:: true_divide diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 5da13158a1a38..83baf1e9d94be 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7544,6 +7544,29 @@ def merge_dicts(*dicts): [100, 200]], dtype=torch.uint8) """) +add_docstr(torch.stochastic_rounding, + r""" +stochastic_rounding(input, generator=None) -> Tensor + +Rounds a tensor to half stochastically. If the dtype of :attr:`input` is Half, +this is equivalent to noop. This function supports only CUDA tensor. +For a floating-point number :attr:`x` and there are two close half values :attr:`y` and :attr:`z`. +Then :attr:`x` is rounded to :attr:`y` (:attr:`z`) with the probability of +:math:`\dfrac{| x - z |}{| y - z |}` (:math:`\dfrac{| x - y |}{| y - z |}`). + +See `Deep learning with limited numerical precision`_ for further details. + +.. _Deep learning with limited numerical precision: https://dl.acm.org/doi/10.5555/3045118.3045303 + +Args: + input (Tensor): float tensor to round stochastically + generator (Generator, optional): A torch.Generator object + +Returns: + Tensor: A stochastically rounded half tensor + +""") + add_docstr(torch._C.Generator, r""" Generator(device='cpu') -> Generator diff --git a/torch/optim/sradam.py b/torch/optim/sradam.py index 08c5161201322..e69b09e067830 100644 --- a/torch/optim/sradam.py +++ b/torch/optim/sradam.py @@ -29,7 +29,7 @@ class SRAdam(Adam): It has been proposed in `Adam: A Method for Stochastic Optimization`_. With Stochastic Rounding, param, `exp_avg`, `exp_avg_sq`, and optionally `max_exp_avg_sq` - can be represented with 16 bits. + can be represented with 16 bits. See :func:`torch.stochastic_rounding` for details. This optimizer requires CUDA. Arguments: diff --git a/torch/optim/sradamw.py b/torch/optim/sradamw.py index 1a7d67e71657b..6dbe6daf6bb07 100644 --- a/torch/optim/sradamw.py +++ b/torch/optim/sradamw.py @@ -28,7 +28,7 @@ class SRAdamW(AdamW): The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. With Stochastic Rounding, param, `exp_avg`, `exp_avg_sq`, and optionally `max_exp_avg_sq` - can be represented with 16 bits. + can be represented with 16 bits. See :func:`torch.stochastic_rounding` for details. This optimizer requires CUDA. Arguments: diff --git a/torch/optim/srsgd.py b/torch/optim/srsgd.py index a2ea4e07ae3a9..9093898ac3829 100644 --- a/torch/optim/srsgd.py +++ b/torch/optim/srsgd.py @@ -6,7 +6,7 @@ class SRSGD(SGD): r"""Implements stochastic gradient descent with Stochastic Rounding. With Stochastic Rounding, param and `momentum_buffer` can be represented with 16 bits. - This optimizer requires CUDA. + See :func:`torch.stochastic_rounding` for details. This optimizer requires CUDA. Args: params (iterable): iterable of parameters to optimize or dicts defining From 59523c8906dfc1c4a7208cadab3614055505819c Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 17 Apr 2020 12:18:49 -0700 Subject: [PATCH 15/18] handle non-current devices --- torch/optim/sradam.py | 11 ++++++----- torch/optim/sradamw.py | 8 +++++--- torch/optim/srsgd.py | 9 ++++++--- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/torch/optim/sradam.py b/torch/optim/sradam.py index e69b09e067830..6976f3ad032ad 100644 --- a/torch/optim/sradam.py +++ b/torch/optim/sradam.py @@ -1,6 +1,5 @@ -import copy - import torch +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator from .adam import Adam @@ -73,14 +72,16 @@ def step(self, closure=None, grad_scaler=None): loss = closure() if grad_scaler is not None: - found_inf = grad_scaler._check_inf_per_device( - self)[torch.device(torch.cuda.current_device())] + found_inf = grad_scaler._check_inf_per_device(self) scale = grad_scaler._get_scale_async() inv_scale = scale.double().reciprocal().float() else: found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator(found_inf) inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + inv_scale = _MultiDeviceReplicator(inv_scale) + for group in self.param_groups: for param in group['params']: if param.grad is None: @@ -107,7 +108,7 @@ def step(self, closure=None, grad_scaler=None): torch.stochastic_rounding_adam_step( param, grad, state['exp_avg'], state['exp_avg_sq'], state['max_exp_avg_sq'], - inv_scale, found_inf, + inv_scale.get(param.device), found_inf.get(param.device), group['lr'], beta1, beta2, group['weight_decay'], group['eps'], state['step'], False, group['amsgrad']) diff --git a/torch/optim/sradamw.py b/torch/optim/sradamw.py index 6dbe6daf6bb07..e6f4ba4bc26f9 100644 --- a/torch/optim/sradamw.py +++ b/torch/optim/sradamw.py @@ -1,4 +1,5 @@ import torch +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator from .adamw import AdamW @@ -74,14 +75,15 @@ def step(self, closure=None, grad_scaler=None): loss = closure() if grad_scaler is not None: - found_inf = grad_scaler._check_inf_per_device( - self)[torch.device(torch.cuda.current_device())] + found_inf = grad_scaler._check_inf_per_device(self) scale = grad_scaler._get_scale_async() inv_scale = scale.double().reciprocal().float() else: found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator(found_inf) inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + inv_scale = _MultiDeviceReplicator(inv_scale) for group in self.param_groups: for param in group['params']: @@ -110,7 +112,7 @@ def step(self, closure=None, grad_scaler=None): torch.stochastic_rounding_adam_step( param, grad, state['exp_avg'], state['exp_avg_sq'], state['max_exp_avg_sq'], - inv_scale, found_inf, + inv_scale.get(param.device), found_inf.get(param.device), group['lr'], beta1, beta2, group['weight_decay'], group['eps'], state['step'], True, group['amsgrad']) diff --git a/torch/optim/srsgd.py b/torch/optim/srsgd.py index 9093898ac3829..ead2267c56f86 100644 --- a/torch/optim/srsgd.py +++ b/torch/optim/srsgd.py @@ -1,4 +1,5 @@ import torch +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator from .sgd import SGD @@ -35,14 +36,16 @@ def step(self, closure=None, grad_scaler=None): loss = closure() if grad_scaler is not None: - found_inf = grad_scaler._check_inf_per_device( - self)[torch.device(torch.cuda.current_device())] + found_inf = grad_scaler._check_inf_per_device(self) scale = grad_scaler._get_scale_async() inv_scale = scale.double().reciprocal().float() else: found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator(found_inf) inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + inv_scale = _MultiDeviceReplicator(inv_scale) + for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] @@ -65,7 +68,7 @@ def step(self, closure=None, grad_scaler=None): torch.stochastic_rounding_sgd_step( param, grad, momentum_buffer, - inv_scale, found_inf, + inv_scale.get(param.device), found_inf.get(param.device), group['lr'], momentum, weight_decay, dampening, nesterov, first_run) From 8ea3245f88544a127f5e20603db39d56d7ec8a78 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 18 Apr 2020 16:28:53 -0700 Subject: [PATCH 16/18] Generator -> c10::optional --- aten/src/ATen/native/cuda/StochasticRounding.cu | 2 +- aten/src/ATen/native/cuda/StochasticRoundingAdam.cu | 2 +- aten/src/ATen/native/cuda/StochasticRoundingSGD.cu | 2 +- aten/src/ATen/native/native_functions.yaml | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/cuda/StochasticRounding.cu b/aten/src/ATen/native/cuda/StochasticRounding.cu index 97993787e5aeb..4305b95eb2d7d 100644 --- a/aten/src/ATen/native/cuda/StochasticRounding.cu +++ b/aten/src/ATen/native/cuda/StochasticRounding.cu @@ -21,7 +21,7 @@ __global__ void stochastic_rounding_kernel( } } -Tensor stochastic_rounding_cuda(const Tensor& input, Generator gen_) { +Tensor stochastic_rounding_cuda(const Tensor& input, c10::optional gen_) { TORCH_CHECK(input.is_contiguous()); diff --git a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu index 19ade8bd65cf1..74e6b9bd471da 100644 --- a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu +++ b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu @@ -74,7 +74,7 @@ Tensor stochastic_rounding_adam_step_cuda( const Tensor& found_inf, double lr, double beta1, double beta2, double weight_decay, double eps, int64_t step, - bool is_decoupled, bool is_amsgrad, Generator gen_) { + bool is_decoupled, bool is_amsgrad, c10::optional gen_) { if (param.numel() == 0) return param; diff --git a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu index aa3a3adb9c536..15a584cc0b170 100644 --- a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu +++ b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu @@ -52,7 +52,7 @@ Tensor stochastic_rounding_sgd_step_cuda( Tensor& param, const Tensor& grad, Tensor& momentum_buffer, const Tensor& inv_scale, const Tensor& found_inf, double lr, double momentum, double weight_decay, double dampening, - bool nesterov, bool first_run, Generator gen_) { + bool nesterov, bool first_run, c10::optional gen_) { if (param.numel() == 0) return param; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1f23ba3f4c415..59c64ca90b28f 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6731,4 +6731,3 @@ - func: stochastic_rounding_sgd_step(Tensor(a!) param, Tensor grad, Tensor(b!) momentum_buffer, Tensor inv_scale, Tensor found_inf, float lr, float momentum, float weight_decay, float dampening, bool nesterov, bool first_run, Generator? gen_=None) -> Tensor(a!) dispatch: CUDA: stochastic_rounding_sgd_step_cuda ->>>>>>> 02db424932... Implement stochastic rounding features From ccab44623ccbca523bb5a14a4ee241f9d48460e6 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 27 Apr 2020 18:49:07 -0700 Subject: [PATCH 17/18] combine `found_inf` tensors from all devices --- torch/optim/_amp_helper.py | 18 ++++++++++++++++++ torch/optim/_amp_helper.pyi | 8 ++++++++ torch/optim/sradam.py | 10 +++++----- torch/optim/sradamw.py | 10 +++++----- torch/optim/srsgd.py | 10 +++++----- 5 files changed, 41 insertions(+), 15 deletions(-) create mode 100644 torch/optim/_amp_helper.py create mode 100644 torch/optim/_amp_helper.pyi diff --git a/torch/optim/_amp_helper.py b/torch/optim/_amp_helper.py new file mode 100644 index 0000000000000..606a879ee8153 --- /dev/null +++ b/torch/optim/_amp_helper.py @@ -0,0 +1,18 @@ +import torch +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator + + +def _combined_found_inf_helper(optimizer, grad_scaler, device): + + found_inf_dict = grad_scaler._check_inf_per_device(optimizer) + # Combines found_inf tensors from all devices. As in GradScaler.update(), + # tensors are combined on the scale's device, which is an arbitrary but + # reasonable choice that avoids new context creation. + found_infs = [f.to(device, non_blocking=True) for f in found_inf_dict.values()] + assert len(found_infs) > 0, "No inf checks were recorded in _check_inf_per_device." + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + with torch.no_grad(): + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + return _MultiDeviceReplicator(found_inf_combined) diff --git a/torch/optim/_amp_helper.pyi b/torch/optim/_amp_helper.pyi new file mode 100644 index 0000000000000..65c0ca5ab0a8c --- /dev/null +++ b/torch/optim/_amp_helper.pyi @@ -0,0 +1,8 @@ +import torch +from torch.optim.optimizer import Optimizer +from torch.cuda.amp.grad_scaler import GradScaler +from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator + + +def _combined_found_inf_helper( + optimizer: Optimizer, grad_scaler: GradScaler, device: torch.Device) -> _MultiDeviceReplicator diff --git a/torch/optim/sradam.py b/torch/optim/sradam.py index 6976f3ad032ad..7bed8af84e029 100644 --- a/torch/optim/sradam.py +++ b/torch/optim/sradam.py @@ -1,6 +1,7 @@ import torch from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator from .adam import Adam +from ._amp_helper import _combined_found_inf_helper def _apply_square_to_state_dict(state_dict): @@ -72,13 +73,12 @@ def step(self, closure=None, grad_scaler=None): loss = closure() if grad_scaler is not None: - found_inf = grad_scaler._check_inf_per_device(self) - scale = grad_scaler._get_scale_async() - inv_scale = scale.double().reciprocal().float() + inv_scale = grad_scaler._get_scale_async().double().reciprocal().float() + found_inf = _combined_found_inf_helper(self, grad_scaler, inv_scale.device) else: - found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) - found_inf = _MultiDeviceReplicator(found_inf) inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator( + torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device())) inv_scale = _MultiDeviceReplicator(inv_scale) diff --git a/torch/optim/sradamw.py b/torch/optim/sradamw.py index e6f4ba4bc26f9..0a3e4279ea595 100644 --- a/torch/optim/sradamw.py +++ b/torch/optim/sradamw.py @@ -1,6 +1,7 @@ import torch from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator from .adamw import AdamW +from ._amp_helper import _combined_found_inf_helper def _apply_square_to_state_dict(state_dict): @@ -75,13 +76,12 @@ def step(self, closure=None, grad_scaler=None): loss = closure() if grad_scaler is not None: - found_inf = grad_scaler._check_inf_per_device(self) - scale = grad_scaler._get_scale_async() - inv_scale = scale.double().reciprocal().float() + inv_scale = grad_scaler._get_scale_async().double().reciprocal().float() + found_inf = _combined_found_inf_helper(self, grad_scaler, inv_scale.device) else: - found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) - found_inf = _MultiDeviceReplicator(found_inf) inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator( + torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device())) inv_scale = _MultiDeviceReplicator(inv_scale) diff --git a/torch/optim/srsgd.py b/torch/optim/srsgd.py index ead2267c56f86..16d6588c153f4 100644 --- a/torch/optim/srsgd.py +++ b/torch/optim/srsgd.py @@ -1,6 +1,7 @@ import torch from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator from .sgd import SGD +from ._amp_helper import _combined_found_inf_helper class SRSGD(SGD): @@ -36,13 +37,12 @@ def step(self, closure=None, grad_scaler=None): loss = closure() if grad_scaler is not None: - found_inf = grad_scaler._check_inf_per_device(self) - scale = grad_scaler._get_scale_async() - inv_scale = scale.double().reciprocal().float() + inv_scale = grad_scaler._get_scale_async().double().reciprocal().float() + found_inf = _combined_found_inf_helper(self, grad_scaler, inv_scale.device) else: - found_inf = torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device()) - found_inf = _MultiDeviceReplicator(found_inf) inv_scale = torch.ones((1,), dtype=torch.float, device=torch.cuda.current_device()) + found_inf = _MultiDeviceReplicator( + torch.zeros((1,), dtype=torch.float, device=torch.cuda.current_device())) inv_scale = _MultiDeviceReplicator(inv_scale) From 31bd5733560235c5710155284d3e10462355f832 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 5 May 2020 23:00:10 -0700 Subject: [PATCH 18/18] Make round_stochastically flexible for future another dtype support --- .../ATen/native/cuda/StochasticRounding.cu | 5 +-- .../native/cuda/StochasticRoundingAdam.cu | 10 +++--- .../ATen/native/cuda/StochasticRoundingSGD.cu | 6 ++-- .../ATen/native/cuda/stochastic_rounding.cuh | 33 +++++++++++-------- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/aten/src/ATen/native/cuda/StochasticRounding.cu b/aten/src/ATen/native/cuda/StochasticRounding.cu index 4305b95eb2d7d..8df7b20f0560d 100644 --- a/aten/src/ATen/native/cuda/StochasticRounding.cu +++ b/aten/src/ATen/native/cuda/StochasticRounding.cu @@ -15,9 +15,10 @@ __global__ void stochastic_rounding_kernel( curandStatePhilox4_32_10_t state; curand_init(seed_and_offset.first, tid, seed_and_offset.second, &state); + round_stochastically rounder; + for (int64_t i = tid; i < numel; i += blockDim.x * gridDim.x) { - float inp = static_cast(input[i]); - output[i] = round_stochastically(inp, curand_uniform(&state)); + output[i] = rounder(input[i], curand_uniform(&state)); } } diff --git a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu index 74e6b9bd471da..c06b2bafa4ab1 100644 --- a/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu +++ b/aten/src/ATen/native/cuda/StochasticRoundingAdam.cu @@ -21,6 +21,8 @@ __global__ void stochastic_rounding_adam_step_kernel( curandStatePhilox4_32_10_t state; curand_init(seeds.first, tid, seeds.second, &state); + round_stochastically rounder; + float m_correction = 1.0 - powf(beta1, step); float v_correction = 1.0 - powf(beta2, step); @@ -54,11 +56,11 @@ __global__ void stochastic_rounding_adam_step_kernel( weight -= (lr / m_correction) * m / (sqrtf(max_v / v_correction) + eps); - weights[i] = round_stochastically(weight, random_values.x); - exp_avg[i] = round_stochastically(m, random_values.y); - exp_avg_sq[i] = round_stochastically(sqrtf(v), random_values.z); + weights[i] = rounder(weight, random_values.x); + exp_avg[i] = rounder(m, random_values.y); + exp_avg_sq[i] = rounder(sqrtf(v), random_values.z); if (is_amsgrad) { - max_exp_avg_sq[i] = round_stochastically(sqrtf(max_v), random_values.w); + max_exp_avg_sq[i] = rounder(sqrtf(max_v), random_values.w); } } } diff --git a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu index 15a584cc0b170..86b92136f2e8f 100644 --- a/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu +++ b/aten/src/ATen/native/cuda/StochasticRoundingSGD.cu @@ -19,6 +19,8 @@ __global__ void stochastic_rounding_sgd_step_kernel( curandStatePhilox4_32_10_t state; curand_init(seeds.first, tid, seeds.second, &state); + round_stochastically rounder; + for (int i = tid; i < numel; i += blockDim.x * gridDim.x) { float weight = static_cast(weights[i]); float gradient = static_cast(gradients[i]) * (*inv_scale); @@ -42,9 +44,9 @@ __global__ void stochastic_rounding_sgd_step_kernel( weight -= lr * gradient; - weights[i] = round_stochastically(weight, random_values.x); + weights[i] = rounder(weight, random_values.x); if (momentum != 0.0f) - momentum_buffer[i] = round_stochastically(velocity, random_values.y); + momentum_buffer[i] = rounder(velocity, random_values.y); } } diff --git a/aten/src/ATen/native/cuda/stochastic_rounding.cuh b/aten/src/ATen/native/cuda/stochastic_rounding.cuh index b59ca5ec29790..4c24a3b40b2d6 100644 --- a/aten/src/ATen/native/cuda/stochastic_rounding.cuh +++ b/aten/src/ATen/native/cuda/stochastic_rounding.cuh @@ -44,17 +44,24 @@ __device__ __forceinline__ float get_delta_fp16(float x) { } // Natalia magic -template -__device__ __forceinline__ scalar_t round_stochastically(float x, float random_value) { - if (x == 0.0) { - return scalar_t(0.0); - } - float delta = get_delta_fp16(x); - float val; - if (x < 0.0) { - val = x - random_value * delta; - } else { - val = x + random_value * delta; +template +struct round_stochastically { + static_assert(std::is_same::value, "round_stochastically only supports round_to_prec=at::Half"); +}; + +template +struct round_stochastically { + __device__ __forceinline__ out_type operator()(in_type x, float random_value) { + if (x == 0.0) { + return out_type(0.0); + } + float delta = get_delta_fp16(static_cast(x)); + float val; + if (x < 0.0) { + val = x - random_value * delta; + } else { + val = x + random_value * delta; + } + return maybe_upcast(__float2half_rz(val)); } - return maybe_upcast(__float2half_rz(val)); -} +};