In [None]:
import torch
import torch.nn.functional as F

def softmax(v):
    # v -= v.max()
    ret = torch.exp(v) / torch.exp(v).sum()
    return ret

def sigmoid(v):
    return 1 / (1 + torch.exp(-v))

def silu(v):
    return v * sigmoid(v)

In [None]:
v = torch.randn([1, 10])
print(softmax(v), F.softmax(v, dim=1))
print(sigmoid(v), F.sigmoid(v)) # sigmoid is deprecated
print(silu(v), F.silu(v)) # silu is deprecated

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None, is_causal=False):
    B, L, D = q.shape
    attn_score = q @ k.transpose(-2, -1) / (D ** 0.5)
    attn_bias = torch.zeros(L, D, dtype=q.dtype)
    if is_causal:
        assert mask is None
        temp_mask = torch.ones(L, D, dtype=torch.bool).tril(diagonal=0)
        attn_bias = attn_bias.masked_fill(temp_mask, float("-inf"))
        attn_bias.to(q.dtype)
        
    if mask is not None:
        if mask.dtype == torch.bool:
            attn_score = attn_score.masked_fill(mask==0, float("-inf"))
        else:
            attn_bias += mask
    attn_score = F.softmax(attn_score, dim=-1)
    output = attn_score @ v
    return output, attn_score

In [None]:
class myBN:
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        self.gamma = torch.nn.Parameter(torch.ones(num_features))
        self.beta = torch.nn.Parameter(torch.zeros(num_features))
        
        self.register_buffer('running_mean', torch.zeors(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
    def forward(self, x, train=True):
        if train:
            mean = x.mean(dim=0) # dim=0 for batch norm, -1 for layer norm
            var = x.var(dim=0)

            self.running_mean = (1-self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1-self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var

        x_ = (x-mean) / torch.sqrt(var+self.eps)
        x_ = self.gamma * x_ + self.beta


        return x_

In [1]:
import torch
x=torch.randn([4,2,3])

In [4]:
x.var().sqrt()

tensor(0.9568)