-
Notifications
You must be signed in to change notification settings - Fork 768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Getting nan values for training and validation loss #620
Comments
Hmm interesting. I will be investigating (and hopefully fixing) ml-explore/mlx#896 (comment) tonight. But this could be unrelated. It seems the loss jumps really high really quickly so there is definitely something going on there. Were you using the exact same configuration before? What was the loss evolution before these commits. This could also be related to #613, but if that's the case I think a better fix is to reduce points towards the training being in an unstable regime anyway. Maybe reducing |
I started off using a configuration of my own on a dataset I've been having difficulty fitting with low training/validation loss for about a week or so. Then, today, I started noticing the NAN values and wanted to rule out if it was recent changes or from the data I was using, so I just tried to mock up the most easily reproducible example here. I'll try lower values of alpha. They were high mostly because I had been having issues fitting the data with the default LoRA values and was looking for values that would fit but also not obliterate the base model's capability, and it seems higher values were needed. |
Oh that's super nice, I didn't see this was on the example data. I can play with it as well. Thanks! |
I lowered the alpha and rank to their default values (16 and 8 respectively) and 800 iterations in I'm not seeing NaN values:
|
I'm running this right now:
So far no NaNs or loss spikes (which is also odd). Are we using the same model? |
I was able to reproduce the loss spike when using the same layer keys and learning rate. A smaller learning rate and/or not using all the layers works fine. It looks like the learning rate may just be on the high side. On the other had, the behavior shouldn't be too different from old versions. The main thing I can think of is we changed where we cast from low to high precision.. @chimezie do you happen to know which versions it was working for you with that setup? |
I went back to mlx==0.6.0 and mlx-lm==0.2.0 and I still see the loss spiking with the high learning rate. I think it might be worth tuning the learning rate with the latest MLX given the changes to casting. I don't know the cause of the NaNs precisely.. but the fact that the loss spikes means the run is likely already lost and it wouldn't suprise me if the low precision is overflowing soon after. Another thing you could do for stability (but it sill slow things down) is to quantize the model but use fp32 as the precision for all the weights. |
I have a similar problem with nan values for training and validation loss. I trained my own model, and saved it. |
|
@awni I don't remember which version I used the last time I successfully did mlx_lm tuning on the same dataset, but I believe it was around Feb 15th. I'm going to see how much I can get from boosting alpha and r values above the default (which seems too low for the domain) to help fit the data at reasonable loss values without having to have the LR be so high. I'll also quantize with fp32, which may help towards the same end. |
Ok let's close this for now as I believe this is related to overflowing fp16. @JKwon0331 if you have a NaN issue please open a new issue and provide some more details, we can help debug |
I'm getting this problem again. This time training a Qwen1.5-14B model quantized with float32 adn converted this way: % python -m mlx_lm.convert --hf-path Qwen/Qwen1.5-14B -q --dtype float32 The git hashes for mlx and mlx-example are below mlx-examples % git rev-parse HEAD
c386dd5f5a1c1d40f94d6f3bd7b5bd25929e05aa
mlx % git rev-parse HEAD
bddf23f175726a57f0e443cd45518c0757daa166 This is the output of the training that resulted in NaN values: Iter 1: Val loss 2.199, Val took 1281.939s
Iter 151: Train loss 2.200, Learning Rate 2.975e-06, It/sec 0.063, Tokens/sec 209.210, Trained Tokens 503331, Peak mem 102.075 GB
Iter 302: Train loss 2.117, Learning Rate 5.961e-06, It/sec 0.059, Tokens/sec 200.795, Trained Tokens 1012934, Peak mem 143.826 GB
Iter 453: Train loss 2.098, Learning Rate 8.946e-06, It/sec 0.061, Tokens/sec 208.397, Trained Tokens 1527324, Peak mem 143.826 GB
Iter 604: Train loss 2.080, Learning Rate 1.193e-05, It/sec 0.057, Tokens/sec 198.680, Trained Tokens 2054545, Peak mem 143.826 GB
Iter 755: Train loss 2.077, Learning Rate 1.492e-05, It/sec 0.062, Tokens/sec 202.874, Trained Tokens 2546743, Peak mem 143.826 GB
Iter 906: Train loss 2.068, Learning Rate 1.790e-05, It/sec 0.058, Tokens/sec 204.168, Trained Tokens 3074828, Peak mem 143.826 GB
Iter 1057: Train loss 2.043, Learning Rate 2.089e-05, It/sec 0.056, Tokens/sec 195.615, Trained Tokens 3605623, Peak mem 143.826 GB
Iter 1208: Train loss 2.055, Learning Rate 2.387e-05, It/sec 0.062, Tokens/sec 206.878, Trained Tokens 4110642, Peak mem 143.826 GB
Iter 1359: Train loss 2.052, Learning Rate 2.686e-05, It/sec 0.067, Tokens/sec 206.440, Trained Tokens 4575392, Peak mem 143.826 GB
Iter 1510: Train loss 2.029, Learning Rate 2.984e-05, It/sec 0.062, Tokens/sec 207.524, Trained Tokens 5080397, Peak mem 143.826 GB
Iter 1517: Val loss 1.984, Val took 1277.913s
Iter 1661: Train loss 2.026, Learning Rate 2.999e-05, It/sec 0.066, Tokens/sec 218.319, Trained Tokens 5578308, Peak mem 143.826 GB
Iter 1812: Train loss 2.021, Learning Rate 2.997e-05, It/sec 0.064, Tokens/sec 205.973, Trained Tokens 6067475, Peak mem 143.826 GB
Iter 1963: Train loss 2.020, Learning Rate 2.994e-05, It/sec 0.061, Tokens/sec 200.265, Trained Tokens 6560695, Peak mem 143.966 GB
Iter 2114: Train loss 2.038, Learning Rate 2.989e-05, It/sec 0.060, Tokens/sec 207.632, Trained Tokens 7083937, Peak mem 143.966 GB
Iter 2265: Train loss nan, Learning Rate 2.982e-05, It/sec 0.056, Tokens/sec 194.322, Trained Tokens 7609702, Peak mem 160.765 GB
Iter 2416: Train loss nan, Learning Rate 2.974e-05, It/sec 0.062, Tokens/sec 208.390, Trained Tokens 8114928, Peak mem 160.765 GB
Iter 2567: Train loss nan, Learning Rate 2.965e-05, It/sec 0.059, Tokens/sec 207.165, Trained Tokens 8645401, Peak mem 160.765 GB
Iter 2718: Train loss nan, Learning Rate 2.954e-05, It/sec 0.067, Tokens/sec 206.968, Trained Tokens 9110687, Peak mem 160.765 GB
Iter 2869: Train loss nan, Learning Rate 2.942e-05, It/sec 0.062, Tokens/sec 208.907, Trained Tokens 9620573, Peak mem 160.765 GB
Iter 3020: Train loss nan, Learning Rate 2.928e-05, It/sec 0.060, Tokens/sec 208.873, Trained Tokens 10148813, Peak mem 160.765 GB
Iter 3034: Val loss nan, Val took 1295.306s
[..snip..] The lora parameters were
As I indicated in this ticket earlier, I have had to use a learning rate and alpha/rank values a little higher than normal because I was not getting convergence training otherwise. Still, this run was using a warmup period w/ Cosine annealing: lr_schedule:
name: "cosine_decay"
warmup: 100
warmup_init: 1e-8
arguments: [3e-5, 900, 7e-6] I'm having trouble isolating a reproducible run with the git LoRa data, but I will provide a follow-up if/when I can. |
@chimezie I don't see a NaN in that log you shared. What am I missing? |
Sorry. I have updated the comment with the training/validation errors, including the NaN values |
Running on the following git hashes, I'm still getting NaN values training Qwen1.5-14. mlx % git rev-parse HEAD
99abb9eff4779700741c3faa92d7fdcb259e2022
mlx-examples % git rev-parse HEAD
eff6690952847386aa3cc375b4ac83decc886868 I tried lowering the learning rate, alpha, and rank as well: learning_rate: 1e-5
lora_layers: 20
lora_parameters:
alpha: 64
dropout: 0.3205
rank: 32
scale: 10.0
lr_schedule:
name: cosine_decay
warmup: 1000
warmup_init: 1e-8
arguments: [1e-5, 15175, 7e-6] Iter 1: Val loss 2.199, Val took 1283.062s
Iter 151: Train loss 2.171, Learning Rate 1.987e-06, It/sec 0.062, Tokens/sec 209.019, Trained Tokens 507685, Peak mem 99.752 GB
Iter 302: Train loss 2.123, Learning Rate 3.976e-06, It/sec 0.059, Tokens/sec 196.263, Trained Tokens 1006928, Peak mem 143.892 GB
Iter 453: Train loss 2.081, Learning Rate 5.966e-06, It/sec 0.057, Tokens/sec 199.289, Trained Tokens 1530602, Peak mem 143.892 GB
Iter 604: Train loss 2.050, Learning Rate 7.956e-06, It/sec 0.059, Tokens/sec 194.431, Trained Tokens 2025091, Peak mem 143.892 GB
Iter 755: Train loss 2.071, Learning Rate 9.946e-06, It/sec 0.061, Tokens/sec 203.686, Trained Tokens 2532112, Peak mem 143.892 GB
Iter 906: Train loss 2.070, Learning Rate 1.194e-05, It/sec 0.065, Tokens/sec 202.642, Trained Tokens 3005750, Peak mem 143.892 GB
Iter 1057: Train loss 2.071, Learning Rate 1.393e-05, It/sec 0.062, Tokens/sec 203.865, Trained Tokens 3500267, Peak mem 143.892 GB
Iter 1208: Train loss 2.061, Learning Rate 1.592e-05, It/sec 0.060, Tokens/sec 206.005, Trained Tokens 4022766, Peak mem 143.892 GB
Iter 1359: Train loss nan, Learning Rate 1.790e-05, It/sec 0.055, Tokens/sec 193.330, Trained Tokens 4551018, Peak mem 158.316 GB
Iter 1510: Train loss nan, Learning Rate 1.989e-05, It/sec 0.063, Tokens/sec 211.451, Trained Tokens 5054493, Peak mem 158.316 GB
Iter 1517: Val loss nan, Val took 1278.666s |
Sorry @chimezie haven't had a chance to debug yet. Is this with the default data set? Can you share the exact training command you are using so I can repro? |
@awni I just tried it again using this commandline (against my own data): $ % python -m mlx_lm.lora -c train.yaml
[..snip..]
Iter 1: Val loss 2.199, Val took 1358.465s
Iter 151: Train loss 2.221, Learning Rate 5.893e-07, It/sec 0.055, Tokens/sec 178.991, Trained Tokens 495465, Peak mem 192.693 GB
Iter 302: Train loss 2.211, Learning Rate 1.082e-06, It/sec 0.055, Tokens/sec 185.651, Trained Tokens 1006881, Peak mem 192.693 GB
Iter 453: Train loss 2.145, Learning Rate 1.574e-06, It/sec 0.056, Tokens/sec 189.687, Trained Tokens 1521827, Peak mem 192.693 GB
Iter 604: Train loss 2.142, Learning Rate 2.067e-06, It/sec 0.059, Tokens/sec 199.163, Trained Tokens 2027825, Peak mem 192.693 GB
Iter 755: Train loss 2.096, Learning Rate 2.560e-06, It/sec 0.056, Tokens/sec 188.073, Trained Tokens 2536171, Peak mem 192.693 GB
Iter 906: Train loss 2.097, Learning Rate 3.052e-06, It/sec 0.052, Tokens/sec 186.432, Trained Tokens 3073332, Peak mem 192.693 GB
Iter 1057: Train loss 2.081, Learning Rate 3.545e-06, It/sec 0.055, Tokens/sec 192.338, Trained Tokens 3602905, Peak mem 192.693 GB
Iter 1208: Train loss 2.065, Learning Rate 4.037e-06, It/sec 0.059, Tokens/sec 188.558, Trained Tokens 4084952, Peak mem 192.693 GB
Iter 1359: Train loss 2.077, Learning Rate 4.530e-06, It/sec 0.058, Tokens/sec 191.755, Trained Tokens 4581628, Peak mem 192.693 GB
Iter 1510: Train loss 2.073, Learning Rate 5.022e-06, It/sec 0.054, Tokens/sec 176.646, Trained Tokens 5079093, Peak mem 192.693 GB
Iter 1517: Val loss 2.021, Val took 1346.512s
Iter 1661: Train loss 2.051, Learning Rate 5.515e-06, It/sec 0.064, Tokens/sec 199.175, Trained Tokens 5546092, Peak mem 192.693 GB
Iter 1812: Train loss 2.051, Learning Rate 6.007e-06, It/sec 0.053, Tokens/sec 185.251, Trained Tokens 6075599, Peak mem 192.693 GB
Iter 1963: Train loss nan, Learning Rate 6.500e-06, It/sec 0.048, Tokens/sec 166.686, Trained Tokens 6596895, Peak mem 192.693 GB Below is the configuration that was used (the model reference is to a 4 bit float32 quantized local copy of Qwen1.5-14B): model: "/path/to/raw_models/mlx/Qwen1.5-14B"
train: true
data: "/path/to/corpus/"
seed: 4
batch_size: 8
learning_rate: 1e-5
lora_layers: 20
iters: 15175
val_batches: 189
steps_per_report: 151
steps_per_eval: 1517
save_every: 5000
lora_parameters:
keys: ["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.o_proj"]
rank: 64
alpha: 32
dropout: 0.3205
scale: 10.0
lr_schedule:
name: cosine_decay
warmup: 3035
warmup_init: 1e-7
arguments: [1e-5, 30351, 1e-6] I have yet to reproduce this with the default data set, which is significantly simpler than the proprietary data I'm training with. |
@chimezie without a good way to reproduce this it will be hard to help debug. I can suggest a couple things that would be really helpful if you are up for it:
|
@awni I ended up publishing the dataset that seems to be able to most consistently reproduce the NaN loss values. The training data can be saved/downloaded this way: import json
from datasets import load_dataset
data = load_dataset('cogbuji/medqa_corpus_en', None, split='train[:1100]')
split = data.train_test_split(test_size=.1)
with open('/tmp/train.jsonl', 'w') as f:
for entry in split['train']:
json.dump(entry, f)
f.write('\n')
with open('/tmp/valid.jsonl', 'w') as f:
for entry in split['test']:
json.dump(entry, f)
f.write('\n') The model this was run against was downloaded/quantified this way: % mlx_lm.convert --hf-path mistralai/Mixtral-8x7B-Instruct-v0.1 -q --dtype float32 \
--mlx-path /path/to/Mixtral-8x7B-Instruct-v0.1 --q-group-size 32 Then using the following YAML configuration: model: "/path/to/Mixtral-8x7B-Instruct-v0.1"
train: true
data: "/tmp"
lora_layers: 16
batch_size: 8
learning_rate: 1e-5
lr_schedule:
name: "cosine_decay"
warmup: 100
warmup_init: 1e-7
arguments: [3e-5, 1000, 7e-6] Then the training resulting in NaN values: % mlx_lm.lora -c mlx_error.yaml
Loading configuration file mlx_error.yaml
Loading pretrained model
Trainable parameters: 0.038% (17.834M/46596.297M)
Loading datasets
Training
Starting training..., iters: 1000
Iter 1: Val loss nan, Val took 129.093s
Iter 10: Train loss nan, Learning Rate 2.791e-06, It/sec 0.077, Tokens/sec 193.032, Trained Tokens 25128, Peak mem 74.917 GB Below are the git hashes for mlx and mlx_lm: % git rev-parse HEAD
b0012cdd0f3af3b5643e63da1c6da39610fe63e6
% git rev-parse HEAD
f20e68fcc0eab129911828c00cbeb1c2a5246156 Oddly enough, switching to a smaller model (Qwen/Qwen1.5-0.5B) converted/quantized the same way, I can train without any NaN values |
Thanks for the detailed repro. I am looking now. |
That specific case should be fixed shortly in ml-explore/mlx#1028 |
I tried with the latest changes (including #1028), but still get NaN's right away |
Did you requantize the model? |
I just did and it runs without any NaNs. Thanks. |
I am trying to finetune gemma 2b using summarizer dataset from hugging face https://huggingface.co/datasets/pszemraj/govreport-summarization-8192 and always get NaN values for training and validation loss right from the first iteration.
YAML configuration used -
I am new to using MLX framework and would like your insights on how to rectify this issue. |
The fix for the test sequence length is in #743 As for the NaN, could you share how you preprocessed the data so I can reproduce it? |
I preprocessed the data by converting it into the standard gemma format for finetuning as following
This is the formatting code -
|
Thanks I will take a look tomorrow! |
@awni Do we have any update regarding the issue? |
@mukundsayeeganesh I opened a relevant issue in MLX core ml-explore/mlx#1084. Once that is fixed the NaNs for your case should be resolved. |
Thanks for the update. |
Slightly unrelated: your LoRA settings are using a lot of memory. I would consider trying to decrease memory consumption by either:
You can see the peak mem is over 200GB which will cause swapping and be very very slow. Also I think you will work around the NaN issue if you use a batch size of 1 or 2 for now. |
Yeah, understood. I will try it out later today and let you know how it works out. |
@awni I tried your suggestion - Reducing the batch size to 2 worked well without any NaN values. |
I'm getting nan values for training and validation loss with recent git versions of mlx and mlx_lm (see commit hashes below).
This is the configuration file I'm using:
This is how I'm launching the LoRA training:
The two times I ran it, the nan values started being reported at different iterations: Iter 375 first then at 325.
The model is an mlx-quantized HF download of OpenHermes 2.5 Mistral 7B (teknium/OpenHermes-2.5-Mistral-7B)
The git commit hashes are:
mlx % git rev-parse HEAD 28fcd2b519f0fabcd681f2c33e14d71983cad819 mlx-examples % git rev-parse HEAD 0ab01b4626cfca974ea8616370da2d0e3254a205```
The text was updated successfully, but these errors were encountered: