In [1]:
import torch
from torch.nn.functional import pad
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

In [2]:
from tqdm.notebook import tqdm

In [3]:
from models.model import CustomModelNoDownsize as CustomModel
from models.model import CRFModelV2 as CRFModel

In [4]:
from datasets import load_from_disk

In [5]:
import tokenmonster

In [6]:
tokenizer_file = "english-8000-balanced-v1"

In [7]:
vocab = tokenmonster.load_multiprocess_safe(tokenizer_file)

In [8]:
initial_vocab_size = len(vocab)

In [9]:
# initial_vocab_size is 8000, so largest valid index is 7999

In [10]:
pad_idx = initial_vocab_size # max valid index is now 8000
eos_idx = initial_vocab_size + 1 # max valid index is now 8001

In [11]:
vocab_size = initial_vocab_size + 2

In [12]:
question_len = 256
context_len = 512
answer_len = 100
assert context_len >= answer_len

In [13]:
embedding_dim = 512
num_heads = 8
num_helix_layers = 1
num_single_strand_layers = 1
phm_factor = 4
lm_head_phm_factor = 2
beam = 32
low_rank = 16
batch_size = 52
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
print(f"Using device: {device}")

Using device: cuda


In [15]:
checkpoint_dir = "checkpoints"

In [16]:
path = f"{checkpoint_dir}/crf_model_test.pt"

In [17]:
squad_train_file = "data/squad_train"
dolly_test_file = "data/closed"

In [18]:
squad = load_from_disk(squad_train_file)
dolly = load_from_disk(dolly_test_file)

In [19]:
def collate_fn(batch, question_key, context_key, answer_key):
    """
    Each batch has 3 elements: instruction, context, response
    However since their lengths may vary, we need to pad them.

    For response, must ensure that we add eos token and then begin padding.
    """
    contexts = []
    answers = []

    for elem in batch:
        i = elem[question_key]
        c = elem[context_key]
        r = elem[answer_key]

        c = torch.cat((i, c))
        contexts.append(c)
        answers.append(r)


    # Pad contexts to context_len
    contexts = [pad(c, (0, (context_len+question_len) - len(c)), value=pad_idx) for c in contexts]

    # Pad responses to answer_len
    # though first making sure that eos token is added to end of each response
    # eos_append = torch.tensor([eos_idx])
    # answers = [torch.cat((r[:answer_len-1], eos_append)) for r in answers]
    eos_append = torch.tensor([eos_idx, eos_idx])
    answers = [torch.cat((r[:answer_len-2], eos_append)) for r in answers]
    answers = [pad(r, (0, answer_len - len(r)), value=pad_idx) for r in answers]

    return torch.stack(contexts), torch.stack(answers)

In [20]:
def dolly_collate_fn(batch):
    return collate_fn(batch, "instruction", "context", "response")

In [21]:
def squad_collate_fn(batch):
    return collate_fn(batch, "question", "context", "answers")

In [22]:
dollyDataloader = DataLoader(dolly, batch_size=batch_size, shuffle=True, collate_fn=dolly_collate_fn)

In [23]:
squadDataloader = DataLoader(squad, batch_size=batch_size, shuffle=True, collate_fn=squad_collate_fn)

In [24]:
for batch in squadDataloader:
    break

In [25]:
from torch.optim import Adam
import datetime
import os

In [26]:
def save_checkpoint(losses, model, crf, optm, tensorboard_log_dir):
    # first check to see if checkpoint dir exists, if not, create it
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    torch.save({
    'epoch': len(losses),
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optm.state_dict(),
    'loss': losses[-1],
    'losses': losses,
    'crf_state_dict': crf.state_dict(),
    'model_kwargs': model.kwargs,
    'crf_kwargs': crf.kwargs,
    'tensorboard_log_dir': tensorboard_log_dir
    }, path)
def load_checkpoint(map_location=None):
    checkpoint = torch.load(path, map_location=map_location)
    model = CustomModel(**checkpoint['model_kwargs'])
    crf = CRFModel(model=model, **checkpoint['crf_kwargs'])
    model.load_state_dict(checkpoint['model_state_dict'])
    crf.load_state_dict(checkpoint['crf_state_dict'])
    model = model.to(device)
    crf = crf.to(device)
    optm = torch.optim.Adam(crf.parameters())
    optm.load_state_dict(checkpoint['optimizer_state_dict'])
    losses = checkpoint['losses']
    log_dir = checkpoint['tensorboard_log_dir']
    return losses, model, crf, optm, log_dir
