In [3]:
import torch
from torch import nn
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [75]:
def pos_encoding(ntours, embed_size):
    t = torch.arange(0, ntours, 1, dtype=dtype).unsqueeze(1)
    fi = torch.exp(
        -torch.arange(0, embed_size, 2) * torch.log(torch.tensor(10_000)) / embed_size
    )

    pe = torch.zeros(ntours, embed_size)
    pe[:, 0::2] = torch.sin(fi*t)
    pe[:, 1::2] = torch.cos(fi*t)
    return pe



In [272]:
class TSP(nn.Module):
    def __init__(self,
        nheads:int,
        embed_size:int,
        ff_size:int,
        nlayers:int,
        nnodes:int,
        xxx:int,
    ) -> None:
        super().__init__()
    
        # ENCODER
        self.embed = nn.Linear(2, embed_size)
        self.encoder = Encoder(nheads=nheads, embed_size=embed_size, ff_size=ff_size, nlayers=nlayers)

        # DECODER
        self.decoder = Decoder()
        self.Wk = nn.Linear(embed_size, xxx * embed_size, bias=False)
        self.Wv = nn.Linear(embed_size, xxx * embed_size, bias=False)

    def forward(self, x):
        h = self.embed(x)
        h = self.encoder(h)

        K, V = self.Wk(h), self.Wv(h)


In [394]:
def pos_encoding(ntours, embed_size):
    t = torch.arange(0, ntours, 1, dtype=dtype).unsqueeze(1)
    fi = torch.exp(
        -torch.arange(0, embed_size, 2) * torch.log(torch.tensor(10_000)) / embed_size
    )

    pe = torch.zeros(ntours, embed_size)
    pe[:, 0::2] = torch.sin(fi*t)
    pe[:, 1::2] = torch.cos(fi*t)
    return pe

class Encoder(nn.Module):
    def __init__(self,
        nheads:int,
        embed_size:int,
        ff_size:int,
        nlayers:int  
    ) -> None:
        super(Encoder, self).__init__()
        
        assert not embed_size % nheads, "The embedding size has to be a multiple of the number of heads."

        self.nlayers = nlayers

        self.mha = nn.ModuleList(nn.MultiheadAttention(embed_size, nheads, batch_first=True, bias=False) for _ in range(self.nlayers))

        self.ff1 = nn.ModuleList(nn.Linear(embed_size, ff_size) for _ in range(self.nlayers))
        self.ff2 = nn.ModuleList(nn.Linear(ff_size, embed_size) for _ in range(self.nlayers))

        self.bn = nn.ModuleList(nn.BatchNorm1d(embed_size) for _ in range(self.nlayers))
        
    def forward(self, h):
        hprev = torch.empty_like(h)
        for i in range(self.nlayers):
            hprev.copy_(h)
            h, _ = self.mha[i](h,h,h)
            h = h + hprev

            h = h.permute(0,2,1).contiguous()
            h = self.bn[i](h)
            h = h.permute(0,2,1).contiguous()

            hprev.copy_(h)
            h = self.ff2[i](torch.relu(self.ff1[i](h)))
            h = h + hprev
            
            h = h.permute(0,2,1).contiguous()
            h = self.bn[i](h)
            h = h.permute(0,2,1).contiguous()
        return h

class MHA_(nn.Module):
    def __init__(self,
        nheads:int,
        embed_size:int,
        ff_size:int,
        nlayers:int,
    ) -> None:
        super(MHA_, self).__init__()
        assert not embed_size % nheads, "The embedding size has to be a divisible by the number of heads."
        self.d_k = embed_size // nheads
        self.nheads = nheads
        self.embed_size = embed_size
        self.ff_size = ff_size
        self.nlayers = nlayers

        self.ff = nn.ModuleList(nn.Linear(embed_size, embed_size) for _ in range(nlayers))
    
    def _attention(self, query, key, value, mask=None, clip=None):
        attn = torch.matmul(query, key.transpose(-2,-1)) / query.size(-1) ** .5
        if mask is not None:
            if self.nheads > 1:
                mask = torch.repeat_interleave(mask, repeats=self.nheads, dim=0)
            mask = mask.unsqueeze(1)
            attn = attn.masked_fill(mask == 0, -1e-9)

        if clip is not None:
            attn = clip * torch.tanh(attn)
        
        p_attn = attn.softmax(dim=-1)
        return torch.matmul(p_attn, value), p_attn

    def forward(self, query, key, value, mask=None, clip=None):
        nbatchs = query.size(0)
        
        query, key, value = [
            x.view(nbatchs, self.nheads, -1, self.d_k)
            for x in (query, key, value)
        ]

        x, attn = self._attention(query, key, value, mask=mask)
        x = x.view(nbatchs, -1, self.d_k * self.nheads)

        return x
    
class Decoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()

class AutoregressiveDecoderLayer(nn.Module):
    def __init__(self,
        nheads:int,
        embed_size:int,
        ff_size:int,
        nlayers:int,
    ) -> None:
        super().__init__()

        self.Wq = nn.ModuleList(nn.Linear(embed_size, embed_size) for _ in range(2))
        self.Wk = nn.ModuleList(nn.Linear(embed_size, embed_size) for _ in range(2))
        self.Wv = nn.ModuleList(nn.Linear(embed_size, embed_size) for _ in range(2))
        self.lin1 = nn.Linear(embed_size, embed_size)
        self.bn1 = nn.BatchNorm1d(embed_size)

        self.kprev = None
        self.vprev = None

        self.mha = MHA_(nheads, embed_size, ff_size, nlayers)
    def forward(self, ht, key, value, mask=None):
        nbatchs = ht.size(0)

        # STEP (2)
        q = self.Wq[0](ht)
        k = self.Wk[0](ht)
        v = self.Wv[0](ht)

        if self.kprev is  None:
            self.kprev = k
            self.vprev = v
        else:
            self.kprev = torch.cat([self.kprev, k], dim=1)
            self.vprev = torch.cat([self.vprev, v], dim=1)
        
        ht += self.lin1(self.mha(q, self.kprev, self.vprev))
        ht = self.bn1(ht.squeeze(1)).view(nbatchs, 1, -1)

        # STEP (3)
        q = self

        print(ht.size())



In [395]:
nheads = 1
nnodes = 10
embed_size = 4
ff_size = 21
nlayers = 3
nbatchs = 3
xxx = 5

# enc = Encoder(nheads, nnodes, embed_size, ff_size, nlayers)
x = torch.randn(nbatchs, 1, embed_size)
k = torch.randn(nbatchs, 20, embed_size)
q = torch.randn(nbatchs, 20, embed_size)
v = torch.randn(nbatchs, 20, embed_size)

tsp = AutoregressiveDecoderLayer(nheads, embed_size, ff_size, nlayers)
tsp(x, k, v)

torch.Size([3, 1, 4])


In [369]:
# x = torch.randn(nbatchs, 1, embed_size)
# myMHA(x, x, x, nheads)[0]

tensor([[[ 0.1133,  0.7578, -0.4857, -0.2076]]])