In [1]:
from models.crf import CRF
from models.cosformer import CosformerAttention as Attention
from models.fnet import FNet
from models.phm import phm

In [2]:
import torch

In [3]:
# model is question-answering model which accepts a context and a question
# usually, you would concatenate the context and question together, and pass it to the model
# however, this one is a bit different, it accepts the context and question separately

In [4]:
embedding_dim = 512
num_heads = 8
vocab_size = 5
batch_size = 2
context_len = 10
question_len = 3

In [5]:
context_input = Attention(embedding_dim, num_heads)

In [6]:
context_fnet_mix = FNet(embedding_dim)

In [7]:
question_input = Attention(embedding_dim, num_heads)

In [8]:
question_fnet_mix = FNet(embedding_dim)

In [9]:
context_to_question_cross = Attention(embedding_dim, num_heads)
question_to_context_cross = Attention(embedding_dim, num_heads)

In [10]:
after_cross_context = Attention(embedding_dim, num_heads)

In [11]:
after_cross_question = Attention(embedding_dim, num_heads)

In [12]:
final_question_to_context = Attention(embedding_dim, num_heads)

In [13]:
output = torch.nn.Linear(embedding_dim, vocab_size)

In [14]:
low_rank = 2
beam = 2

In [15]:
crf = CRF(vocab_size, beam, low_rank)

In [16]:
def forward(context, question, targets):
    c = context_input(context) + context
    c = context_fnet_mix(c) + c

    q = question_input(question) + question
    q = question_fnet_mix(q) + q

    c = question_to_context_cross(c, q, q) + c
    q = context_to_question_cross(q, c, c) + q

    c = after_cross_context(c) + c
    q = after_cross_question(q) + q

    c = final_question_to_context(c, q, q) + c

    # cos former attention outputs (len, batch, dim)
    # however we are expecting (batch, len, dim) so we transpose
    c = c.transpose(0, 1)
    
    logits = output(c)
    mask = ~targets.eq(-1)
    crfLoss = crf(logits, targets, mask)

    return logits, crfLoss, c, q

In [17]:
test_context = torch.randn(batch_size, context_len, embedding_dim)
test_question = torch.randn(batch_size, question_len, embedding_dim)
test_targets = torch.randint(0, vocab_size, (batch_size, context_len))

In [18]:
# cos former attention expects shape (len, batch, dim) so need to transpose
test_context = test_context.transpose(0, 1)
test_question = test_question.transpose(0, 1)

In [19]:
logits, crfLoss, c, q = forward(test_context, test_question, test_targets)

In [20]:
crfLoss

tensor([ -79.4506, -125.0226], grad_fn=<SubBackward0>)

In [21]:
c.shape # batch, len, dim (2, 10, 512)

torch.Size([2, 10, 512])

In [22]:
desired_len = 5

In [23]:
from torch.nn import AdaptiveAvgPool1d

In [24]:
pooling = AdaptiveAvgPool1d(desired_len)

In [25]:
c = c.transpose(1, 2)
pooled = pooling(c)

In [26]:
pooled.shape # now shape is (batch, dim, len), (2, 512, 5)
# however, the logits input needs to be (batch, len, dim) so we transpose again
pooled = pooled.transpose(1, 2)

In [27]:
pooled.shape

torch.Size([2, 5, 512])

In [28]:
test2 = torch.tensor([[1,2,3]])

In [29]:
k = test2

In [30]:
k

tensor([[1, 2, 3]])

In [31]:
test2

tensor([[1, 2, 3]])

In [32]:
k

tensor([[1, 2, 3]])