def try_loading():
    """
    First try to load the model, if it doesn't exist, create one
    based on the parameters specified above
    """
    try:
        losses, model, crf, optm, log_dir = load_checkpoint()
        print(f"Resuming, have seen {len(losses)} epochs")
        print(f"Have {sum(p.numel() for p in crf.parameters() if p.requires_grad)} trainable parameters")
        print(f"Logging to {log_dir}")
        return losses, model, crf, optm, log_dir
    except FileNotFoundError:
        # couldn't find model, probably because it doesn't exist
        print("Couldn't find model, creating new one")
        model = CustomModel(embedding_dim, num_heads, vocab_size, num_helix_layers=num_helix_layers, num_single_strand_layers=num_single_strand_layers, phm_factor=phm_factor, lm_head_phm_factor=lm_head_phm_factor)
        model = model.to(device)
        crf = CRFModel(model, vocab_size, beam, low_rank, pad_idx)
        crf = crf.to(device)
        optm = Adam(crf.parameters(), lr=0.001)
        losses = []
        print(f"Have {sum(p.numel() for p in crf.parameters() if p.requires_grad)} trainable parameters")
        # create a string to identify this model for tensorboard logging
        now = datetime.datetime.now()
        log_dir = f"runs/run_at_{now.strftime('%Y-%m-%d_%H-%M-%S')}"
        print(f"Logging to {log_dir}")
        return losses, model, crf, optm, log_dir
    except RuntimeError:
        # probably because model was saved on gpu and now we're using cpu
        # so can still load it, but need to specify map_location
        print("Model found but was saved on gpu, attempting to load on cpu")
        losses, model, crf, optm, log_dir = load_checkpoint(map_location='cpu')
        print(f"Resuming, have seen {len(losses)} epochs")
        print(f"Have {sum(p.numel() for p in crf.parameters() if p.requires_grad)} trainable parameters")
        print(f"Logging to {log_dir}")
        return losses, model, crf, optm, log_dir

In [27]:
from loss_functions.unlikelihood_loss import unlikelihood_loss
from loss_functions.nag_bert_loss import nag_bert_loss

In [28]:
k = 3
nll_loss_weight = 0.5

In [29]:
losses, model, crf, optm, log_dir = try_loading()

Resuming, have seen 15000 epochs
Have 10783810 trainable parameters
Logging to runs/run_at_2023-11-25_15-44-28


In [30]:
writer = SummaryWriter(log_dir=log_dir)

In [31]:
def custom_loss(logits, targets, crf_losses):
    bt_size, _ = targets.shape
    cross_entropy_loss = torch.nn.functional.cross_entropy(logits.reshape(-1, vocab_size), targets.view(-1), ignore_index=pad_idx, reduction='none').view(bt_size, answer_len)
    target_padding_matrix = ~targets.eq(pad_idx)
    target_padding_matrix = target_padding_matrix.type(cross_entropy_loss.type())
    cross_entropy_loss = cross_entropy_loss * target_padding_matrix
    summed_loss = cross_entropy_loss.sum(dim = -1)
    one_step_train_loss = crf_losses + nll_loss_weight * summed_loss
    scaled_train_loss = torch.sum(one_step_train_loss) / torch.sum(target_padding_matrix)
    return scaled_train_loss

In [32]:
def inference(c_sample):
    with torch.no_grad():
        model_input_sample = torch.full_like(r_sample, pad_idx)
        scores, tokens = crf.inference(c_sample, model_input_sample)
    return tokens

In [33]:
def special_decode(tokens):
    # differs from vocab.decode in that it stops decoding when it sees eos_idx or pad_idx
    filtered_tokens = []
    for t in tokens:
        if t == eos_idx or t == pad_idx:
            break
        else:
            filtered_tokens.append(t)
    return vocab.decode(filtered_tokens)

In [34]:
# import adam optimizer
from torch.optim import Adam
optm = Adam(crf.parameters(), lr=0.001)

In [35]:
squadIter = iter(squadDataloader)

In [36]:
total_epochs = 20000
seen = len(losses)
epochs = total_epochs - seen
print(f"Have seen {seen} epochs, training for {epochs} more for a total of {total_epochs}")

Have seen 15000 epochs, training for 2000 more for a total of 17000


In [37]:
pbar = tqdm(range(epochs))
for epoch in pbar:
    try:
        batch = next(squadIter)
    except StopIteration:
        squadIter = iter(squadDataloader)
        batch = next(squadIter)
    c, r = batch
    c = c.to(device)
    r = r.to(device)
    model_input = torch.full_like(r, pad_idx)
    logits, crf_losses = crf(c, model_input, r)
    loss = custom_loss(logits, r, crf_losses)
    optm.zero_grad()
    loss.backward()
    optm.step()
    losses.append(loss.item())
    writer.add_scalar("Loss/train", losses[-1], epoch + seen)
    pbar.set_description(f"Epoch {epoch + seen} of {total_epochs}, loss: {loss.item()}")
    if epoch % 20 == 0:
        save_checkpoint(losses, model, crf, optm, log_dir)

  0%|          | 0/2000 [00:00<?, ?it/s]

