In [1]:
import torch
from torch.nn.functional import pad
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

In [2]:
from tqdm.notebook import tqdm

In [3]:
# from models.model import CustomModelNoDownsize as CustomModel
# from models.model import CRFModelV2 as CRFModel
from models.model import CustomModel
from models.model import CRFModel

In [4]:
from datasets import load_from_disk

In [5]:
import tokenmonster

In [6]:
tokenizer_file = "english-8000-balanced-v1"

In [7]:
vocab = tokenmonster.load_multiprocess_safe(tokenizer_file)

In [8]:
initial_vocab_size = len(vocab)

In [9]:
# initial_vocab_size is 8000, so largest valid index is 7999

In [10]:
pad_idx = initial_vocab_size # max valid index is now 8000
eos_idx = initial_vocab_size + 1 # max valid index is now 8001

In [11]:
vocab_size = initial_vocab_size + 2

In [12]:
question_len = 256
context_len = 512
answer_len = 100
assert context_len >= answer_len

In [13]:
embedding_dim = 512
num_heads = 8
num_helix_layers = 1
num_single_strand_layers = 1
phm_factor = 4
lm_head_phm_factor = 2
beam = 32
low_rank = 16
batch_size = 25
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [14]:
print(f"Using device: {device}")

Using device: cuda


In [15]:
checkpoint_dir = "checkpoints"

In [16]:
path = f"{checkpoint_dir}/crf_model_reflex.pt"

In [17]:
squad_train_file = "data/squad_train"
dolly_test_file = "data/closed"

In [18]:
squad = load_from_disk(squad_train_file)
dolly = load_from_disk(dolly_test_file)

In [19]:
def collate_fn(batch, question_key, context_key, answer_key):
    """
    Each batch has 3 elements: instruction, context, response
    However since their lengths may vary, we need to pad them.

    For response, must ensure that we add eos token and then begin padding.
    """
    contexts = []
    answers = []

    for elem in batch:
        i = elem[question_key]
        c = elem[context_key]
        r = elem[answer_key]

        c = torch.cat((i, c))
        contexts.append(c)
        answers.append(r)


    # Pad contexts to context_len
    contexts = [pad(c, (0, (context_len+question_len) - len(c)), value=pad_idx) for c in contexts]

    # Pad responses to answer_len
    # though first making sure that eos token is added to end of each response
    # eos_append = torch.tensor([eos_idx])
    # answers = [torch.cat((r[:answer_len-1], eos_append)) for r in answers]
    eos_append = torch.tensor([eos_idx, eos_idx])
    answers = [torch.cat((r[:answer_len-2], eos_append)) for r in answers]
    answers = [pad(r, (0, answer_len - len(r)), value=pad_idx) for r in answers]

    return torch.stack(contexts), torch.stack(answers)

In [20]:
def dolly_collate_fn(batch):
    return collate_fn(batch, "instruction", "context", "response")

In [21]:
def squad_collate_fn(batch):
    return collate_fn(batch, "question", "context", "answers")

In [22]:
dollyDataloader = DataLoader(dolly, batch_size=batch_size, shuffle=True, collate_fn=dolly_collate_fn)

In [23]:
squadDataloader = DataLoader(squad, batch_size=batch_size, shuffle=True, collate_fn=squad_collate_fn)

In [24]:
from torch.optim import Adam
import datetime
import os

In [25]:
def save_checkpoint(losses, model, crf, optm, tensorboard_log_dir):
    # first check to see if checkpoint dir exists, if not, create it
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    torch.save({
    'epoch': len(losses),
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optm.state_dict(),
    'loss': losses[-1],
    'losses': losses,
    'crf_state_dict': crf.state_dict(),
    'model_kwargs': model.kwargs,
    'crf_kwargs': crf.kwargs,
    'tensorboard_log_dir': tensorboard_log_dir
    }, path)
def load_checkpoint(map_location=None):
    checkpoint = torch.load(path, map_location=map_location)
    model = CustomModel(**checkpoint['model_kwargs'])
    crf = CRFModel(model=model, **checkpoint['crf_kwargs'])
    model.load_state_dict(checkpoint['model_state_dict'])
    crf.load_state_dict(checkpoint['crf_state_dict'])
    model = model.to(device)
    crf = crf.to(device)
    optm = torch.optim.Adam(crf.parameters())
    optm.load_state_dict(checkpoint['optimizer_state_dict'])
    losses = checkpoint['losses']
    log_dir = checkpoint['tensorboard_log_dir']
    return losses, model, crf, optm, log_dir
def try_loading():
    """
    First try to load the model, if it doesn't exist, create one
    based on the parameters specified above
    """
    try:
        losses, model, crf, optm, log_dir = load_checkpoint()
        print(f"Resuming, have seen {len(losses)} epochs")
        print(f"Have {sum(p.numel() for p in crf.parameters() if p.requires_grad)} trainable parameters")
        print(f"Logging to {log_dir}")
        return losses, model, crf, optm, log_dir
    except FileNotFoundError:
        # couldn't find model, probably because it doesn't exist
        print("Couldn't find model, creating new one")
        model = CustomModel(embedding_dim, num_heads, vocab_size, num_helix_layers=num_helix_layers, num_single_strand_layers=num_single_strand_layers, phm_factor=phm_factor, lm_head_phm_factor=lm_head_phm_factor)
        model = model.to(device)
        crf = CRFModel(model, vocab_size, beam, low_rank, pad_idx)
        crf = crf.to(device)
        optm = Adam(crf.parameters(), lr=0.001)
        losses = []
        print(f"Have {sum(p.numel() for p in crf.parameters() if p.requires_grad)} trainable parameters")
        # create a string to identify this model for tensorboard logging
        now = datetime.datetime.now()
        log_dir = f"runs/run_at_{now.strftime('%Y-%m-%d_%H-%M-%S')}"
        print(f"Logging to {log_dir}")
        return losses, model, crf, optm, log_dir
    except RuntimeError:
        # probably because model was saved on gpu and now we're using cpu
        # so can still load it, but need to specify map_location
        print("Model found but was saved on gpu, attempting to load on cpu")
        losses, model, crf, optm, log_dir = load_checkpoint(map_location='cpu')
        print(f"Resuming, have seen {len(losses)} epochs")
        print(f"Have {sum(p.numel() for p in crf.parameters() if p.requires_grad)} trainable parameters")
        print(f"Logging to {log_dir}")
        return losses, model, crf, optm, log_dir

