In [4]:
import torch
import torch.nn as nn

class MyDropout(nn.Module):
    def __init__(self, p: float = 0.5):
        super().__init__()
        if not (0.0 <= p < 1.0):
            raise ValueError("p must be in [0, 1).")
        self.p = p

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if (not self.training) or self.p == 0.0:
            return x

        print("p_dropout:", self.p)

        keep_prob = 1.0 - self.p
        print("keep_prob:", keep_prob)

        # Bernoulli mask: 1 表示保留，0 表示丢弃
        mask = (torch.rand_like(x) < keep_prob).to(x.dtype) # 逐元素乘法，mask 和 x 的 shape 是一样的，并且不改变 shape
        print("mask:", mask)

        # inverted dropout：除以 keep_prob 保持期望一致
        return x * mask / keep_prob


# demo
x = torch.ones(2, 5)
drop = MyDropout(p=0.1)

drop.train()
print("train:", drop(x))  # 会看到一些 0，同时非零位置会被放大到 1/0.6 ≈ 1.6667

drop.eval()
print("eval :", drop(x))  # 全是 1


p_dropout: 0.1
keep_prob: 0.9
mask: tensor([[1., 0., 0., 1., 1.],
        [1., 1., 1., 1., 1.]])
train: tensor([[1.1111, 0.0000, 0.0000, 1.1111, 1.1111],
        [1.1111, 1.1111, 1.1111, 1.1111, 1.1111]])
eval : tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
