# Train GPT-2 on OpenWebText w/ Distributionally Robust Optimization

See python file of the same name as this notebook for the production version.

# Train Model

In [1]:
import huggingface_hub as hf_hub

def read_key(filename):
    with open(filename) as f:
        key = f.read().strip()
    return key

hf_hub.login(token=read_key('huggingface.key'), write_permission=True, add_to_git_credential=True)



Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /home/kdbanman/.cache/huggingface/token
Login successful


In [2]:
from tqdm.notebook import tqdm

In [3]:
from datasets import load_from_disk
from torch.utils.data.dataloader import DataLoader

batch_size = 32

tokenized_datasets = load_from_disk('tokenized-openwebtext')

tokenized_datasets.set_format("torch")
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_datasets["test"], batch_size=batch_size)

In [4]:
from torch.nn import CrossEntropyLoss
import torch


def dro_loss(inputs, logits, alpha=0.8):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss(reduce=False)
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    
    # Resize and average loss per sample
    loss_per_sample = loss.view(shift_logits.size(0), shift_logits.size(1)).mean(axis=1)

    if alpha < 1.0:
        # Keep only largest alpha-fraction of losses by reweighting smallest (1-alpha)-fraction to zero
        num_samples = len(loss_per_sample)
        num_to_ignore = num_samples - int(num_samples * alpha)

        if num_to_ignore >= 1 or num_to_ignore < num_samples:
            cutoff_value, _cutoff_index = torch.kthvalue(loss_per_sample, num_to_ignore, dim=0)
            loss_per_sample[loss_per_sample < cutoff_value] = 0
        else:
            print("ERROR: crazy reweighting request from the following.  Skipping DRO reweighting.")
            priunt(f'alpha: {alpha}')
            priunt(f'num_samples: {num_samples}')
            priunt(f'num_to_ignore: {num_to_ignore}')
            print(f'losses: {loss_per_sample}')
    
    return loss_per_sample.mean()

In [5]:
weight_decay = 0.1


def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
    params_with_wd, params_without_wd = [], []
    for n, p in model.named_parameters():
        if any(nd in n for nd in no_decay):
            params_without_wd.append(p)
        else:
            params_with_wd.append(p)
    return [
        {"params": params_with_wd, "weight_decay": weight_decay},
        {"params": params_without_wd, "weight_decay": 0.0},
    ]

In [6]:
def evaluate(max_eval_batches=None):
    model.eval()
    losses = []

    if max_eval_batches is None:
        max_eval_batches = len(eval_dataloader)
        
    for step, batch in tqdm(
        enumerate(eval_dataloader), total=max_eval_batches
    ):
        with torch.no_grad():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])

        losses.append(accelerator.gather(outputs.loss))
        
        if step >= max_eval_batches:
            break

    if len(losses[0].shape) == 0:
        loss = torch.mean(torch.stack(losses))
    else:
        loss = torch.mean(torch.cat(losses))
        
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [7]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

# For now, the dataset is tokenized in advance using context_length, so this value is fixed
# at tokenization time.  In the future, tokenization should really be streamed.  Then 
# context_length can be varied.
context_length = 256

tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

In [8]:
model = GPT2LMHeadModel(config)

In [9]:
from torch.optim import AdamW

optimizer = AdamW(get_grouped_params(model), lr=5e-4)

In [10]:
from accelerate import Accelerator

accelerator = Accelerator(cpu=False)

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

In [11]:
from transformers import get_scheduler

num_train_epochs = 1
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=1_000,
    num_training_steps=num_training_steps,
)

In [12]:
from huggingface_hub import Repository, get_full_repo_name

model_name = "gpt2-openwebtext-dro-test"
repo_name = get_full_repo_name(model_name)
repo_name

'kdbanman/gpt2-openwebtext-dro-test'

In [13]:
import huggingface_hub

try:
    huggingface_hub.model_info(repo_name)
