In [2]:
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F

In [160]:
class Attn(nn.Module):
    def __init__(self, emb_dim, q_dim, k_dim):
        super().__init__()
        self.Wq = nn.Linear(emb_dim, q_dim)
        self.Wk = nn.Linear(emb_dim, k_dim)
        self.Wv = nn.Linear(emb_dim, k_dim)
        
    def forward(self, x, z):
        Q = self.Wq(x)
        K = self.Wk(z)
        V = self.Wv(z)
        print(K.shape)
        S = Q.bmm(K.transpose(1,2))
        # uhh masking here
#         print(S.shape)
        sm = F.softmax(torch.div(S,torch.sqrt(torch.tensor(x.shape[-1]))),dim=-1)
#         print(sm.shape)
#         print(V.shape)
        return sm.bmm(V)
    
class MHAttn(nn.Module):
    def __init__(self, num_heads, emb_dim, q_dim, k_dim):
        super().__init__()
        self.heads = nn.ModuleList([Attn(emb_dim, q_dim, k_dim) for _ in range(num_heads)])
        self.Wo = nn.Linear(num_heads * k_dim, emb_dim)
        
    def forward(self, x, z):
        subAttns = torch.cat([h(x,z) for h in self.heads], dim=-1)
        return self.Wo(subAttns)
        
# encoder is definitionally self-attn
class EncoderBlock(nn.Module):
    def __init__(self, emb_dim, z_dim, heads):
        super().__init__()
        self.attn = MHAttn(heads, emb_dim, z_dim, z_dim)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)
        self.ff1 = nn.Linear(emb_dim,emb_dim)
        self.ff2 = nn.Linear(emb_dim,emb_dim)
        
    def forward(self, z):
        z = z + self.attn(z,z)
        z = self.ln1(z)
        z = z + self.ff2(F.relu(self.ff1(z)))
        return self.ln2(z)
    
class DecoderBlock(nn.Module):
    def __init__(self, emb_dim, x_dim, z_dim, heads):
        super().__init__()
        self.attn = MHAttn(heads, emb_dim, x_dim, z_dim)
        self.ln1 = nn.LayerNorm(emb_dim)
        self.ln2 = nn.LayerNorm(emb_dim)
        self.ln3 = nn.LayerNorm(emb_dim)
        self.ff1 = nn.Linear(emb_dim,emb_dim)
        self.ff2 = nn.Linear(emb_dim,emb_dim)
        
    def forward(self, x, z):
        x = x + self.attn(x,x)
        x = self.ln1(x)
        x = x + self.attn(x,z)
        x = self.ln2(x)
        x = x + self.ff2(F.relu(self.ff1(x)))
        return self.ln3(x)
        
# review this

class EDTransformer(nn.Module):
    def __init__(self, embs, pos, emb_dim, x_dim, z_dim, heads, enc_blocks, dec_blocks, out_dim):
        # needs embedding matrix
        # garbage for now
        # positional embedding scheme
        # softmax FF at the end
        super().__init__()
        self.enc_blocks = enc_blocks
        self.dec_blocks = dec_blocks
        self.encoderBlocks = nn.ModuleList(
            [EncoderBlock(emb_dim, x_dim, heads) for _ in range(enc_blocks)]
        )
        self.decoderBlocks = nn.ModuleList(
            [DecoderBlock(emb_dim, x_dim, z_dim, heads) for _ in range(dec_blocks)]
        )
        self.ff = nn.Linear(emb_dim, out_dim)

    def forward(self, x,z):
        # embed + pos
        # loop through encoder blocks
        for i in range(self.enc_blocks):
            z = self.encoderBlocks[i](z)
        # loop through decoder blocks
        for i in range(self.dec_blocks):
            x = self.decoderBlocks[i](x, z)
        # FF and softmax
        return F.softmax(self.ff(x), dim=-1)


In [93]:
exAttn = Attn(256, 32, 32)

In [112]:
# sequence of 8 words
x = torch.randn(1,8,256)
# seq of 10
z = torch.randn(1,10,256)

In [116]:
# three seqs of 8 words
x = torch.randn(3,8,256)
# three seqs of 10
z = torch.randn(3,10,256)

In [117]:
res = exAttn(x,x)
res.shape

torch.Size([3, 8, 32])
torch.Size([3, 8, 8])
torch.Size([3, 8, 8])
torch.Size([3, 8, 32])


torch.Size([3, 8, 32])

In [118]:
exMHA = MHAttn(8, 256, 32, 32)

In [119]:
mhres = exMHA(x,x)
mhres.shape

torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])


torch.Size([3, 8, 256])

In [120]:
exEB = EncoderBlock(256, 8)

In [121]:
ebres = exEB(x)

torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])


In [122]:
ebres.shape

torch.Size([3, 8, 256])

In [123]:
exDB = DecoderBlock(256, 256, 8)

In [124]:
dbres = exDB(x, z)

torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 8, 256])
torch.Size([3, 10, 256])
torch.Size([3, 10, 256])
torch.Size([3, 10, 256])
torch.Size([3, 10, 256])
torch.Size([3, 10, 256])
torch.Size([3, 10, 256])
torch.Size([3, 10, 256])
torch.Size([3, 10, 256])


In [115]:
dbres.shape

torch.Size([1, 8, 256])

In [161]:
exEDT = EDTransformer(None, None, 256, 32, 32, 8, 3,3,10)

In [162]:
edtres = exEDT(x,z)

torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 10, 32])
torch.Size([3, 8, 32])
torch.Size([3, 8, 32])
to

In [163]:
edtres.shape

torch.Size([3, 8, 10])