# 编写在triton tutorial 2.3.1的dropout章节的作业
- 目标是在教程的一维的基础上，学习写一个二维的dropout

## 整个kernel氛围四大部分
- 1、定义pid
- 2、定义地址以及offset
- 3、输入加载以及算法主体
- 4、结果store

In [3]:
import torch

import triton
import triton.language as tl

In [4]:
@triton.jit
def dropout_2D_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride,p, seed, n_cols, BLOCK_SIZE: tl.constexpr):
    # 这段是确定pid分布以及ptr布局
    row_idx = tl.program_id(axis=0)
    # the stribe represents how much we need to increase the pointer to advance 1 row
    row_start_ptr = input_ptr + row_idx * input_row_stride
    output_start_ptr = output_ptr + row_idx * output_row_stride
    # th block size is the next power of 2 greater than n_cols, so we can fit each
    # row in a single block
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets

    # 加载
    # load the row into SRAM, using a mask since Block size may be greater than n_cols
    x = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))

    # 计算
    random  = tl.rand(seed, col_offsets)
    x_keep = random > p
    output = tl.where(x_keep, x/(1-p), 0.0)


    # 存储
    tl.store(output_start_ptr+col_offsets, output, mask=col_offsets < n_cols)


    return