In [None]:
import torch
import torch.nn.functional as F

def softmax(v):
    # v -= v.max()
    ret = torch.exp(v) / torch.exp(v).sum()
    return ret

def sigmoid(v):
    return 1 / (1 + torch.exp(-v))

def silu(v):
    return v * sigmoid(v)

In [None]:
v = torch.randn([1, 10])
print(softmax(v), F.softmax(v, dim=1))
print(sigmoid(v), F.sigmoid(v)) # sigmoid is deprecated
print(silu(v), F.silu(v)) # silu is deprecated

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None, is_causal=False):
    B, L, D = q.shape
    attn_bias = torch.zeros(L, D, dtype=q.dtype)
    if is_causal:
        assert mask is None
        temp_mask = torch.ones(L, D, dtype=torch.bool).tril(diagonal=0)
        attn_bias = attn_bias.masked_fill(temp_mask, float("-inf"))
        attn_bias.to(q.dtype)
        
    if mask is not None:
        if mask.dtype == torch.bool:
            attn_score = attn_score.masked_fill(mask==0, float("-inf"))
        else:
            attn_bias += mask
    attn_score = q @ k.transpose(-2, -1) / (D ** 0.5)
    attn_score += attn_bias
    attn_score = torch.softmax(attn_score, dim=-1)
    output = attn_score @ v
    return output, attn_score

In [None]:
class myBN:
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        self.gamma = torch.nn.Parameter(torch.ones(num_features))
        self.beta = torch.nn.Parameter(torch.zeros(num_features))
        
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
    def forward(self, x, train=True):
        # assume x is of shape (B, C, H, W)
        # if group norm, reshape x to (B, Groups, C//Groups, H, W), and means on dimension 2, 3, 4
        if train:
            mean = x.mean(dim=0) # dim=0 for batch norm, -1 for layer norm, 2,3 for instance norm
            var = x.var(dim=0)

            self.running_mean.data = (1-self.momentum) * self.running_mean + self.momentum * mean
            self.running_var.data = (1-self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var

        x_ = (x-mean) / torch.sqrt(var+self.eps)
        x_ = self.gamma * x_ + self.beta


        return x_

In [None]:
import torch
import torch.nn as nn
# Example of target with class indices
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
# output.backward()
# # Example of target with class probabilities
# input = torch.randn(3, 5, requires_grad=True)
# target = torch.randn(3, 5).softmax(dim=1)
# output = loss(input, target)
# output.backward()

In [None]:
import torch
import numpy as np
relative_coords_h = torch.arange(-(3 - 1), 3, dtype=torch.float32)
relative_coords_w = torch.arange(-(3 - 1), 3, dtype=torch.float32)
relative_coords_table = torch.stack(
    torch.meshgrid([relative_coords_h,
                    relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
relative_coords_table[:, :, :, 0] /= (3 - 1)
relative_coords_table[:, :, :, 1] /= (3 - 1)
relative_coords_table *= 4  # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
    torch.abs(relative_coords_table) + 1.0) / np.log2(8)

In [None]:
relative_coords_table[0,2,3]

In [None]:
tmp = torch.Tensor([-1, 1])
torch.sign(tmp) * torch.log2(
    torch.abs(tmp) + 1.0) / np.log2(8)

In [None]:
tmp = torch.Tensor([1, -1])
torch.sign(tmp) * torch.log2(
    torch.abs(tmp) + 1.0) / np.log2(8)

In [None]:
coords_h = torch.arange(3)
coords_w = torch.arange(3)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += 3 - 1  # shift to start from 0
relative_coords[:, :, 1] += 3 - 1
relative_coords[:, :, 0] *= 2 * 3 - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

In [None]:
import math
import torch
def scale_dot_product_attention(query, key, value, attn_mask=None, is_causal=False, dropout_p=0.0, scale=None):
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)
    
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_mask.masked_fill_(attn_mask.logic_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value