## **Dataset**

In [1]:
from datasets import load_dataset

ds = load_dataset("datablations/c4-filter-small", split="train")
ds = ds.select_columns(["text"])
ds = ds.train_test_split(test_size=0.1)

In [2]:
ds


DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 90000
    })
    test: Dataset({
        features: ['text'],
        num_rows: 10000
    })
})

## **Tokenizer**

In [3]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.normalizers import NFKC
from tokenizers.decoders import ByteLevel as ByteLevelDecoder

# Initialize BPE tokenizer
tokenizer = Tokenizer(BPE())
tokenizer.pre_tokenizer = ByteLevel()
tokenizer.normalizer = NFKC()
tokenizer.decoder = ByteLevelDecoder()

trainer = BpeTrainer(
    vocab_size=100_000,
    special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
)

tokenizer.train_from_iterator(ds["train"]["text"], trainer)
tokenizer.save("gpt_tokenizer.json")

In [4]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(tokenizer_file="gpt_tokenizer.json")
tokenizer.add_special_tokens({
    "bos_token": "<s>",
    "eos_token": "</s>",
    "unk_token": "<unk>",
    "pad_token": "<pad>",
    "mask_token": "<mask>",
})

tokenizer.save_pretrained("gpt-tokenizer")



('gpt-tokenizer\\tokenizer_config.json',
 'gpt-tokenizer\\special_tokens_map.json',
 'gpt-tokenizer\\tokenizer.json')

In [5]:
len(tokenizer)


100000

In [6]:
tokenizer.pad_token_id, tokenizer.eos_token_id, tokenizer.bos_token_id


(1, 2, 0)

In [7]:
def tokenize(example):
    return tokenizer(example["text"])

tokenized_ds = ds.map(
    tokenize, remove_columns=["text"], batched=True
)


Map:   0%|          | 0/90000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [8]:
tokenized_ds


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 90000
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 10000
    })
})

In [9]:
block_size = 512

def group_texts(examples):
    # concat input_ids
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated["input_ids"])
    total_length = (total_length // block_size) * block_size

    # split block_size
    result = {
        k: [concatenated[k][i : i + block_size] for i in range(0, total_length, block_size)]
        for k in concatenated
    }

    # prepare labels
    result["labels"] = result["input_ids"].copy()
    return result

lm_ds = tokenized_ds.map(group_texts, batched=True)


Map:   0%|          | 0/90000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [10]:
lm_ds

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 77014
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 9032
    })
})

In [11]:
import torch

torch.tensor(lm_ds["train"]["input_ids"][:5])

tensor([[  257,   365, 30640,  ...,  1174,  3747,   237],
        [  743, 12350,   239,  ...,   233,  4343,   581],
        [  333,  3153,   266,  ...,   213,  1166,   240],
        [  213,  2050,   308,  ...,  3307,   240,   213],
        [16820,   272,   213,  ...,   240,  2852,  5791]])

In [12]:
lm_ds


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 77014
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 9032
    })
})

## **Model**

In [52]:
from transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_positions=block_size,
    n_embd=512,
    n_layer=6,
    n_head=8,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

model = GPT2LMHeadModel(config)

In [14]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)

