In [2]:
import sys
sys.path.append('..')

In [157]:
import torch
from torch import nn, Tensor
from torch.nn import functional as F

In [158]:
from fastcore.meta import delegates

In [159]:
from modular_unet.blocks import ConvLayer
from modular_unet.utils import test_forward

In [162]:
def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor) -> Tensor:
    temp = query.bmm(key.transpose(1, 2))
    scale = query.size(-1) ** 0.5
    softmax = F.softmax(temp / scale, dim=-1)
    return softmax.bmm(value)

In [163]:
class AttentionHead(nn.Module):
    def __init__(self, dim_in: int, dim_k: int, dim_v: int):
        super().__init__()
        self.q = nn.Linear(dim_in, dim_k)
        self.k = nn.Linear(dim_in, dim_k)
        self.v = nn.Linear(dim_in, dim_v)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value))

In [164]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_k: int, dim_v: int):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_k, dim_v) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_v, dim_in)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return self.linear(
            torch.cat([h(query, key, value) for h in self.heads], dim=-1)
        )

In [168]:
m1 = torch.randn(32, 3, 64)
m2 = torch.randn(32, 3, 64)
m3 = torch.randn(32, 3, 64)

In [169]:
MultiHeadAttention(6, 64, 64, 64)(m1, m2, m3)

tensor([[[-0.1360,  0.1602, -0.0319,  ...,  0.1782, -0.3392,  0.2051],
         [-0.1027,  0.0974, -0.0130,  ...,  0.2100, -0.3227,  0.1479],
         [-0.0503,  0.0691, -0.0770,  ...,  0.3561, -0.1989,  0.1192]],

        [[-0.1580, -0.2195,  0.0673,  ..., -0.0643, -0.4210,  0.0279],
         [-0.1611, -0.2485,  0.0882,  ..., -0.0353, -0.4525, -0.0597],
         [-0.1083, -0.1868,  0.0583,  ...,  0.0051, -0.4212, -0.0304]],

        [[-0.1810, -0.0170, -0.0697,  ..., -0.0469,  0.1518,  0.1783],
         [-0.0625, -0.0267, -0.2017,  ...,  0.0484,  0.1649,  0.0388],
         [-0.0907,  0.1069, -0.1266,  ..., -0.0737,  0.1155,  0.0969]],

        ...,

        [[ 0.0271,  0.0183, -0.1683,  ..., -0.1653, -0.3830,  0.4774],
         [-0.1080, -0.0171, -0.1254,  ..., -0.0427, -0.3177,  0.3580],
         [-0.0766, -0.0913, -0.1727,  ..., -0.0790, -0.3399,  0.3344]],

        [[-0.2523, -0.1449,  0.0586,  ...,  0.1386,  0.0855, -0.0376],
         [-0.2535, -0.0744,  0.0601,  ...,  0.0065,  0.