In [1]:
from model import Transformer, ModelConfig
from trainer import Trainer, TrainerConfig, DataLoader

from transformers import AutoTokenizer
import torch

torch.set_float32_matmul_precision('high')
torch.cuda.empty_cache()

tokenizer_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.pad_token = tokenizer.eos_token

checkpoint_path = './model_testing'
continue_train = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_config = TrainerConfig(
    vocab_size = 50368,
    num_epochs = 1,

    use_ddp = True,
    use_moe = True,
    use_lossfreebalance = False,
    clean_cuda_cache = True,
    use_compile = True,
    use_dtype = "bfloat16",

    seed = 42,
    max_seq_len = 128,
    batch_size = , 
    accumulation_steps = 0,
    
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    learning_rate = 4e-4,
    betas = (0.90, 0.97),
    update_rate = 5e-6,

    val_ratio = 0.005,
    steps_for_eval = 20,
    eval_interval = 20,

    mlm_probability = 0.30,

    checkpoints_frequency = 3500,
    path_to_checkpoints = "./model_testing",

    tokenized_dataset_path = "",
    hf_dataset_name = "allenai/c4",
    hf_dataset_config = "en",
    hf_dataset_split = "train",
    hf_text_field = "text",
    hf_add_eos = True,
    hf_cache_dir = "./.cache/hf",
    hf_tokenized_path = "./.cache/tokenized",
    hf_num_proc = 64,
    eval_log_file = "log/eval.txt",
    use_wandb = True,
    wandb_project = "forschungsprojekt",
    wandb_run_name = "bert-moe",
)

In [5]:
config = ModelConfig(
        vocab_size = 50368,

        num_dims = 768,
        num_heads = 12,
        num_kv_heads = 12,
        num_layers = 22,
        ffn_hidden_dims = 512 * 2,

        layernorm_eps = 1e-6,

        attention_probs_dropout_prob = 0.0,
        attn_qkv_bias = False,
        attn_out_bias = False,
        attn_out_dropout_prob = 0.0,
        global_attn_every_n_layers = 3,
        sliding_window = 128,
        rotary_emb_base = 10000,
        #local_attn_rotary_emb_base = 10000,
    
        context_len = 128,
        
        use_cache = False,
        use_flash = True,
        use_moe = True,

        moe_num_experts = 4,
        moe_routed_experts = 2,
        moe_eps = 1e-6,
        moe_aux_loss_coef = 0.01,
        moe_shared_experts = 1,
        use_lossfreebalance = True,
)

In [6]:
model = Transformer(config)
if continue_train:
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

    state_dict = checkpoint['model']
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith("_orig_mod."):
            new_state_dict[k[len("_orig_mod."):]] = v 
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict, strict=False)

model.to(device)

Unable to import flash_attn; defaulting FlexBERT attention implementation to PyTorch's SDPA kernel. This requires padding and unpadding inputs, which will add some overhead.
SDPA attention is being used without an attention mask. Including padding in the  attention calculation may cause differences from the Flash Attention implementation.


ValueError: Sliding window is not implemented for the PyTorch SDPA path. Use the FA2 backend.

In [None]:
data_loader = DataLoader(train_config, tokenizer=tokenizer)
trainer = Trainer(train_config, model, tokenizer)
trainer.train(data_loader)

Total tokens loaded: 953,856,000
Device: cuda:0
Model's trainable params: 236.93M
Tokens per step: 262144
use torch.compile(): True
Use MoE: Yes 
Number of experts: 4
Number of used experts during inference: 2
Method of aux_loss: default
Number of parameters will be used during inference: 170.85M


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ubuntu/.netrc


Epoch: 0 | Step: 0 | loss: 11.0879 | acc: 0.0000 | norm: 10.0949 | lr: 2.1046831955922866e-05 | tok/s: 8589.258888476856
Epoch: 0 | Step: 1 | loss: 10.5767 | acc: 0.0434 | norm: 6.7429 | lr: 2.2093663911845734e-05 | tok/s: 43202.74503844035
Epoch: 0 | Step: 2 | loss: 10.3104 | acc: 0.0514 | norm: 5.3676 | lr: 2.3140495867768598e-05 | tok/s: 43762.14520914456
Epoch: 0 | Step: 3 | loss: 10.1071 | acc: 0.0494 | norm: 4.3395 | lr: 2.4187327823691462e-05 | tok/s: 45143.02149718191
Epoch: 0 | Step: 4 | loss: 9.9509 | acc: 0.0504 | norm: 3.7357 | lr: 2.523415977961433e-05 | tok/s: 41992.96584693851
Epoch: 0 | Step: 5 | loss: 9.8359 | acc: 0.0516 | norm: 3.3117 | lr: 2.628099173553719e-05 | tok/s: 43084.130623821016
Epoch: 0 | Step: 6 | loss: 9.7416 | acc: 0.0523 | norm: 2.8601 | lr: 2.7327823691460058e-05 | tok/s: 44413.80966560781
Epoch: 0 | Step: 7 | loss: 9.6563 | acc: 0.0509 | norm: 2.5794 | lr: 2.837465564738292e-05 | tok/s: 43724.41909069658
Epoch: 0 | Step: 8 | loss: 9.6114 | acc: 0.05