In [92]:
from models.crf import CRF
from models.cosformer import CosformerAttention as Attention
from models.fnet import FNet
from models.phm import phm

In [93]:
import torch

In [94]:
# model is question-answering model which accepts a context and a question
# usually, you would concatenate the context and question together, and pass it to the model
# however, this one is a bit different, it accepts the context and question separately

In [95]:
embedding_dim = 512
num_heads = 8
vocab_size = 5
batch_size = 2
context_len = 10
question_len = 3

In [96]:
context_input = Attention(embedding_dim, num_heads)

In [97]:
context_fnet_mix = FNet()

In [98]:
question_input = Attention(embedding_dim, num_heads)

In [99]:
question_fnet_mix = FNet()

In [100]:
context_to_question_cross = Attention(embedding_dim, num_heads)
question_to_context_cross = Attention(embedding_dim, num_heads)

In [101]:
after_cross_context = Attention(embedding_dim, num_heads)

In [102]:
after_cross_question = Attention(embedding_dim, num_heads)

In [103]:
final_question_to_context = Attention(embedding_dim, num_heads)

In [104]:
output = torch.nn.Linear(embedding_dim, vocab_size)

In [105]:
low_rank = 2
beam = 2

In [106]:
crf = CRF(vocab_size, beam, low_rank)

In [107]:
def forward(context, question, targets):
    c = context_input(context) + context
    c = context_fnet_mix(c) + c

    q = question_input(question) + question
    q = question_fnet_mix(q) + q

    c = question_to_context_cross(c, q, q) + c
    q = context_to_question_cross(q, c, c) + q

    c = after_cross_context(c) + c
    q = after_cross_question(q) + q

    c = final_question_to_context(c, q, q) + c

    # cos former attention outputs (len, batch, dim)
    # however we are expecting (batch, len, dim) so we transpose
    c = c.transpose(0, 1)
    
    logits = output(c)
    mask = ~targets.eq(-1)
    crfLoss = crf(logits, targets, mask)

    return logits, crfLoss

In [108]:
test_context = torch.randn(batch_size, context_len, embedding_dim)
test_question = torch.randn(batch_size, question_len, embedding_dim)
test_targets = torch.randint(0, vocab_size, (batch_size, context_len))

In [109]:
# cos former attention expects shape (len, batch, dim) so need to transpose
test_context = test_context.transpose(0, 1)
test_question = test_question.transpose(0, 1)

In [110]:
logits, crfLoss = forward(test_context, test_question, test_targets)

In [112]:
crfLoss

tensor([-376.4216, -269.3660], grad_fn=<SubBackward0>)