except huggingface_hub.utils.RepositoryNotFoundError:
    huggingface_hub.create_repo(repo_name)

In [14]:
output_dir = model_name
repo = Repository(output_dir, clone_from=repo_name)

/home/kdbanman/tmp/gpt2-openwebtext-dro-test is already a clone of https://huggingface.co/kdbanman/gpt2-openwebtext-dro-test. Make sure you pull the latest changes with `repo.git_pull()`.


In [15]:
dro_alpha = 0.8

epoch_step_logging_interval = 8

gradient_accumulation_steps = 8
gradient_step_eval_interval = 30
gradient_steps_since_eval = 0
gradient_steps = 0

num_eval_batches = 50

model.train()
for epoch in range(1, num_train_epochs + 1):
    for epoch_step, batch in tqdm(
        enumerate(train_dataloader, start=1), total=len(train_dataloader)
    ):
        logits = model(batch["input_ids"]).logits
        loss = dro_loss(batch["input_ids"], logits, alpha=dro_alpha)
        if epoch_step % epoch_step_logging_interval == 0:
            accelerator.print(
                {
                    "lr": lr_scheduler.get_lr(),
                    "samples/contexts": epoch * epoch_step * batch_size,
                    "samples/tokens": epoch * epoch_step * batch_size * context_length,
                    "gradient_steps": gradient_steps,
                    "loss/train": loss.item() * gradient_accumulation_steps,
                }
            )
        loss = loss / gradient_accumulation_steps
        accelerator.backward(loss)
        if epoch_step % gradient_accumulation_steps == 0:
            accelerator.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            gradient_steps += 1
            gradient_steps_since_eval += 1
            
        if gradient_steps_since_eval >= gradient_step_eval_interval:
            gradient_steps_since_eval = 0
            eval_loss, perplexity = evaluate(num_eval_batches)
            accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
            model.train()
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
            if accelerator.is_main_process:
                tokenizer.save_pretrained(output_dir)
                repo.push_to_hub(
                    commit_message=f"Training in progress epoch step {epoch * epoch_step}", blocking=False
                )

accelerator.end_training()

  0%|          | 0/878614 [00:00<?, ?it/s]



{'lr': [0.0, 0.0], 'samples/contexts': 256, 'samples/tokens': 65536, 'gradient_steps': 0, 'loss/train': 71.26820373535156}
{'lr': [5e-07, 5e-07], 'samples/contexts': 512, 'samples/tokens': 131072, 'gradient_steps': 1, 'loss/train': 71.38358306884766}
{'lr': [1e-06, 1e-06], 'samples/contexts': 768, 'samples/tokens': 196608, 'gradient_steps': 2, 'loss/train': 71.17308044433594}
{'lr': [1.5e-06, 1.5e-06], 'samples/contexts': 1024, 'samples/tokens': 262144, 'gradient_steps': 3, 'loss/train': 70.84748840332031}
{'lr': [2e-06, 2e-06], 'samples/contexts': 1280, 'samples/tokens': 327680, 'gradient_steps': 4, 'loss/train': 70.48038482666016}
{'lr': [2.5e-06, 2.5e-06], 'samples/contexts': 1536, 'samples/tokens': 393216, 'gradient_steps': 5, 'loss/train': 69.88656616210938}
{'lr': [3e-06, 3e-06], 'samples/contexts': 1792, 'samples/tokens': 458752, 'gradient_steps': 6, 'loss/train': 69.14768981933594}
{'lr': [3.5e-06, 3.5e-06], 'samples/contexts': 2048, 'samples/tokens': 524288, 'gradient_steps': 

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 9.393847465515137, 'perplexity': 12014.2353515625}