In [26]:
from loss_functions.unlikelihood_loss import unlikelihood_loss
from loss_functions.nag_bert_loss import custom_loss

In [27]:
k = 3
nll_loss_weight = 0.5

In [28]:
losses, model, crf, optm, log_dir = try_loading()

Resuming, have seen 12484 epochs
Have 10783810 trainable parameters
Logging to runs/run_at_2023-11-26_08-06-41


In [29]:
writer = SummaryWriter(log_dir=log_dir)

In [30]:
def inference(c_sample, model_input_sample=None):
    with torch.no_grad():
        bs, _ = c_sample.shape
        if model_input_sample is None:
            model_input_sample = torch.ones(bs, answer_len) * pad_idx
            model_input_sample = model_input_sample.long().to(device)
        scores, tokens = crf.inference(c_sample, model_input_sample)
    return tokens

In [31]:
def special_decode(tokens):
    # differs from vocab.decode in that it stops decoding when it sees eos_idx or pad_idx
    filtered_tokens = []
    for t in tokens:
        if t == eos_idx or t == pad_idx:
            break
        else:
            filtered_tokens.append(t)
    return vocab.decode(filtered_tokens)

In [32]:
squadIter = iter(squadDataloader)

In [33]:
total_epochs = 22000
seen = len(losses)
epochs = total_epochs - seen
print(f"Have seen {seen} epochs, training for {epochs} more for a total of {total_epochs}")

Have seen 12484 epochs, training for 9516 more for a total of 22000


In [34]:
pbar = tqdm(range(epochs))
for epoch in pbar:
    try:
        batch = next(squadIter)
    except StopIteration:
        squadIter = iter(squadDataloader)
        batch = next(squadIter)
    c, r = batch
    c = c.to(device)
    r = r.to(device)
    model_input = torch.full_like(r, pad_idx)
    logits, crf_losses = crf(c, model_input, r)
    loss = custom_loss(logits, r, crf_losses, vocab_size, pad_idx, nll_loss_weight, answer_len)
    # model_input is full of pad_idx, but we also want to train the model
    # to be able to correct itself (reflection)
    with torch.no_grad():
        _, model_prediction = crf.inference(c, model_input)
    logits_reflection, crf_losses_reflection = crf(c, model_prediction, r)
    loss_reflection = custom_loss(logits_reflection, r, crf_losses_reflection, vocab_size, pad_idx, nll_loss_weight, answer_len)
    loss = loss + loss_reflection
    optm.zero_grad()
    loss.backward()
    optm.step()
    losses.append(loss.item())
    writer.add_scalar("Loss/train", losses[-1], epoch + seen)
    pbar.set_description(f"Epoch {epoch + seen} of {total_epochs}, loss: {loss.item()}")
    if epoch % 20 == 0:
        save_checkpoint(losses, model, crf, optm, log_dir)

  0%|          | 0/9516 [00:00<?, ?it/s]

In [51]:
save_checkpoint(losses, model, crf, optm, log_dir)

In [52]:
# take sample of c and r to test model generation
sample_i = 9
assert sample_i < batch_size
c_sample = c[sample_i:sample_i+1]
r_sample = r[sample_i:sample_i+1]

In [53]:
tokens = inference(c_sample)
tokensl = tokens.tolist()[0]

In [54]:
vocab.decode(c_sample.tolist()[0])

'Were peopel in pre-industrial societies considered to have long or short lifespans?The work of children was important in pre-industrial societies, as children needed to provide their labour for their survival and that of their group. Pre-industrial societies were characterised by low productivity and short life expectancy, preventing children from participating in productive work would be more harmful to their welfare and that of their group in the long run. In pre-industrial societies, there was little need for children to attend school. This is especially the case in non literate societies. Most pre-industrial skill and knowledge were amenable to being passed down through direct mentoring or apprenticing by competent adults.'

In [55]:
special_decode(tokensl)

'functionals'

In [56]:
vocab.decode(r_sample.tolist()[0])

'short life expectancy'

In [57]:
reflected = inference(c_sample, tokens)

In [58]:
reflectedl = reflected.tolist()[0]
special_decode(reflectedl)

'(three'

In [59]:
custom_input = 'On what date is the Feat of Transfiguration celebrated?Ecclesiam suam was given at St. Peter\'s, Rome, on the Feast of the Transfiguration, 6 August 1964, the second year of his Pontificate. It is considered an important document, identifying the Catholic Church with the Body of Christ. A later Council document Lumen Gentium stated that the Church subsists in the Body of Christ, raising questions as to the difference between "is" and "subsists in". Paul VI appealed to "all people of good will" and discussed necessary dialogues within the Church and between the Churches and with atheism.'

In [60]:
encoded = vocab.tokenize(custom_input)

In [62]:
# encoded from uint16 to int32
encoded = encoded.astype('int32')

In [63]:
encoded = torch.tensor(encoded).unsqueeze(0)

In [64]:
custom_output = inference(encoded.to(device))

In [65]:
special_decode(custom_output.tolist()[0])

'146'