In [53]:
save_checkpoint(losses, model, crf, optm, log_dir)

In [110]:
# take sample of c and r to test model generation
sample_i = 23
assert sample_i < batch_size
c_sample = c[sample_i:sample_i+1]
r_sample = r[sample_i:sample_i+1]

In [111]:
tokens = inference(c_sample).tolist()[0]

In [112]:
tokens

[332,
 2259,
 45,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 1263,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 461,
 37,
 719,
 37,
 719,
 37,
 719,
 37,
 719,
 37,
 719,
 37,
 648,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 8001,
 6252,
 1055,
 394,
 37,
 1765,
 454,
 394,
 6252,
 253,
 514,
 6252,
 253,
 514,
 6252,
 253,
 454,
 110,
 505,
 1985,
 514,
 3778,
 8001,
 1996,
 4653,
 1890,
 2004,
 5088,
 2189,
 6164,
 1922,
 71,
 7620]

In [113]:
vocab.decode(c_sample.tolist()[0])

'What country did kids play football in public schools in the 19th century?Association football in itself does not have a classical history. Notwithstanding any similarities to other ball games played around the world FIFA have recognised that no historical connection exists with any game played in antiquity outside Europe. The modern rules of association football are based on the mid-19th century efforts to standardise the widely varying forms of football played in the public schools of England. The history of football in England dates back to at least the eighth century AD.'

In [114]:
special_decode(tokens)

'China'

In [115]:
vocab.decode(r_sample.tolist()[0])

'England'

In [116]:
custom_input = 'On what date is the Feat of Transfiguration celebrated?Ecclesiam suam was given at St. Peter\'s, Rome, on the Feast of the Transfiguration, 6 August 1964, the second year of his Pontificate. It is considered an important document, identifying the Catholic Church with the Body of Christ. A later Council document Lumen Gentium stated that the Church subsists in the Body of Christ, raising questions as to the difference between "is" and "subsists in". Paul VI appealed to "all people of good will" and discussed necessary dialogues within the Church and between the Churches and with atheism.'

In [117]:
encoded = vocab.tokenize(custom_input)

In [118]:
encoded

array([ 332,  772, 2777, 2291, 4165,   36,  233,   37, 1413,  769,   36,
       3604,   37,  705,   37,  716, 3804, 6446,   48,   34,  332,  688,
         37, 1351,  500,  348,  806,  348, 1835, 3242,  655, 3761, 1653,
        387,  275,  290,  791,  445,   15, 4287,   36,  704,   37, 1278,
       4280,   36, 3604,   37,  705,   37,  716, 3804,   15,  216, 5497,
       1234,   23,   15, 7222, 2796, 4277,   36, 1665,  509,  392,  992,
       1003, 3312, 7015, 7582, 5762,   15, 5870, 4694, 6871, 5502, 6232,
         36, 2232,  769, 5501, 1914, 3325, 6294, 5762, 1950, 1086,   36,
       1473,  509,  517, 4441, 6133, 5502, 2705, 2072, 4153,   36, 2232,
        769, 5501,   15, 1683, 2110, 6683, 3073, 1788, 7950,  580,  729,
       2806,  580, 2705, 2072,  727,  850,  777,  516,   38,  828, 3881,
       2933,  580, 1256, 4308,  769, 2391, 2783, 2806, 6497, 6631, 1388,
         37, 1569,  515,   63, 7258, 5502, 1263, 7345, 5502,  388, 5654,
        655,   37, 1498,   37,  800,   17], dtype=u

In [119]:
# encoded from uint16 to int32
encoded = encoded.astype('int32')

In [120]:
encoded = torch.tensor(encoded).unsqueeze(0)

In [121]:
custom_output = inference(encoded.to(device))

In [122]:
custom_output

tensor([[ 332, 1412,  443,  929, 1854,  929,  362, 4310,  929,  362,  929,  362,
          929,  362,  929,  362,  929,  362,  929,  362,  929,  362,  929,  362,
         8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001,
         8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001,
         8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001, 8001,
          727,  485,  485,   70,   37,  773,   37,  773,   37,  773,   37, 1423,
          485, 1423,  485,  773,  672, 3788, 1472, 3788,  749, 3788, 1472, 3788,
         1472, 3788, 3183, 3788, 3183, 7190, 7651,  846, 5711, 6216, 5711, 6216,
         5711, 2141, 3030, 7569]], device='cuda:0')

In [123]:
special_decode(custom_output.tolist()[0])

'Early B – BCE period BCE BCE BCE BCE BCE BCE BCE BCE'