# Showcasing some Pretrain API functionality

In [3]:
from transformers import GPT2Config, GPT2LMHeadModel
import warnings

warnings.filterwarnings("ignore", message="TqdmWarning")


config = GPT2Config(
    n_head=10,
    n_layer=12,
    n_embd=760,
)

model = GPT2LMHeadModel(config)

In [6]:
import random
import pretrain as pt

pages = [random.randint(1, pt.dataset.SubsetFalconLoader.max_pages) for _ in range(1)]
batches = list(
    pt.dataset.SubsetFalconLoader(
        batch_size=1,
        sequence_length=10,
        pages=pages,
    )
)

print(f"Exploring first batch: {batches[0][0]}")

Exploring first batch: tensor([ 8053,   286, 13614, 42544,  9932,   947,   287,  1615,   258,   263])


In [7]:
import math

model.to("cpu")
model.eval()

# Iterate over each page and corresponding batches
losses = []
for batch in batches:
    try:
        inputs = batch.to("cpu")
        outputs = model(inputs, labels=inputs)
        loss = outputs.loss.item()  # Extract scalar loss value
        losses.append(loss)
    except Exception as e:
        print(f"Exception occurred: {e}")
        losses.append(math.inf)  # Use infinity to indicate failure

print(losses)

[10.917858123779297, 11.129181861877441, 10.8306245803833, 10.906065940856934, 10.802318572998047, 10.959781646728516, 10.906094551086426, 11.058467864990234, 10.992799758911133, 10.780656814575195, 10.81259536743164, 11.092170715332031, 10.952272415161133, 11.19336986541748, 11.114277839660645, 11.009113311767578, 10.721773147583008, 11.304011344909668, 11.267609596252441, 10.8449068069458, 10.983033180236816, 10.781233787536621, 10.844575881958008, 10.71689224243164, 11.021096229553223, 10.861352920532227, 10.966983795166016, 10.782625198364258, 10.6871919631958, 11.139700889587402, 11.294319152832031, 10.71467113494873, 10.85884952545166, 11.54813289642334, 10.573038101196289, 11.125123977661133, 10.886366844177246, 11.064266204833984, 11.219145774841309, 10.750901222229004, 11.125911712646484, 11.080912590026855, 10.746673583984375, 10.922796249389648, 11.089228630065918, 11.0670804977417, 10.859468460083008, 11.034564971923828, 10.864897727966309, 10.742331504821777, 10.9028482437

In [None]:
import math
import torch

model.to("cpu")
model.eval()

# Iterate over each page and corresponding batches
losses = []
for batch in batches:
    try:
        inputs = batch.to("cpu")
        logits = model(inputs).logits
        print(f"Batch logits: {logits}")

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = inputs[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = torch.nn.CrossEntropyLoss()
        shift_logits = shift_logits.view(-1, model.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        loss = loss_fct(shift_logits, shift_labels).item()

        losses.append(loss)
        print(f"Batch loss: {loss}")
    except Exception as e:
        print(f"Exception occurred: {e}")
        losses.append(math.inf)  # Use infinity to indicate failure

print(losses)