training_args = TrainingArguments(
    output_dir="gpt-small-c4",
    logging_dir="logs",
    per_device_train_batch_size=12,
    per_device_eval_batch_size=12,
    num_train_epochs=20,
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=1000,
    save_steps=1000,
    logging_steps=1000,
    save_total_limit=1,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    load_best_model_at_end=True,
    fp16=True,
    gradient_checkpointing=True,
    gradient_accumulation_steps=2
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_ds["train"],
    eval_dataset=lm_ds["test"],
    tokenizer=tokenizer,
    data_collator=data_collator
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


## **Training**

In [15]:
trainer.train()

  0%|          | 0/64180 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
  return fn(*args, **kwargs)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 7.1748, 'grad_norm': 1.340687870979309, 'learning_rate': 4.92209411031474e-05, 'epoch': 0.31}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 6.556571960449219, 'eval_runtime': 49.7847, 'eval_samples_per_second': 181.421, 'eval_steps_per_second': 15.125, 'epoch': 0.31}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 6.3822, 'grad_norm': 1.4861959218978882, 'learning_rate': 4.84418822062948e-05, 'epoch': 0.62}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 6.196662425994873, 'eval_runtime': 52.4877, 'eval_samples_per_second': 172.078, 'eval_steps_per_second': 14.346, 'epoch': 0.62}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 6.1053, 'grad_norm': 1.5324902534484863, 'learning_rate': 4.766282330944219e-05, 'epoch': 0.93}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.962182521820068, 'eval_runtime': 38.0159, 'eval_samples_per_second': 237.585, 'eval_steps_per_second': 19.807, 'epoch': 0.93}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.892, 'grad_norm': 1.6041563749313354, 'learning_rate': 4.6883764412589594e-05, 'epoch': 1.25}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.788722038269043, 'eval_runtime': 52.6251, 'eval_samples_per_second': 171.629, 'eval_steps_per_second': 14.309, 'epoch': 1.25}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.7427, 'grad_norm': 1.727060317993164, 'learning_rate': 4.610470551573699e-05, 'epoch': 1.56}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.65200662612915, 'eval_runtime': 37.8739, 'eval_samples_per_second': 238.476, 'eval_steps_per_second': 19.882, 'epoch': 1.56}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.6236, 'grad_norm': 1.6213436126708984, 'learning_rate': 4.532564661888439e-05, 'epoch': 1.87}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.5418829917907715, 'eval_runtime': 49.5891, 'eval_samples_per_second': 182.137, 'eval_steps_per_second': 15.185, 'epoch': 1.87}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.4991, 'grad_norm': 1.7167649269104004, 'learning_rate': 4.4546587722031785e-05, 'epoch': 2.18}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.448551654815674, 'eval_runtime': 53.1791, 'eval_samples_per_second': 169.841, 'eval_steps_per_second': 14.16, 'epoch': 2.18}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.4177, 'grad_norm': 1.6319338083267212, 'learning_rate': 4.3767528825179186e-05, 'epoch': 2.49}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.371700763702393, 'eval_runtime': 53.9646, 'eval_samples_per_second': 167.369, 'eval_steps_per_second': 13.954, 'epoch': 2.49}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.3543, 'grad_norm': 1.625144362449646, 'learning_rate': 4.2989248987223435e-05, 'epoch': 2.8}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.305250644683838, 'eval_runtime': 37.871, 'eval_samples_per_second': 238.494, 'eval_steps_per_second': 19.883, 'epoch': 2.8}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.2758, 'grad_norm': 1.592323660850525, 'learning_rate': 4.221019009037083e-05, 'epoch': 3.12}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.2413554191589355, 'eval_runtime': 51.5903, 'eval_samples_per_second': 175.072, 'eval_steps_per_second': 14.596, 'epoch': 3.12}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.2012, 'grad_norm': 1.7870405912399292, 'learning_rate': 4.1431910252415086e-05, 'epoch': 3.43}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.186042785644531, 'eval_runtime': 37.9651, 'eval_samples_per_second': 237.903, 'eval_steps_per_second': 19.834, 'epoch': 3.43}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.1618, 'grad_norm': 1.6234861612319946, 'learning_rate': 4.065285135556248e-05, 'epoch': 3.74}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.133570194244385, 'eval_runtime': 37.6084, 'eval_samples_per_second': 240.159, 'eval_steps_per_second': 20.022, 'epoch': 3.74}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.1123, 'grad_norm': 1.6437910795211792, 'learning_rate': 3.9874571517606737e-05, 'epoch': 4.05}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.085620880126953, 'eval_runtime': 37.0658, 'eval_samples_per_second': 243.674, 'eval_steps_per_second': 20.315, 'epoch': 4.05}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.0395, 'grad_norm': 1.7911509275436401, 'learning_rate': 3.909551262075413e-05, 'epoch': 4.36}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.043959140777588, 'eval_runtime': 37.0386, 'eval_samples_per_second': 243.854, 'eval_steps_per_second': 20.33, 'epoch': 4.36}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 5.0134, 'grad_norm': 1.6576942205429077, 'learning_rate': 3.831723278279838e-05, 'epoch': 4.67}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 5.003215312957764, 'eval_runtime': 37.0944, 'eval_samples_per_second': 243.487, 'eval_steps_per_second': 20.3, 'epoch': 4.67}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.983, 'grad_norm': 1.7939783334732056, 'learning_rate': 3.7538173885945775e-05, 'epoch': 4.99}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.964324951171875, 'eval_runtime': 37.3421, 'eval_samples_per_second': 241.872, 'eval_steps_per_second': 20.165, 'epoch': 4.99}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.9113, 'grad_norm': 1.7928218841552734, 'learning_rate': 3.6759114989093176e-05, 'epoch': 5.3}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.932944297790527, 'eval_runtime': 37.0806, 'eval_samples_per_second': 243.577, 'eval_steps_per_second': 20.307, 'epoch': 5.3}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.8946, 'grad_norm': 1.8746986389160156, 'learning_rate': 3.5980835151137426e-05, 'epoch': 5.61}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.900229454040527, 'eval_runtime': 37.1532, 'eval_samples_per_second': 243.102, 'eval_steps_per_second': 20.267, 'epoch': 5.61}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.8683, 'grad_norm': 1.783697485923767, 'learning_rate': 3.520177625428483e-05, 'epoch': 5.92}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.874930381774902, 'eval_runtime': 36.9787, 'eval_samples_per_second': 244.248, 'eval_steps_per_second': 20.363, 'epoch': 5.92}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.8187, 'grad_norm': 1.6635746955871582, 'learning_rate': 3.4423496416329076e-05, 'epoch': 6.23}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.846935749053955, 'eval_runtime': 37.1108, 'eval_samples_per_second': 243.379, 'eval_steps_per_second': 20.291, 'epoch': 6.23}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.7957, 'grad_norm': 1.7984720468521118, 'learning_rate': 3.364443751947647e-05, 'epoch': 6.54}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.824095726013184, 'eval_runtime': 37.0194, 'eval_samples_per_second': 243.98, 'eval_steps_per_second': 20.341, 'epoch': 6.54}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.7783, 'grad_norm': 1.7740974426269531, 'learning_rate': 3.286537862262387e-05, 'epoch': 6.86}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.801535606384277, 'eval_runtime': 37.4415, 'eval_samples_per_second': 241.23, 'eval_steps_per_second': 20.111, 'epoch': 6.86}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.7395, 'grad_norm': 1.8004045486450195, 'learning_rate': 3.208709878466812e-05, 'epoch': 7.17}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.782341003417969, 'eval_runtime': 37.0707, 'eval_samples_per_second': 243.642, 'eval_steps_per_second': 20.313, 'epoch': 7.17}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.7114, 'grad_norm': 2.062527894973755, 'learning_rate': 3.130803988781552e-05, 'epoch': 7.48}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.76706600189209, 'eval_runtime': 36.9379, 'eval_samples_per_second': 244.519, 'eval_steps_per_second': 20.386, 'epoch': 7.48}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.7065, 'grad_norm': 1.9962882995605469, 'learning_rate': 3.052976004985977e-05, 'epoch': 7.79}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.749912261962891, 'eval_runtime': 37.0551, 'eval_samples_per_second': 243.745, 'eval_steps_per_second': 20.321, 'epoch': 7.79}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.6786, 'grad_norm': 1.8524738550186157, 'learning_rate': 2.975070115300717e-05, 'epoch': 8.1}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.733800411224365, 'eval_runtime': 37.0556, 'eval_samples_per_second': 243.742, 'eval_steps_per_second': 20.321, 'epoch': 8.1}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.6408, 'grad_norm': 1.8569316864013672, 'learning_rate': 2.8971642256154568e-05, 'epoch': 8.41}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.719481468200684, 'eval_runtime': 37.0569, 'eval_samples_per_second': 243.733, 'eval_steps_per_second': 20.32, 'epoch': 8.41}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.639, 'grad_norm': 1.78729248046875, 'learning_rate': 2.8193362418198814e-05, 'epoch': 8.73}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.704639434814453, 'eval_runtime': 36.9461, 'eval_samples_per_second': 244.464, 'eval_steps_per_second': 20.381, 'epoch': 8.73}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.6306, 'grad_norm': 1.821032166481018, 'learning_rate': 2.7414303521346212e-05, 'epoch': 9.04}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.694730758666992, 'eval_runtime': 37.042, 'eval_samples_per_second': 243.832, 'eval_steps_per_second': 20.328, 'epoch': 9.04}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.5925, 'grad_norm': 1.9462559223175049, 'learning_rate': 2.6636023683390464e-05, 'epoch': 9.35}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.680616855621338, 'eval_runtime': 36.9092, 'eval_samples_per_second': 244.708, 'eval_steps_per_second': 20.401, 'epoch': 9.35}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.5793, 'grad_norm': 1.925039529800415, 'learning_rate': 2.5856964786537862e-05, 'epoch': 9.66}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.673486232757568, 'eval_runtime': 37.0587, 'eval_samples_per_second': 243.722, 'eval_steps_per_second': 20.319, 'epoch': 9.66}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.5797, 'grad_norm': 1.9093952178955078, 'learning_rate': 2.507790588968526e-05, 'epoch': 9.97}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.660247325897217, 'eval_runtime': 37.1518, 'eval_samples_per_second': 243.11, 'eval_steps_per_second': 20.268, 'epoch': 9.97}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.5404, 'grad_norm': 1.9301847219467163, 'learning_rate': 2.4299626051729513e-05, 'epoch': 10.28}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.655567169189453, 'eval_runtime': 37.0851, 'eval_samples_per_second': 243.548, 'eval_steps_per_second': 20.305, 'epoch': 10.28}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.5374, 'grad_norm': 1.9753071069717407, 'learning_rate': 2.352056715487691e-05, 'epoch': 10.6}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.646126747131348, 'eval_runtime': 37.0222, 'eval_samples_per_second': 243.962, 'eval_steps_per_second': 20.339, 'epoch': 10.6}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.5373, 'grad_norm': 1.9087965488433838, 'learning_rate': 2.274228731692116e-05, 'epoch': 10.91}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.637481212615967, 'eval_runtime': 37.0866, 'eval_samples_per_second': 243.538, 'eval_steps_per_second': 20.304, 'epoch': 10.91}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4964, 'grad_norm': 1.9045668840408325, 'learning_rate': 2.1963228420068558e-05, 'epoch': 11.22}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.630731105804443, 'eval_runtime': 37.0057, 'eval_samples_per_second': 244.071, 'eval_steps_per_second': 20.348, 'epoch': 11.22}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.501, 'grad_norm': 1.861090898513794, 'learning_rate': 2.1184169523215956e-05, 'epoch': 11.53}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.623802185058594, 'eval_runtime': 36.9562, 'eval_samples_per_second': 244.397, 'eval_steps_per_second': 20.375, 'epoch': 11.53}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4999, 'grad_norm': 2.027089834213257, 'learning_rate': 2.0405110626363354e-05, 'epoch': 11.84}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.6139817237854, 'eval_runtime': 37.0312, 'eval_samples_per_second': 243.902, 'eval_steps_per_second': 20.334, 'epoch': 11.84}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4756, 'grad_norm': 1.9271458387374878, 'learning_rate': 1.9626830788407603e-05, 'epoch': 12.15}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.608966827392578, 'eval_runtime': 37.0398, 'eval_samples_per_second': 243.846, 'eval_steps_per_second': 20.329, 'epoch': 12.15}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4597, 'grad_norm': 1.887660026550293, 'learning_rate': 1.8847771891555e-05, 'epoch': 12.46}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.605347156524658, 'eval_runtime': 37.1146, 'eval_samples_per_second': 243.354, 'eval_steps_per_second': 20.288, 'epoch': 12.46}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4628, 'grad_norm': 1.9143131971359253, 'learning_rate': 1.80687129947024e-05, 'epoch': 12.78}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.597010135650635, 'eval_runtime': 37.0835, 'eval_samples_per_second': 243.558, 'eval_steps_per_second': 20.306, 'epoch': 12.78}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4539, 'grad_norm': 1.9500900506973267, 'learning_rate': 1.7289654097849797e-05, 'epoch': 13.09}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.591020584106445, 'eval_runtime': 37.0738, 'eval_samples_per_second': 243.622, 'eval_steps_per_second': 20.311, 'epoch': 13.09}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4265, 'grad_norm': 1.9350168704986572, 'learning_rate': 1.65121533187909e-05, 'epoch': 13.4}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.590110778808594, 'eval_runtime': 37.0802, 'eval_samples_per_second': 243.58, 'eval_steps_per_second': 20.307, 'epoch': 13.4}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4316, 'grad_norm': 2.02030086517334, 'learning_rate': 1.57330944219383e-05, 'epoch': 13.71}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.5829081535339355, 'eval_runtime': 37.0392, 'eval_samples_per_second': 243.849, 'eval_steps_per_second': 20.33, 'epoch': 13.71}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4333, 'grad_norm': 2.0084433555603027, 'learning_rate': 1.4954035525085697e-05, 'epoch': 14.02}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.577147483825684, 'eval_runtime': 36.9794, 'eval_samples_per_second': 244.244, 'eval_steps_per_second': 20.363, 'epoch': 14.02}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4019, 'grad_norm': 2.002607583999634, 'learning_rate': 1.4174976628233093e-05, 'epoch': 14.33}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.572101593017578, 'eval_runtime': 37.1056, 'eval_samples_per_second': 243.414, 'eval_steps_per_second': 20.293, 'epoch': 14.33}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.412, 'grad_norm': 2.082624912261963, 'learning_rate': 1.3396696790277346e-05, 'epoch': 14.65}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.5708513259887695, 'eval_runtime': 37.0564, 'eval_samples_per_second': 243.736, 'eval_steps_per_second': 20.32, 'epoch': 14.65}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.4036, 'grad_norm': 1.9881033897399902, 'learning_rate': 1.2617637893424744e-05, 'epoch': 14.96}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.570649147033691, 'eval_runtime': 37.0535, 'eval_samples_per_second': 243.756, 'eval_steps_per_second': 20.322, 'epoch': 14.96}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3814, 'grad_norm': 1.9801015853881836, 'learning_rate': 1.1838578996572142e-05, 'epoch': 15.27}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.563081741333008, 'eval_runtime': 37.0876, 'eval_samples_per_second': 243.532, 'eval_steps_per_second': 20.303, 'epoch': 15.27}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3843, 'grad_norm': 1.921825885772705, 'learning_rate': 1.105952009971954e-05, 'epoch': 15.58}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.558789253234863, 'eval_runtime': 36.997, 'eval_samples_per_second': 244.128, 'eval_steps_per_second': 20.353, 'epoch': 15.58}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3849, 'grad_norm': 2.0296237468719482, 'learning_rate': 1.0282019320660642e-05, 'epoch': 15.89}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.557994842529297, 'eval_runtime': 36.9159, 'eval_samples_per_second': 244.664, 'eval_steps_per_second': 20.398, 'epoch': 15.89}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3716, 'grad_norm': 2.08455228805542, 'learning_rate': 9.50296042380804e-06, 'epoch': 16.2}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.553534030914307, 'eval_runtime': 37.081, 'eval_samples_per_second': 243.575, 'eval_steps_per_second': 20.307, 'epoch': 16.2}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3648, 'grad_norm': 2.0631179809570312, 'learning_rate': 8.723901526955438e-06, 'epoch': 16.52}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.553905487060547, 'eval_runtime': 37.0772, 'eval_samples_per_second': 243.6, 'eval_steps_per_second': 20.309, 'epoch': 16.52}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3672, 'grad_norm': 1.9733415842056274, 'learning_rate': 7.944842630102836e-06, 'epoch': 16.83}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.548783302307129, 'eval_runtime': 37.0364, 'eval_samples_per_second': 243.868, 'eval_steps_per_second': 20.331, 'epoch': 16.83}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3582, 'grad_norm': 1.9005111455917358, 'learning_rate': 7.166562792147087e-06, 'epoch': 17.14}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.546670436859131, 'eval_runtime': 37.0415, 'eval_samples_per_second': 243.835, 'eval_steps_per_second': 20.329, 'epoch': 17.14}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.348, 'grad_norm': 1.9467021226882935, 'learning_rate': 6.388282954191336e-06, 'epoch': 17.45}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.545351028442383, 'eval_runtime': 36.987, 'eval_samples_per_second': 244.194, 'eval_steps_per_second': 20.359, 'epoch': 17.45}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.351, 'grad_norm': 2.0583221912384033, 'learning_rate': 5.609224057338735e-06, 'epoch': 17.76}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.541982173919678, 'eval_runtime': 37.0109, 'eval_samples_per_second': 244.036, 'eval_steps_per_second': 20.345, 'epoch': 17.76}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3499, 'grad_norm': 2.02461838722229, 'learning_rate': 4.830165160486133e-06, 'epoch': 18.07}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.541180610656738, 'eval_runtime': 37.0994, 'eval_samples_per_second': 243.454, 'eval_steps_per_second': 20.297, 'epoch': 18.07}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3387, 'grad_norm': 1.984906554222107, 'learning_rate': 4.05110626363353e-06, 'epoch': 18.39}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.541054725646973, 'eval_runtime': 36.8656, 'eval_samples_per_second': 244.998, 'eval_steps_per_second': 20.426, 'epoch': 18.39}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3347, 'grad_norm': 1.96657133102417, 'learning_rate': 3.2728264256777816e-06, 'epoch': 18.7}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.539795398712158, 'eval_runtime': 37.023, 'eval_samples_per_second': 243.957, 'eval_steps_per_second': 20.339, 'epoch': 18.7}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.338, 'grad_norm': 1.9721380472183228, 'learning_rate': 2.4937675288251796e-06, 'epoch': 19.01}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.538328170776367, 'eval_runtime': 37.1705, 'eval_samples_per_second': 242.988, 'eval_steps_per_second': 20.258, 'epoch': 19.01}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3294, 'grad_norm': 2.0633718967437744, 'learning_rate': 1.7154876908694297e-06, 'epoch': 19.32}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.537507057189941, 'eval_runtime': 37.0175, 'eval_samples_per_second': 243.993, 'eval_steps_per_second': 20.342, 'epoch': 19.32}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3277, 'grad_norm': 2.0319652557373047, 'learning_rate': 9.364287940168278e-07, 'epoch': 19.63}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.535121440887451, 'eval_runtime': 37.0201, 'eval_samples_per_second': 243.975, 'eval_steps_per_second': 20.34, 'epoch': 19.63}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