{'lr': [1.5e-05, 1.5e-05], 'samples/contexts': 7936, 'samples/tokens': 2031616, 'gradient_steps': 30, 'loss/train': 61.213165283203125}
{'lr': [1.55e-05, 1.55e-05], 'samples/contexts': 8192, 'samples/tokens': 2097152, 'gradient_steps': 31, 'loss/train': 61.22166061401367}
{'lr': [1.6e-05, 1.6e-05], 'samples/contexts': 8448, 'samples/tokens': 2162688, 'gradient_steps': 32, 'loss/train': 61.35581588745117}
{'lr': [1.65e-05, 1.65e-05], 'samples/contexts': 8704, 'samples/tokens': 2228224, 'gradient_steps': 33, 'loss/train': 61.499610900878906}
{'lr': [1.7000000000000003e-05, 1.7000000000000003e-05], 'samples/contexts': 8960, 'samples/tokens': 2293760, 'gradient_steps': 34, 'loss/train': 61.47796630859375}
{'lr': [1.7500000000000002e-05, 1.7500000000000002e-05], 'samples/contexts': 9216, 'samples/tokens': 2359296, 'gradient_steps': 35, 'loss/train': 60.99376678466797}
{'lr': [1.8e-05, 1.8e-05], 'samples/contexts': 9472, 'samples/tokens': 2424832, 'gradient_steps': 36, 'loss/train': 60.40749

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 8.549030303955078, 'perplexity': 5161.74658203125}




