In [None]:
import torch
from torch import nn

# MaxPool

In [None]:
mp = torch.nn.MaxPool1d(4, stride=2)

In [None]:
x = torch.tensor([[1,2,3,4,5,6,7,8,9,10]], dtype=torch.float)
mp(x)

# AsymBlock

In [None]:
class AsymBlock(nn.Module):
    def __init__(self, config, out_tokens):
        super().__init__()
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Parameter(torch.rand(1, out_tokens, config.n_embd))
        self.value = nn.Linear(config.n_embd, config.n_embd)

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.attention = nn.MultiheadAttention(config.n_embd, config.n_head, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, config.n_embd),
        )

    def forward(self, x):
        x = self.ln1(x)
        key = self.key(x)
        value = self.value(x)
        query = self.query.repeat([x.shape[0], 1, 1])
        attn_output, attn_output_weights = self.attention(query, key, value)
        x = self.mlp(self.ln2(attn_output))
        return x

# nn.MultiHeadAttention

In [None]:
def main():
    attention = nn.MultiheadAttention(16, 4, batch_first=True)
    query = torch.ones(1, 2, 16)
    key = torch.ones(1, 4, 16)
    value = torch.ones(1, 4, 16)
    print(dir(attention))
    att, att_weights = attention(query, key, value)
    print(att)
    
main()

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.constant_(m.weight, 1.0)
        torch.nn.init.constant_(m.bias, 0.1)

In [None]:
def main():
    attention = nn.MultiheadAttention(4,2, batch_first=True)
    attention.apply(init_weights)
    print(attention.in_proj_weight)
    # query = torch.ones(1, 2, 4)
    # key = torch.ones(1, 2, 4)
    # value = torch.ones(1, 2, 4)
    # att, att_weights = attention(query, key, value)
    # print(att)
main()

# torch.repeat_interleave

In [67]:
def main():
    x = torch.arange(6).view(1,2,3)
    print(x)
    print(x.shape)
    y = x.repeat(1,1,2)
    print(y)
    print(y.shape)
main()

tensor([[[0, 1, 2],
         [3, 4, 5]]])
torch.Size([1, 2, 3])
tensor([[[0, 1, 2, 0, 1, 2],
         [3, 4, 5, 3, 4, 5]]])
torch.Size([1, 2, 6])


In [68]:
def main():
    x = torch.arange(6).view(1,2,3)
    print(x)
    print(x.shape)
    y = torch.repeat_interleave(x,2,2)
    print(y)
    print(y.shape)
main()

tensor([[[0, 1, 2],
         [3, 4, 5]]])
torch.Size([1, 2, 3])
tensor([[[0, 0, 1, 1, 2, 2],
         [3, 3, 4, 4, 5, 5]]])
torch.Size([1, 2, 6])