{'loss': 4.3303, 'grad_norm': 2.007161855697632, 'learning_rate': 1.5814895606107822e-07, 'epoch': 19.94}


  0%|          | 0/753 [00:00<?, ?it/s]

{'eval_loss': 4.535489559173584, 'eval_runtime': 37.0434, 'eval_samples_per_second': 243.822, 'eval_steps_per_second': 20.327, 'epoch': 19.94}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


{'train_runtime': 20930.0957, 'train_samples_per_second': 73.592, 'train_steps_per_second': 3.066, 'train_loss': 4.774919954361147, 'epoch': 20.0}


TrainOutput(global_step=64180, training_loss=4.774919954361147, metrics={'train_runtime': 20930.0957, 'train_samples_per_second': 73.592, 'train_steps_per_second': 3.066, 'total_flos': 8.950241713717248e+16, 'train_loss': 4.774919954361147, 'epoch': 20.0})

In [16]:
torch.cuda.empty_cache()

In [17]:
# push to hub
trainer.push_to_hub("binhphap5/gpt-small-c4")

model.safetensors:   0%|          | 0.00/282M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/binhphap5/gpt-small-c4/commit/e8d5bec2572a149fe86eb9a6fe8e3d205a0eec3a', commit_message='binhphap5/gpt-small-c4', commit_description='', oid='e8d5bec2572a149fe86eb9a6fe8e3d205a0eec3a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/binhphap5/gpt-small-c4', endpoint='https://huggingface.co', repo_type='model', repo_id='binhphap5/gpt-small-c4'), pr_revision=None, pr_num=None)

## **Inference**

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "binhphap5/gpt-small-c4"

model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [74]:
prompt = "We gotta get at least ten thousand"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

In [75]:
output = model.generate(
    **inputs,
    max_new_tokens=50, # This is just for testing purpose, can be any value but must not exceed (512 - len(prompt))
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id,
    temperature=0.9,
)

In [76]:
print(tokenizer.decode(output[0], skip_special_tokens=True))

 We gotta get at least ten thousand dollars to get a book. The Book of Mormon has had more of a new book, and more of it was available on a library of books for readers. And it was a book read. A book I read recently. Most readers had been reading


In [97]:
import math
import torch
# Shift for labels (causal LM setting: predict token t+1 from token t)
labels = output[:, 1:].clone()
inputs = output[:, :-1].clone()

with torch.no_grad():
    outputs = model(inputs)
    logits = outputs.logits

# Compute log softmax over vocabulary
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
# Gather log-probabilities corresponding to the labels
selected_log_probs = log_probs.gather(2, labels.unsqueeze(-1)).squeeze(-1)

# Sum negative log probs → total NLL
nll = -selected_log_probs.sum().item()
num_tokens = labels.numel()
perplexity = math.exp(nll / num_tokens)
perplexity

21.237168770906834