{'lr': [3e-05, 3e-05], 'samples/contexts': 15616, 'samples/tokens': 3997696, 'gradient_steps': 60, 'loss/train': 55.753021240234375}
{'lr': [3.05e-05, 3.05e-05], 'samples/contexts': 15872, 'samples/tokens': 4063232, 'gradient_steps': 61, 'loss/train': 55.97893524169922}
{'lr': [3.1e-05, 3.1e-05], 'samples/contexts': 16128, 'samples/tokens': 4128768, 'gradient_steps': 62, 'loss/train': 55.91229248046875}
{'lr': [3.15e-05, 3.15e-05], 'samples/contexts': 16384, 'samples/tokens': 4194304, 'gradient_steps': 63, 'loss/train': 55.927772521972656}
{'lr': [3.2e-05, 3.2e-05], 'samples/contexts': 16640, 'samples/tokens': 4259840, 'gradient_steps': 64, 'loss/train': 55.96785354614258}
{'lr': [3.2500000000000004e-05, 3.2500000000000004e-05], 'samples/contexts': 16896, 'samples/tokens': 4325376, 'gradient_steps': 65, 'loss/train': 54.23414993286133}
{'lr': [3.3e-05, 3.3e-05], 'samples/contexts': 17152, 'samples/tokens': 4390912, 'gradient_steps': 66, 'loss/train': 54.26689147949219}
{'lr': [3.35e-05

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 7.727158069610596, 'perplexity': 2269.144287109375}




{'lr': [4.4999999999999996e-05, 4.4999999999999996e-05], 'samples/contexts': 23296, 'samples/tokens': 5963776, 'gradient_steps': 90, 'loss/train': 49.69676208496094}
{'lr': [4.55e-05, 4.55e-05], 'samples/contexts': 23552, 'samples/tokens': 6029312, 'gradient_steps': 91, 'loss/train': 50.529052734375}
{'lr': [4.6e-05, 4.6e-05], 'samples/contexts': 23808, 'samples/tokens': 6094848, 'gradient_steps': 92, 'loss/train': 50.180179595947266}
{'lr': [4.65e-05, 4.65e-05], 'samples/contexts': 24064, 'samples/tokens': 6160384, 'gradient_steps': 93, 'loss/train': 50.291297912597656}
{'lr': [4.7000000000000004e-05, 4.7000000000000004e-05], 'samples/contexts': 24320, 'samples/tokens': 6225920, 'gradient_steps': 94, 'loss/train': 49.10736846923828}
{'lr': [4.75e-05, 4.75e-05], 'samples/contexts': 24576, 'samples/tokens': 6291456, 'gradient_steps': 95, 'loss/train': 50.06428527832031}
{'lr': [4.8e-05, 4.8e-05], 'samples/contexts': 24832, 'samples/tokens': 6356992, 'gradient_steps': 96, 'loss/train': 4

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 7.161857604980469, 'perplexity': 1289.3037109375}




{'lr': [6e-05, 6e-05], 'samples/contexts': 30976, 'samples/tokens': 7929856, 'gradient_steps': 120, 'loss/train': 47.732521057128906}
{'lr': [6.05e-05, 6.05e-05], 'samples/contexts': 31232, 'samples/tokens': 7995392, 'gradient_steps': 121, 'loss/train': 48.0483512878418}
{'lr': [6.1e-05, 6.1e-05], 'samples/contexts': 31488, 'samples/tokens': 8060928, 'gradient_steps': 122, 'loss/train': 46.980018615722656}
{'lr': [6.15e-05, 6.15e-05], 'samples/contexts': 31744, 'samples/tokens': 8126464, 'gradient_steps': 123, 'loss/train': 47.76753234863281}
{'lr': [6.2e-05, 6.2e-05], 'samples/contexts': 32000, 'samples/tokens': 8192000, 'gradient_steps': 124, 'loss/train': 47.18766784667969}
{'lr': [6.25e-05, 6.25e-05], 'samples/contexts': 32256, 'samples/tokens': 8257536, 'gradient_steps': 125, 'loss/train': 46.76532745361328}
{'lr': [6.3e-05, 6.3e-05], 'samples/contexts': 32512, 'samples/tokens': 8323072, 'gradient_steps': 126, 'loss/train': 47.926002502441406}
{'lr': [6.35e-05, 6.35e-05], 'samples

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 6.897471904754639, 'perplexity': 989.7693481445312}




{'lr': [7.5e-05, 7.5e-05], 'samples/contexts': 38656, 'samples/tokens': 9895936, 'gradient_steps': 150, 'loss/train': 44.45376205444336}
{'lr': [7.55e-05, 7.55e-05], 'samples/contexts': 38912, 'samples/tokens': 9961472, 'gradient_steps': 151, 'loss/train': 44.903011322021484}
{'lr': [7.6e-05, 7.6e-05], 'samples/contexts': 39168, 'samples/tokens': 10027008, 'gradient_steps': 152, 'loss/train': 46.237728118896484}
{'lr': [7.65e-05, 7.65e-05], 'samples/contexts': 39424, 'samples/tokens': 10092544, 'gradient_steps': 153, 'loss/train': 45.43748474121094}
{'lr': [7.7e-05, 7.7e-05], 'samples/contexts': 39680, 'samples/tokens': 10158080, 'gradient_steps': 154, 'loss/train': 45.20257568359375}
{'lr': [7.75e-05, 7.75e-05], 'samples/contexts': 39936, 'samples/tokens': 10223616, 'gradient_steps': 155, 'loss/train': 45.0079231262207}
{'lr': [7.8e-05, 7.8e-05], 'samples/contexts': 40192, 'samples/tokens': 10289152, 'gradient_steps': 156, 'loss/train': 43.668128967285156}
{'lr': [7.85e-05, 7.85e-05],

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 6.719207763671875, 'perplexity': 828.1611328125}




{'lr': [8.999999999999999e-05, 8.999999999999999e-05], 'samples/contexts': 46336, 'samples/tokens': 11862016, 'gradient_steps': 180, 'loss/train': 44.53252029418945}
{'lr': [9.05e-05, 9.05e-05], 'samples/contexts': 46592, 'samples/tokens': 11927552, 'gradient_steps': 181, 'loss/train': 44.493221282958984}
{'lr': [9.1e-05, 9.1e-05], 'samples/contexts': 46848, 'samples/tokens': 11993088, 'gradient_steps': 182, 'loss/train': 44.08295440673828}
{'lr': [9.15e-05, 9.15e-05], 'samples/contexts': 47104, 'samples/tokens': 12058624, 'gradient_steps': 183, 'loss/train': 43.55826187133789}
{'lr': [9.2e-05, 9.2e-05], 'samples/contexts': 47360, 'samples/tokens': 12124160, 'gradient_steps': 184, 'loss/train': 44.5527229309082}
{'lr': [9.25e-05, 9.25e-05], 'samples/contexts': 47616, 'samples/tokens': 12189696, 'gradient_steps': 185, 'loss/train': 44.33708953857422}
{'lr': [9.3e-05, 9.3e-05], 'samples/contexts': 47872, 'samples/tokens': 12255232, 'gradient_steps': 186, 'loss/train': 43.898189544677734}

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 6.57998514175415, 'perplexity': 720.5286254882812}




