In [1]:
import torch
from torch.nn.functional import pad
from torch.utils.data import DataLoader, Dataset

In [2]:
from models.model import CustomModel
from models.model import CRFModel

In [3]:
from datasets import load_from_disk

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import tokenmonster

In [5]:
tokenizer_file = "english-8000-balanced-v1"

In [6]:
vocab = tokenmonster.load_multiprocess_safe(tokenizer_file)

In [7]:
initial_vocab_size = len(vocab)

In [8]:
# initial_vocab_size is 8000, so largest valid index is 7999

In [9]:
pad_idx = initial_vocab_size # max valid index is now 8000
eos_idx = initial_vocab_size + 1 # max valid index is now 8001

In [10]:
vocab_size = initial_vocab_size + 2

In [11]:
question_len = 256
context_len = 512
answer_len = 256
assert context_len >= answer_len

In [12]:
embedding_dim = 768
num_heads = 12
target_len = answer_len
num_helix_layers = 3
num_single_strand_layers = 2
phm_factor = 4
lm_head_phm_factor = 2
beam = 64
low_rank = 32

In [13]:
squad_train_file = "data/squad_train"
dolly_test_file = "data/closed"

In [14]:
squad = load_from_disk(squad_train_file)
dolly = load_from_disk(dolly_test_file)

In [15]:
def collate_fn(batch, question_key, context_key, answer_key):
    """
    Each batch has 3 elements: instruction, context, response
    However since their lengths may vary, we need to pad them.

    For response, must ensure that we add eos token and then begin padding.
    """
    questions = []
    contexts = []
    answers = []

    for elem in batch:
        i = elem[question_key]
        c = elem[context_key]
        r = elem[answer_key]

        questions.append(i)
        contexts.append(c)
        answers.append(r)

    # Pad instructions to question_len
    questions = [pad(i, (0, question_len - len(i)), value=pad_idx) for i in questions]

    # Pad contexts to context_len
    contexts = [pad(c, (0, context_len - len(c)), value=pad_idx) for c in contexts]

    # Pad responses to answer_len
    # though first making sure that eos token is added to end of each response
    eos_append = torch.tensor([eos_idx])
    answers = [torch.cat((r[:answer_len-1], eos_append)) for r in answers]
    answers = [pad(r, (0, answer_len - len(r)), value=pad_idx) for r in answers]

    return torch.stack(questions), torch.stack(contexts), torch.stack(answers)

In [16]:
def dolly_collate_fn(batch):
    return collate_fn(batch, "instruction", "context", "response")

In [17]:
def squad_collate_fn(batch):
    return collate_fn(batch, "question", "context", "answers")

In [18]:
dollyDataloader = DataLoader(dolly, batch_size=2, shuffle=True, collate_fn=dolly_collate_fn)

In [19]:
squadDataloader = DataLoader(squad, batch_size=2, shuffle=True, collate_fn=squad_collate_fn)

In [20]:
model = CustomModel(embedding_dim, num_heads, target_len, vocab_size, num_helix_layers=num_helix_layers, num_single_strand_layers=num_single_strand_layers, phm_factor=phm_factor, lm_head_phm_factor=lm_head_phm_factor)

In [21]:
crf = CRFModel(model, vocab_size, beam, low_rank, pad_idx)

In [22]:
from torch.optim import Adam

In [23]:
optm = Adam(crf.parameters(), lr=0.0001)

In [24]:
for batch in squadDataloader:
    break

In [25]:
q, c, a = batch

In [26]:
logits, crf_losses = crf(q, c, a)

In [27]:
from loss_functions.unlikelihood_loss import unlikelihood_loss

In [28]:
unk_loss = unlikelihood_loss(logits, a, 3, allow_self_repeats_idx=pad_idx)

In [29]:
logits = logits.reshape(-1, vocab_size)

In [30]:
logits_loss = torch.nn.functional.cross_entropy(logits, a.view(-1), ignore_index=pad_idx)

In [31]:
total_loss = logits_loss + torch.sum(crf_losses) + unk_loss

In [32]:
total_loss

tensor(157.7595, grad_fn=<AddBackward0>)

In [33]:
optm.zero_grad()
total_loss.backward()
optm.step()