In [41]:
import torch
from torch.nn.functional import pad
from torch.utils.data import DataLoader, Dataset

In [42]:
from models.model import CustomModel
from models.model import CRFLayer

In [43]:
from datasets import load_from_disk

In [44]:
import tokenmonster

In [45]:
tokenizer_file = "english-8000-balanced-v1"

In [46]:
vocab = tokenmonster.load_multiprocess_safe(tokenizer_file)

In [47]:
initial_vocab_size = len(vocab)

In [48]:
pad_idx = initial_vocab_size + 1
eos_idx = initial_vocab_size + 2

In [49]:
vocab_size = initial_vocab_size + 2

In [50]:
question_len = 256
context_len = 512
answer_len = 256

In [51]:
squad_train_file = "data/squad_train"
dolly_test_file = "data/closed"

In [52]:
squad = load_from_disk(squad_train_file)
dolly = load_from_disk(dolly_test_file)

In [53]:
def collate_fn(batch, question_key, context_key, answer_key):
    """
    Each batch has 3 elements: instruction, context, response
    However since their lengths may vary, we need to pad them.

    For response, must ensure that we add eos token and then begin padding.
    """
    questions = []
    contexts = []
    answers = []

    for elem in batch:
        i = elem[question_key]
        c = elem[context_key]
        r = elem[answer_key]

        questions.append(i)
        contexts.append(c)
        answers.append(r)

    # Pad instructions to question_len
    questions = [pad(i, (0, question_len - len(i)), value=pad_idx) for i in questions]

    # Pad contexts to context_len
    contexts = [pad(c, (0, context_len - len(c)), value=pad_idx) for c in contexts]

    # Pad responses to answer_len
    # though first making sure that eos token is added to end of each response
    eos_append = torch.tensor([eos_idx])
    answers = [torch.cat((r[:answer_len-1], eos_append)) for r in answers]
    answers = [pad(r, (0, answer_len - len(r)), value=pad_idx) for r in answers]

    return torch.stack(questions), torch.stack(contexts), torch.stack(answers)

In [54]:
def dolly_collate_fn(batch):
    return collate_fn(batch, "instruction", "context", "response")

In [55]:
def squad_collate_fn(batch):
    return collate_fn(batch, "question", "context", "answers")

In [56]:
dollyDataloader = DataLoader(dolly, batch_size=2, shuffle=True, collate_fn=dolly_collate_fn)

In [57]:
squadDataloader = DataLoader(squad, batch_size=2, shuffle=True, collate_fn=squad_collate_fn)

In [58]:
for batch in squadDataloader:
    break

In [59]:
q, c, a = batch

In [72]:
vocab.decode(q[0].tolist())

'The search for mysticism in the Dominican Order goes back to what?'