{'lr': [0.000105, 0.000105], 'samples/contexts': 54016, 'samples/tokens': 13828096, 'gradient_steps': 210, 'loss/train': 44.20814895629883}
{'lr': [0.0001055, 0.0001055], 'samples/contexts': 54272, 'samples/tokens': 13893632, 'gradient_steps': 211, 'loss/train': 44.350982666015625}
{'lr': [0.000106, 0.000106], 'samples/contexts': 54528, 'samples/tokens': 13959168, 'gradient_steps': 212, 'loss/train': 43.619354248046875}
{'lr': [0.0001065, 0.0001065], 'samples/contexts': 54784, 'samples/tokens': 14024704, 'gradient_steps': 213, 'loss/train': 43.702110290527344}
{'lr': [0.000107, 0.000107], 'samples/contexts': 55040, 'samples/tokens': 14090240, 'gradient_steps': 214, 'loss/train': 41.55384063720703}
{'lr': [0.0001075, 0.0001075], 'samples/contexts': 55296, 'samples/tokens': 14155776, 'gradient_steps': 215, 'loss/train': 43.743412017822266}
{'lr': [0.000108, 0.000108], 'samples/contexts': 55552, 'samples/tokens': 14221312, 'gradient_steps': 216, 'loss/train': 43.105079650878906}
{'lr': [0

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 6.442998886108398, 'perplexity': 628.2881469726562}




{'lr': [0.00012, 0.00012], 'samples/contexts': 61696, 'samples/tokens': 15794176, 'gradient_steps': 240, 'loss/train': 43.46315002441406}
{'lr': [0.0001205, 0.0001205], 'samples/contexts': 61952, 'samples/tokens': 15859712, 'gradient_steps': 241, 'loss/train': 43.30318832397461}
{'lr': [0.000121, 0.000121], 'samples/contexts': 62208, 'samples/tokens': 15925248, 'gradient_steps': 242, 'loss/train': 41.85728454589844}
{'lr': [0.0001215, 0.0001215], 'samples/contexts': 62464, 'samples/tokens': 15990784, 'gradient_steps': 243, 'loss/train': 42.016998291015625}
{'lr': [0.000122, 0.000122], 'samples/contexts': 62720, 'samples/tokens': 16056320, 'gradient_steps': 244, 'loss/train': 42.71946334838867}
{'lr': [0.0001225, 0.0001225], 'samples/contexts': 62976, 'samples/tokens': 16121856, 'gradient_steps': 245, 'loss/train': 43.09562683105469}
{'lr': [0.000123, 0.000123], 'samples/contexts': 63232, 'samples/tokens': 16187392, 'gradient_steps': 246, 'loss/train': 42.587913513183594}
{'lr': [0.0001

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 6.312440395355225, 'perplexity': 551.388916015625}




{'lr': [0.000135, 0.000135], 'samples/contexts': 69376, 'samples/tokens': 17760256, 'gradient_steps': 270, 'loss/train': 41.61447525024414}
{'lr': [0.00013550000000000001, 0.00013550000000000001], 'samples/contexts': 69632, 'samples/tokens': 17825792, 'gradient_steps': 271, 'loss/train': 42.551666259765625}
{'lr': [0.00013600000000000003, 0.00013600000000000003], 'samples/contexts': 69888, 'samples/tokens': 17891328, 'gradient_steps': 272, 'loss/train': 41.69309616088867}
{'lr': [0.0001365, 0.0001365], 'samples/contexts': 70144, 'samples/tokens': 17956864, 'gradient_steps': 273, 'loss/train': 42.94802474975586}
{'lr': [0.00013700000000000002, 0.00013700000000000002], 'samples/contexts': 70400, 'samples/tokens': 18022400, 'gradient_steps': 274, 'loss/train': 41.05244064331055}
{'lr': [0.0001375, 0.0001375], 'samples/contexts': 70656, 'samples/tokens': 18087936, 'gradient_steps': 275, 'loss/train': 41.59552764892578}
{'lr': [0.00013800000000000002, 0.00013800000000000002], 'samples/conte

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 6.202305316925049, 'perplexity': 493.8863220214844}




{'lr': [0.00015, 0.00015], 'samples/contexts': 77056, 'samples/tokens': 19726336, 'gradient_steps': 300, 'loss/train': 40.61223602294922}
{'lr': [0.0001505, 0.0001505], 'samples/contexts': 77312, 'samples/tokens': 19791872, 'gradient_steps': 301, 'loss/train': 42.06130599975586}
{'lr': [0.000151, 0.000151], 'samples/contexts': 77568, 'samples/tokens': 19857408, 'gradient_steps': 302, 'loss/train': 42.35646057128906}
{'lr': [0.0001515, 0.0001515], 'samples/contexts': 77824, 'samples/tokens': 19922944, 'gradient_steps': 303, 'loss/train': 40.227012634277344}
{'lr': [0.000152, 0.000152], 'samples/contexts': 78080, 'samples/tokens': 19988480, 'gradient_steps': 304, 'loss/train': 41.29930877685547}
{'lr': [0.0001525, 0.0001525], 'samples/contexts': 78336, 'samples/tokens': 20054016, 'gradient_steps': 305, 'loss/train': 40.63396453857422}
{'lr': [0.000153, 0.000153], 'samples/contexts': 78592, 'samples/tokens': 20119552, 'gradient_steps': 306, 'loss/train': 41.05976867675781}
{'lr': [0.00015

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 6.115078926086426, 'perplexity': 452.6317443847656}




{'lr': [0.000165, 0.000165], 'samples/contexts': 84736, 'samples/tokens': 21692416, 'gradient_steps': 330, 'loss/train': 39.905372619628906}
{'lr': [0.0001655, 0.0001655], 'samples/contexts': 84992, 'samples/tokens': 21757952, 'gradient_steps': 331, 'loss/train': 39.95137023925781}
{'lr': [0.00016600000000000002, 0.00016600000000000002], 'samples/contexts': 85248, 'samples/tokens': 21823488, 'gradient_steps': 332, 'loss/train': 40.25444793701172}
{'lr': [0.0001665, 0.0001665], 'samples/contexts': 85504, 'samples/tokens': 21889024, 'gradient_steps': 333, 'loss/train': 39.74456024169922}
{'lr': [0.00016700000000000002, 0.00016700000000000002], 'samples/contexts': 85760, 'samples/tokens': 21954560, 'gradient_steps': 334, 'loss/train': 41.021400451660156}
{'lr': [0.0001675, 0.0001675], 'samples/contexts': 86016, 'samples/tokens': 22020096, 'gradient_steps': 335, 'loss/train': 40.270545959472656}
{'lr': [0.00016800000000000002, 0.00016800000000000002], 'samples/contexts': 86272, 'samples/to

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 6.023135662078857, 'perplexity': 412.8711853027344}




{'lr': [0.00017999999999999998, 0.00017999999999999998], 'samples/contexts': 92416, 'samples/tokens': 23658496, 'gradient_steps': 360, 'loss/train': 40.13685607910156}
{'lr': [0.0001805, 0.0001805], 'samples/contexts': 92672, 'samples/tokens': 23724032, 'gradient_steps': 361, 'loss/train': 39.8623161315918}
{'lr': [0.000181, 0.000181], 'samples/contexts': 92928, 'samples/tokens': 23789568, 'gradient_steps': 362, 'loss/train': 40.77336502075195}
{'lr': [0.0001815, 0.0001815], 'samples/contexts': 93184, 'samples/tokens': 23855104, 'gradient_steps': 363, 'loss/train': 39.20796203613281}
{'lr': [0.000182, 0.000182], 'samples/contexts': 93440, 'samples/tokens': 23920640, 'gradient_steps': 364, 'loss/train': 39.72557067871094}
{'lr': [0.0001825, 0.0001825], 'samples/contexts': 93696, 'samples/tokens': 23986176, 'gradient_steps': 365, 'loss/train': 39.38744354248047}
{'lr': [0.000183, 0.000183], 'samples/contexts': 93952, 'samples/tokens': 24051712, 'gradient_steps': 366, 'loss/train': 39.762

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 5.941521167755127, 'perplexity': 380.5133056640625}




{'lr': [0.00019500000000000002, 0.00019500000000000002], 'samples/contexts': 100096, 'samples/tokens': 25624576, 'gradient_steps': 390, 'loss/train': 39.82495880126953}
{'lr': [0.0001955, 0.0001955], 'samples/contexts': 100352, 'samples/tokens': 25690112, 'gradient_steps': 391, 'loss/train': 39.27484893798828}
{'lr': [0.00019600000000000002, 0.00019600000000000002], 'samples/contexts': 100608, 'samples/tokens': 25755648, 'gradient_steps': 392, 'loss/train': 39.47488021850586}
{'lr': [0.0001965, 0.0001965], 'samples/contexts': 100864, 'samples/tokens': 25821184, 'gradient_steps': 393, 'loss/train': 40.34943389892578}
{'lr': [0.00019700000000000002, 0.00019700000000000002], 'samples/contexts': 101120, 'samples/tokens': 25886720, 'gradient_steps': 394, 'loss/train': 38.55377960205078}
{'lr': [0.0001975, 0.0001975], 'samples/contexts': 101376, 'samples/tokens': 25952256, 'gradient_steps': 395, 'loss/train': 40.489173889160156}
{'lr': [0.00019800000000000002, 0.00019800000000000002], 'sampl

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 5.853573322296143, 'perplexity': 348.4773864746094}




{'lr': [0.00021, 0.00021], 'samples/contexts': 107776, 'samples/tokens': 27590656, 'gradient_steps': 420, 'loss/train': 38.40015411376953}
{'lr': [0.0002105, 0.0002105], 'samples/contexts': 108032, 'samples/tokens': 27656192, 'gradient_steps': 421, 'loss/train': 38.650978088378906}
{'lr': [0.000211, 0.000211], 'samples/contexts': 108288, 'samples/tokens': 27721728, 'gradient_steps': 422, 'loss/train': 38.469512939453125}
{'lr': [0.0002115, 0.0002115], 'samples/contexts': 108544, 'samples/tokens': 27787264, 'gradient_steps': 423, 'loss/train': 38.78327178955078}
{'lr': [0.000212, 0.000212], 'samples/contexts': 108800, 'samples/tokens': 27852800, 'gradient_steps': 424, 'loss/train': 39.07328414916992}
{'lr': [0.0002125, 0.0002125], 'samples/contexts': 109056, 'samples/tokens': 27918336, 'gradient_steps': 425, 'loss/train': 38.0362548828125}
{'lr': [0.000213, 0.000213], 'samples/contexts': 109312, 'samples/tokens': 27983872, 'gradient_steps': 426, 'loss/train': 38.538177490234375}
{'lr': 

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 5.778227806091309, 'perplexity': 323.1859436035156}




{'lr': [0.00022500000000000002, 0.00022500000000000002], 'samples/contexts': 115456, 'samples/tokens': 29556736, 'gradient_steps': 450, 'loss/train': 40.11278533935547}
{'lr': [0.0002255, 0.0002255], 'samples/contexts': 115712, 'samples/tokens': 29622272, 'gradient_steps': 451, 'loss/train': 37.819244384765625}
{'lr': [0.00022600000000000002, 0.00022600000000000002], 'samples/contexts': 115968, 'samples/tokens': 29687808, 'gradient_steps': 452, 'loss/train': 38.31676483154297}
{'lr': [0.0002265, 0.0002265], 'samples/contexts': 116224, 'samples/tokens': 29753344, 'gradient_steps': 453, 'loss/train': 39.00597381591797}
{'lr': [0.00022700000000000002, 0.00022700000000000002], 'samples/contexts': 116480, 'samples/tokens': 29818880, 'gradient_steps': 454, 'loss/train': 38.56154251098633}
{'lr': [0.0002275, 0.0002275], 'samples/contexts': 116736, 'samples/tokens': 29884416, 'gradient_steps': 455, 'loss/train': 38.082969665527344}
{'lr': [0.000228, 0.000228], 'samples/contexts': 116992, 'samp

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 5.69384765625, 'perplexity': 297.0343322753906}




{'lr': [0.00024, 0.00024], 'samples/contexts': 123136, 'samples/tokens': 31522816, 'gradient_steps': 480, 'loss/train': 37.82728576660156}
{'lr': [0.0002405, 0.0002405], 'samples/contexts': 123392, 'samples/tokens': 31588352, 'gradient_steps': 481, 'loss/train': 37.62151336669922}
{'lr': [0.000241, 0.000241], 'samples/contexts': 123648, 'samples/tokens': 31653888, 'gradient_steps': 482, 'loss/train': 38.9891471862793}
{'lr': [0.0002415, 0.0002415], 'samples/contexts': 123904, 'samples/tokens': 31719424, 'gradient_steps': 483, 'loss/train': 37.23066329956055}
{'lr': [0.000242, 0.000242], 'samples/contexts': 124160, 'samples/tokens': 31784960, 'gradient_steps': 484, 'loss/train': 37.507835388183594}
{'lr': [0.00024249999999999999, 0.00024249999999999999], 'samples/contexts': 124416, 'samples/tokens': 31850496, 'gradient_steps': 485, 'loss/train': 36.293312072753906}
{'lr': [0.000243, 0.000243], 'samples/contexts': 124672, 'samples/tokens': 31916032, 'gradient_steps': 486, 'loss/train': 3

  0%|          | 0/50 [00:00<?, ?it/s]

{'loss/eval': 5.631961345672607, 'perplexity': 279.209228515625}




{'lr': [0.000255, 0.000255], 'samples/contexts': 130816, 'samples/tokens': 33488896, 'gradient_steps': 510, 'loss/train': 36.41244125366211}
{'lr': [0.00025550000000000003, 0.00025550000000000003], 'samples/contexts': 131072, 'samples/tokens': 33554432, 'gradient_steps': 511, 'loss/train': 36.874427795410156}
{'lr': [0.000256, 0.000256], 'samples/contexts': 131328, 'samples/tokens': 33619968, 'gradient_steps': 512, 'loss/train': 36.71207046508789}
{'lr': [0.0002565, 0.0002565], 'samples/contexts': 131584, 'samples/tokens': 33685504, 'gradient_steps': 513, 'loss/train': 38.00460433959961}
{'lr': [0.000257, 0.000257], 'samples/contexts': 131840, 'samples/tokens': 33751040, 'gradient_steps': 514, 'loss/train': 37.89382553100586}
{'lr': [0.0002575, 0.0002575], 'samples/contexts': 132096, 'samples/tokens': 33816576, 'gradient_steps': 515, 'loss/train': 37.34954833984375}
{'lr': [0.00025800000000000004, 0.00025800000000000004], 'samples/contexts': 132352, 'samples/tokens': 33882112, 'gradien

KeyboardInterrupt: 