Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 12, 2019
1 parent 8932647 commit 51b3b91
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 3 deletions.
81 changes: 81 additions & 0 deletions kernels/lltm_cuda.cpp
@@ -0,0 +1,81 @@
#include <torch/extension.h>

#include <vector>

// CUDA forward declarations

std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell);

std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights);

// C++ interface

// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<torch::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
CHECK_INPUT(input);
CHECK_INPUT(weights);
CHECK_INPUT(bias);
CHECK_INPUT(old_h);
CHECK_INPUT(old_cell);

return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}

std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
CHECK_INPUT(grad_h);
CHECK_INPUT(grad_cell);
CHECK_INPUT(input_gate);
CHECK_INPUT(output_gate);
CHECK_INPUT(candidate_cell);
CHECK_INPUT(X);
CHECK_INPUT(gate_weights);
CHECK_INPUT(weights);

return lltm_cuda_backward(
grad_h,
grad_cell,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights,
weights);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward (CUDA)");
m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
}
174 changes: 174 additions & 0 deletions kernels/lltm_cuda_kernel.cu
@@ -0,0 +1,174 @@
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

namespace {
template <typename scalar_t>
__device__ __forceinline__ scalar_t sigmoid(scalar_t z) {
return 1.0 / (1.0 + exp(-z));
}

template <typename scalar_t>
__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {
const auto s = sigmoid(z);
return (1.0 - s) * s;
}

template <typename scalar_t>
__device__ __forceinline__ scalar_t d_tanh(scalar_t z) {
const auto t = tanh(z);
return 1 - (t * t);
}

template <typename scalar_t>
__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) {
return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0));
}

template <typename scalar_t>
__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {
const auto e = exp(z);
const auto d_relu = z < 0.0 ? 0.0 : 1.0;
return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0);
}

template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < gates.size(2)){
input_gate[n][c] = sigmoid(gates[n][0][c]);
output_gate[n][c] = sigmoid(gates[n][1][c]);
candidate_cell[n][c] = elu(gates[n][2][c]);
new_cell[n][c] =
old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];
}
}

template <typename scalar_t>
__global__ void lltm_cuda_backward_kernel(
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < d_gates.size(2)){
const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];
const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
const auto d_new_cell =
d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];


d_old_cell[n][c] = d_new_cell;
const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
const auto d_input_gate = candidate_cell[n][c] * d_new_cell;

d_gates[n][0][c] =
d_input_gate * d_sigmoid(gate_weights[n][0][c]);
d_gates[n][1][c] =
d_output_gate * d_sigmoid(gate_weights[n][1][c]);
d_gates[n][2][c] =
d_candidate_cell * d_elu(gate_weights[n][2][c]);
}
}
} // namespace

std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));

const auto batch_size = old_cell.size(0);
const auto state_size = old_cell.size(1);

auto gates = gate_weights.reshape({batch_size, 3, state_size});
auto new_h = torch::zeros_like(old_cell);
auto new_cell = torch::zeros_like(old_cell);
auto input_gate = torch::zeros_like(old_cell);
auto output_gate = torch::zeros_like(old_cell);
auto candidate_cell = torch::zeros_like(old_cell);

const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);

AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));

return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
}

std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gates,
torch::Tensor weights) {
auto d_old_cell = torch::zeros_like(new_cell);
auto d_gates = torch::zeros_like(gates);

const auto batch_size = new_cell.size(0);
const auto state_size = new_cell.size(1);

const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);

AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
}));

auto d_gate_weights = d_gates.flatten(1, 2);
auto d_weights = d_gate_weights.t().mm(X);
auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);

auto d_X = d_gate_weights.mm(weights);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);

return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates};
}
14 changes: 14 additions & 0 deletions kernels/setup.py
@@ -0,0 +1,14 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
name='lltm_cuda',
ext_modules=[
CUDAExtension('lltm_cuda', [
'lltm_cuda.cpp',
'lltm_cuda_kernel.cu',
]),
],
cmdclass={
'build_ext': BuildExtension
})
6 changes: 3 additions & 3 deletions torch_struct/semirings/semirings.py
Expand Up @@ -136,7 +136,7 @@ class LogSemiring(_BaseLog):
def sum(xs, dim=-1):
return torch.logsumexp(xs, dim=dim)


from pykeops.torch import LazyTensor
class LogMemSemiring(_BaseLog):
"""
Implements the log-space semiring (logsumexp, +, -inf, 0).
Expand All @@ -150,8 +150,8 @@ def sum(xs, dim=-1):

@classmethod
def matmul(cls, a, b, dims=1):
max_a = a.max()
max_b = b.max()
max_a = a.max(dim=-1, keepdim=True)[0]
max_b = b.max(dim=-2, keepdim=True)[0]
exp_a, exp_b = a - max_a, b - max_b
c = torch.matmul(exp_a.exp(), exp_b.exp())
c = (c.log() + max_a + max_b)
Expand Down

0 comments on commit 51b3b91

Please sign in to comment.