## Task Description

Write a custom fused ReLU class for 2D tensors, that implements **vector addition and ReLU**, with the **forward** pass. Then write a fused ReLU Triton kernel for 2D tensors that performs the **vector addition and ReLU** with the **forward** pass and computes its matching **backward** pass. Verify that both implementations produce the same outputs and gradients on random inputs and write a benchmark test with 'triton.testing' to showcase the efficiency of Triton in comparison to Torch. And you should look into **PyTorch's JIT** to make your pytorch implementation more efficient.

## Grading scheme
Total: 5 points
1. **Implementation of the forward pass for the fused ReLU with PyTorch** (0.5 points)
2. **Implementation of the forward & backward pass for the fused ReLU with Triton** (4 points)
3. **Verify outputs & benchmark test** (0.5 points)

## Code

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton.language as tl
import triton
import triton.testing

In [16]:
DEVICE = triton.runtime.driver.active.get_active_torch_device()

RuntimeError: 0 active drivers ([]). There should only be one.

### 1. Implementation of the forward pass for the fused ReLU with PyTorch (0.5 points)

In [9]:
class TorchAddReLU(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        assert y.dim() == 2 and y.dim() == 2

        added = x + y
        mask = added > 0
        ctx.save_for_backward(mask)
        return added * mask

### 2. Implementation of the forward & backward pass for the fused ReLU with Triton (4 points)

In [19]:
@triton.jit
def add_relu_kernel(x_ptr,
               y_ptr,
               output_ptr,
               n_elements,
               BLOCK_SIZE: tl.constexpr,
               ):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    
    summed = x + y
    relued = tl.where(summed > 0, summed, 0)
    
    tl.store(output_ptr + offsets, output, mask=mask)

def add_relu(x: torch.Tensor, y: torch.Tensor, BLOCK_SIZE=1024):
    output = torch.empty_like(x)
    n_elements = output.numel()

    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    add_relu_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    return output

### 3. Verify outputs & benchmark test (0.5 points)

In [12]:
x = torch.arange(5).unsqueeze(0)
y = - 2 * torch.ones(5).unsqueeze(0)
F.relu(x+y)

tensor([[0., 0., 0., 1., 2.]])

In [13]:
TorchAddReLU.apply(x,y)

tensor([[-0., -0., 0., 1., 2.]])

In [20]:
add_relu(x,y)

RuntimeError: 0 active drivers ([]). There should only be one.