In [3]:
import torch
from torch import nn
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [75]:
def pos_encoding(ntours, embed_size):
    t = torch.arange(0, ntours, 1, dtype=dtype).unsqueeze(1)
    fi = torch.exp(
        -torch.arange(0, embed_size, 2) * torch.log(torch.tensor(10_000)) / embed_size
    )

    pe = torch.zeros(ntours, embed_size)
    pe[:, 0::2] = torch.sin(fi*t)
    pe[:, 1::2] = torch.cos(fi*t)
    return pe



In [272]:
class TSP(nn.Module):
    def __init__(self,
        nheads:int,
        embed_size:int,
        ff_size:int,
        nlayers:int,
        nnodes:int,
        xxx:int,
    ) -> None:
        super().__init__()
    
        # ENCODER
        self.embed = nn.Linear(2, embed_size)
        self.encoder = Encoder(nheads=nheads, embed_size=embed_size, ff_size=ff_size, nlayers=nlayers)

        # DECODER
        self.decoder = Decoder()
        self.Wk = nn.Linear(embed_size, xxx * embed_size, bias=False)
        self.Wv = nn.Linear(embed_size, xxx * embed_size, bias=False)

    def forward(self, x):
        h = self.embed(x)
        h = self.encoder(h)

        K, V = self.Wk(h), self.Wv(h)


In [664]:
def pos_encoding(ntours, embed_size):
    t = torch.arange(0, ntours, 1, dtype=dtype).unsqueeze(1)
    fi = torch.exp(
        -torch.arange(0, embed_size, 2) * torch.log(torch.tensor(10_000)) / embed_size
    )

    pe = torch.zeros(ntours, embed_size)
    pe[:, 0::2] = torch.sin(fi*t)
    pe[:, 1::2] = torch.cos(fi*t)
    return pe

class Encoder(nn.Module):
    def __init__(self,
        nheads:int,
        embed_size:int,
        ff_size:int,
        nlayers:int  
    ) -> None:
        super(Encoder, self).__init__()
        
        assert not embed_size % nheads, "The embedding size has to be a multiple of the number of heads."

        self.nlayers = nlayers

        self.mha = nn.ModuleList(nn.MultiheadAttention(embed_size, nheads, batch_first=True, bias=False) for _ in range(self.nlayers))

        self.ff1 = nn.ModuleList(nn.Linear(embed_size, ff_size) for _ in range(self.nlayers))
        self.ff2 = nn.ModuleList(nn.Linear(ff_size, embed_size) for _ in range(self.nlayers))

        self.bn = nn.ModuleList(nn.BatchNorm1d(embed_size) for _ in range(self.nlayers))
        
    def forward(self, h):
        hprev = torch.empty_like(h)
        for i in range(self.nlayers):
            hprev.copy_(h)
            h, _ = self.mha[i](h,h,h)
            h = h + hprev

            h = h.permute(0,2,1).contiguous()
            h = self.bn[i](h)
            h = h.permute(0,2,1).contiguous()

            hprev.copy_(h)
            h = self.ff2[i](torch.relu(self.ff1[i](h)))
            h = h + hprev
            
            h = h.permute(0,2,1).contiguous()
            h = self.bn[i](h)
            h = h.permute(0,2,1).contiguous()
        return h

class MHA_(nn.Module):
    def __init__(self,
        nheads:int,
        embed_size:int,
        ff_size:int,
        nlayers:int,
    ) -> None:
        super(MHA_, self).__init__()
        assert not embed_size % nheads, "The embedding size has to be a divisible by the number of heads."
        self.d_k = embed_size // nheads
        self.nheads = nheads
        self.embed_size = embed_size
        self.ff_size = ff_size
        self.nlayers = nlayers
    
    def _attention(self, query, key, value, mask=None, clip=None):
        attn = torch.bmm(query, key.transpose(1,2)) / self.d_k ** .5
        if mask is not None:
            if self.nheads > 1:
                mask = torch.repeat_interleave(mask, repeats=self.nheads, dim=0)

            mask = mask.unsqueeze(1)
            attn = attn.masked_fill(mask, float("-1e9"))
        
        if clip is not None:
            attn = clip * torch.tanh(attn)
        
        p_attn = attn.softmax(dim=-1)
        print(p_attn.size(), value.size())
        return torch.bmm(p_attn, value), p_attn

    def forward(self, query, key, value, mask=None, clip=None):
        nbatchs = query.size(0)
        if self.nheads > 1:
            query, key, value = [
                x.transpose(1,2).contiguous().view(nbatchs*self.nheads, self.d_k, nnd).transpose(1,2).contiguous()
                for nnd, x in zip((query.size(1), key.size(1), value.size(1)), (query, key, value))
            ]

        x, _ = self._attention(query, key, value, mask=mask)

        x = x.view(nbatchs, -1, self.d_k * self.nheads) if self.nheads > 1 else x
        return x

class Decoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()

class AutoregressiveDecoderLayer(nn.Module):
    def __init__(self,
        nheads:int,
        embed_size:int,
        ff_size:int,
        nlayers:int,
    ) -> None:
        super().__init__()

        self.lin = nn.ModuleList(nn.Linear(embed_size, embed_size) for _ in range(11))
        self.bn = nn.ModuleList(nn.LayerNorm(embed_size) for _ in range(3))

        self.kprev = None
        self.vprev = None

        self.mha = MHA_(nheads, embed_size, ff_size, nlayers)

    def forward(self, ht, key, value, mask):
        nbatchs = ht.size(0)
        print("wwhat", ht.size())

        # STEP (2)
        q = self.lin[0](ht)
        k = self.lin[1](ht)
        v = self.lin[2](ht)

        if self.kprev is None:
            self.kprev = k
            self.vprev = v
        else:
            self.kprev = torch.cat([self.kprev, k], dim=1)
            self.vprev = torch.cat([self.vprev, v], dim=1)
        
        ht += self.lin[3](self.mha(q, self.kprev, self.vprev, mask=None))
        ht = self.bn[0](ht.squeeze()).view(nbatchs, 1, -1)

        # STEP (3)
        q = self.lin[4](ht)
        ht += self.lin[5](self.mha(q, key, value, mask))
        ht = self.bn[1](ht.squeeze()).view(nbatchs, 1, -1)

        # STEP (4)
        ht += self.lin[6](torch.relu(self.lin[7](ht)))
        ht = self.bn[2](ht.squeeze(1))
        return ht

nheads = 4
nnodes = 11
embed_size = 8
ff_size = 21
nlayers = 3
nbatchs = 3
xxx = 5

# enc = Encoder(nheads, nnodes, embed_size, ff_size, nlayers)
x = torch.randn(nbatchs, 1, embed_size)
k = torch.randn(nbatchs, nnodes, embed_size)
v = torch.randn(nbatchs, nnodes, embed_size)
mask = torch.ones((nbatchs, nnodes))

tsp = AutoregressiveDecoderLayer(nheads, embed_size, ff_size, nlayers)
print("1", tsp(x, k, v, mask))
# m = MHA_(nheads, embed_size, ff_size, nlayers, nnodes)

wwhat torch.Size([3, 1, 8])
torch.Size([12, 1, 1]) torch.Size([12, 1, 2])
torch.Size([12, 1, 11]) torch.Size([12, 11, 2])
1 tensor([[-0.3766, -0.9537,  0.8644,  0.4186, -0.9731, -1.4053,  0.9279,  1.4978],
        [ 2.0470,  0.2075, -1.1381,  0.1470,  0.1432,  0.3725, -0.2944, -1.4846],
        [ 0.7489, -1.1418, -1.4837,  1.5960,  0.6947,  0.5540, -0.7376, -0.2305]],
       grad_fn=<NativeLayerNormBackward0>)
2 tensor([[ 0.0747, -0.5848,  0.4539,  0.9453, -1.6970, -1.2422,  0.9055,  1.1445],
        [ 1.8162,  0.4295, -1.6982,  0.0605, -0.0894,  0.5995,  0.0051, -1.1233],
        [ 0.8660, -1.3136, -1.5051,  1.3808,  1.0478, -0.5029, -0.0140,  0.0411]],
       grad_fn=<NativeLayerNormBackward0>)


In [562]:
nheads = 8
nnodes = 10
embed_size = 128
ff_size = 21
nlayers = 3
nbatchs = 3
xxx = 5

# enc = Encoder(nheads, nnodes, embed_size, ff_size, nlayers)
q = torch.randn(nbatchs, 1, embed_size)
k = torch.randn(nbatchs, nnodes, embed_size)
v = torch.randn(nbatchs, nnodes, embed_size)
mask = torch.randn(nbatchs, nnodes)

# tsp = AutoregressiveDecoderLayer(nheads, embed_size, ff_size, nlayers)

m = MHA_(nheads, embed_size, ff_size, nheads, nnodes)
m(q, k, v, mask).size()

RuntimeError: Mask tensor can take 0 and 1 values only