Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting nan values for training and validation loss #620

Closed
chimezie opened this issue Mar 26, 2024 · 34 comments
Closed

Getting nan values for training and validation loss #620

chimezie opened this issue Mar 26, 2024 · 34 comments

Comments

@chimezie
Copy link
Contributor

chimezie commented Mar 26, 2024

I'm getting nan values for training and validation loss with recent git versions of mlx and mlx_lm (see commit hashes below).

This is the configuration file I'm using:

config.yaml:
model: "/path/to/mlx/model"

train: true
lora_layers: 16
batch_size: 4
iters: 2596
steps_per_report: 25
steps_per_eval: 259
val_batches: 32
learning_rate: 7e-5
seed: 4
lora_parameters:
  keys: ["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.o_proj"]
  rank: 64
  alpha: 128
  dropout: 0.0
  scale: 10.0

This is how I'm launching the LoRA training:

% python -m mlx_lm.lora --data mlx-examples/lora/data -c config.yaml                                  
Loading configuration file config.yaml
Loading pretrained model
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Trainable parameters: 0.376% (27.263M/7241.748M)
Loading datasets
Training
Starting training..., iters: 2596
Iter 1: Val loss 2.381, Val took 16.942s
Iter 25: Train loss 8.108, Learning Rate 7.000e-05, It/sec 0.984, Tokens/sec 384.079, Trained Tokens 9761, Peak mem 6.906 GB
Iter 50: Train loss 6.509, Learning Rate 7.000e-05, It/sec 0.934, Tokens/sec 364.859, Trained Tokens 19531, Peak mem 7.058 GB
Iter 75: Train loss 6.027, Learning Rate 7.000e-05, It/sec 0.908, Tokens/sec 359.912, Trained Tokens 29444, Peak mem 8.382 GB
Iter 100: Train loss 5.575, Learning Rate 7.000e-05, It/sec 0.925, Tokens/sec 366.800, Trained Tokens 39358, Peak mem 8.906 GB
Iter 100: Saved adapter weights to checkpoints/100_MrGrammaticaOntology-corpus-pretraining.npz.
Iter 125: Train loss 5.393, Learning Rate 7.000e-05, It/sec 0.952, Tokens/sec 364.491, Trained Tokens 48927, Peak mem 8.906 GB
Iter 150: Train loss 5.236, Learning Rate 7.000e-05, It/sec 0.880, Tokens/sec 354.642, Trained Tokens 59006, Peak mem 8.906 GB
Iter 175: Train loss 5.164, Learning Rate 7.000e-05, It/sec 0.963, Tokens/sec 363.904, Trained Tokens 68453, Peak mem 8.906 GB
Iter 200: Train loss 5.453, Learning Rate 7.000e-05, It/sec 0.969, Tokens/sec 365.781, Trained Tokens 77893, Peak mem 8.906 GB
Iter 200: Saved adapter weights to checkpoints/200_MrGrammaticaOntology-corpus-pretraining.npz.
Iter 225: Train loss 6.007, Learning Rate 7.000e-05, It/sec 0.905, Tokens/sec 364.433, Trained Tokens 87963, Peak mem 8.906 GB
Iter 250: Train loss 6.085, Learning Rate 7.000e-05, It/sec 0.934, Tokens/sec 365.032, Trained Tokens 97732, Peak mem 8.906 GB
Iter 259: Val loss 6.359, Val took 17.220s
Iter 275: Train loss 6.106, Learning Rate 7.000e-05, It/sec 1.399, Tokens/sec 537.295, Trained Tokens 107330, Peak mem 8.906 GB
Iter 300: Train loss 6.236, Learning Rate 7.000e-05, It/sec 0.974, Tokens/sec 364.823, Trained Tokens 116697, Peak mem 8.906 GB
Iter 300: Saved adapter weights to checkpoints/300_adapter.npz.
Iter 325: Train loss nan, Learning Rate 7.000e-05, It/sec 0.926, Tokens/sec 357.395, Trained Tokens 126344, Peak mem 8.906 GB
Iter 350: Train loss nan, Learning Rate 7.000e-05, It/sec 0.903, Tokens/sec 368.899, Trained Tokens 136561, Peak mem 8.906 GB
Iter 375: Train loss nan, Learning Rate 7.000e-05, It/sec 0.924, Tokens/sec 370.353, Trained Tokens 146581, Peak mem 8.906 GB
Iter 400: Train loss nan, Learning Rate 7.000e-05, It/sec 0.893, Tokens/sec 368.682, Trained Tokens 156901, Peak mem 8.986 GB
Iter 400: Saved adapter weights to checkpoints/400_adapter.npz.

The two times I ran it, the nan values started being reported at different iterations: Iter 375 first then at 325.

The model is an mlx-quantized HF download of OpenHermes 2.5 Mistral 7B (teknium/OpenHermes-2.5-Mistral-7B)

The git commit hashes are:

mlx % git rev-parse HEAD
28fcd2b519f0fabcd681f2c33e14d71983cad819

mlx-examples % git rev-parse HEAD              
0ab01b4626cfca974ea8616370da2d0e3254a205```
@angeloskath
Copy link
Member

Hmm interesting. I will be investigating (and hopefully fixing) ml-explore/mlx#896 (comment) tonight. But this could be unrelated. It seems the loss jumps really high really quickly so there is definitely something going on there. Were you using the exact same configuration before? What was the loss evolution before these commits.

This could also be related to #613, but if that's the case I think a better fix is to reduce points towards the training being in an unstable regime anyway. Maybe reducing alpha or scale would help?

@chimezie
Copy link
Contributor Author

I started off using a configuration of my own on a dataset I've been having difficulty fitting with low training/validation loss for about a week or so. Then, today, I started noticing the NAN values and wanted to rule out if it was recent changes or from the data I was using, so I just tried to mock up the most easily reproducible example here.

I'll try lower values of alpha. They were high mostly because I had been having issues fitting the data with the default LoRA values and was looking for values that would fit but also not obliterate the base model's capability, and it seems higher values were needed.

@angeloskath
Copy link
Member

Oh that's super nice, I didn't see this was on the example data. I can play with it as well. Thanks!

@chimezie
Copy link
Contributor Author

I lowered the alpha and rank to their default values (16 and 8 respectively) and 800 iterations in I'm not seeing NaN values:

Training
Starting training..., iters: 2596
Iter 1: Val loss 2.086, Val took 17.177s
Iter 25: Train loss 1.593, Learning Rate 7.000e-05, It/sec 0.999, Tokens/sec 387.697, Trained Tokens 9699, Peak mem 8.104 GB
Iter 50: Train loss 1.406, Learning Rate 7.000e-05, It/sec 0.942, Tokens/sec 369.778, Trained Tokens 19515, Peak mem 8.104 GB
Iter 75: Train loss 3.316, Learning Rate 7.000e-05, It/sec 0.933, Tokens/sec 366.098, Trained Tokens 29327, Peak mem 8.104 GB
Iter 100: Train loss 6.437, Learning Rate 7.000e-05, It/sec 0.921, Tokens/sec 371.347, Trained Tokens 39407, Peak mem 8.626 GB
Iter 100: Saved adapter weights to checkpoints/100_adapters.npz.
Iter 125: Train loss 5.478, Learning Rate 7.000e-05, It/sec 0.952, Tokens/sec 379.485, Trained Tokens 49374, Peak mem 8.626 GB
Iter 150: Train loss 4.857, Learning Rate 7.000e-05, It/sec 0.973, Tokens/sec 367.105, Trained Tokens 58811, Peak mem 8.626 GB
Iter 175: Train loss 4.874, Learning Rate 7.000e-05, It/sec 0.959, Tokens/sec 371.460, Trained Tokens 68497, Peak mem 8.626 GB
Iter 200: Train loss 4.604, Learning Rate 7.000e-05, It/sec 0.943, Tokens/sec 363.135, Trained Tokens 78122, Peak mem 8.626 GB
Iter 200: Saved adapter weights to checkpoints/200_adapters.npz.
Iter 225: Train loss 4.430, Learning Rate 7.000e-05, It/sec 0.913, Tokens/sec 370.260, Trained Tokens 88265, Peak mem 8.626 GB
Iter 250: Train loss 3.892, Learning Rate 7.000e-05, It/sec 0.954, Tokens/sec 361.173, Trained Tokens 97727, Peak mem 8.626 GB
Iter 259: Val loss 4.187, Val took 17.317s
Iter 275: Train loss 3.681, Learning Rate 7.000e-05, It/sec 1.523, Tokens/sec 573.098, Trained Tokens 107132, Peak mem 8.626 GB
Iter 300: Train loss 3.763, Learning Rate 7.000e-05, It/sec 0.952, Tokens/sec 371.709, Trained Tokens 116893, Peak mem 8.626 GB
Iter 300: Saved adapter weights to checkpoints/300_adapters.npz.
Iter 325: Train loss 3.751, Learning Rate 7.000e-05, It/sec 0.920, Tokens/sec 369.209, Trained Tokens 126930, Peak mem 8.707 GB
Iter 350: Train loss 3.546, Learning Rate 7.000e-05, It/sec 0.924, Tokens/sec 368.282, Trained Tokens 136898, Peak mem 8.707 GB
Iter 375: Train loss 3.306, Learning Rate 7.000e-05, It/sec 0.948, Tokens/sec 374.389, Trained Tokens 146775, Peak mem 8.707 GB
Iter 400: Train loss 3.552, Learning Rate 7.000e-05, It/sec 0.916, Tokens/sec 370.725, Trained Tokens 156896, Peak mem 8.707 GB
Iter 400: Saved adapter weights to checkpoints/400_adapters.npz.
Iter 425: Train loss 3.236, Learning Rate 7.000e-05, It/sec 0.966, Tokens/sec 364.650, Trained Tokens 166331, Peak mem 8.707 GB
Iter 450: Train loss 3.056, Learning Rate 7.000e-05, It/sec 0.956, Tokens/sec 364.666, Trained Tokens 175865, Peak mem 8.707 GB
Iter 475: Train loss 2.656, Learning Rate 7.000e-05, It/sec 0.907, Tokens/sec 369.480, Trained Tokens 186046, Peak mem 8.707 GB
Iter 500: Train loss 2.088, Learning Rate 7.000e-05, It/sec 0.970, Tokens/sec 364.845, Trained Tokens 195454, Peak mem 8.707 GB
Iter 500: Saved adapter weights to checkpoints/500_adapters.npz.
Iter 518: Val loss 1.685, Val took 17.378s
Iter 525: Train loss 1.645, Learning Rate 7.000e-05, It/sec 3.284, Tokens/sec 1274.637, Trained Tokens 205156, Peak mem 8.707 GB
Iter 550: Train loss 1.382, Learning Rate 7.000e-05, It/sec 0.943, Tokens/sec 363.792, Trained Tokens 214797, Peak mem 8.707 GB
Iter 575: Train loss 1.415, Learning Rate 7.000e-05, It/sec 0.924, Tokens/sec 372.226, Trained Tokens 224868, Peak mem 8.707 GB
Iter 600: Train loss 1.271, Learning Rate 7.000e-05, It/sec 0.943, Tokens/sec 367.870, Trained Tokens 234621, Peak mem 8.707 GB
Iter 600: Saved adapter weights to checkpoints/600_adapters.npz.
Iter 625: Train loss 1.350, Learning Rate 7.000e-05, It/sec 0.920, Tokens/sec 371.598, Trained Tokens 244716, Peak mem 8.707 GB
Iter 650: Train loss 1.149, Learning Rate 7.000e-05, It/sec 0.927, Tokens/sec 366.723, Trained Tokens 254601, Peak mem 8.707 GB
Iter 675: Train loss 1.151, Learning Rate 7.000e-05, It/sec 0.967, Tokens/sec 373.507, Trained Tokens 264253, Peak mem 8.707 GB
Iter 700: Train loss 1.127, Learning Rate 7.000e-05, It/sec 0.951, Tokens/sec 372.192, Trained Tokens 274039, Peak mem 8.707 GB
Iter 700: Saved adapter weights to checkpoints/700_adapters.npz.
Iter 725: Train loss 1.008, Learning Rate 7.000e-05, It/sec 0.987, Tokens/sec 372.797, Trained Tokens 283484, Peak mem 8.707 GB
Iter 750: Train loss 0.962, Learning Rate 7.000e-05, It/sec 0.946, Tokens/sec 366.860, Trained Tokens 293181, Peak mem 8.707 GB
Iter 775: Train loss 1.021, Learning Rate 7.000e-05, It/sec 0.931, Tokens/sec 357.233, Trained Tokens 302772, Peak mem 8.707 GB
Iter 777: Val loss 1.158, Val took 17.297s
Iter 800: Train loss 0.988, Learning Rate 7.000e-05, It/sec 1.018, Tokens/sec 413.309, Trained Tokens 312918, Peak mem 8.707 GB

@awni
Copy link
Member

awni commented Mar 26, 2024

I'm running this right now:

python -m mlx_lm.lora --model  mlx-community/Nous-Hermes-2-Mistral-7B-DPO-4bit-MLX --data ../lora/data --train 

So far no NaNs or loss spikes (which is also odd).

Are we using the same model?

@awni
Copy link
Member

awni commented Mar 26, 2024

I was able to reproduce the loss spike when using the same layer keys and learning rate. A smaller learning rate and/or not using all the layers works fine.

It looks like the learning rate may just be on the high side. On the other had, the behavior shouldn't be too different from old versions. The main thing I can think of is we changed where we cast from low to high precision..

@chimezie do you happen to know which versions it was working for you with that setup?

@awni
Copy link
Member

awni commented Mar 26, 2024

I went back to mlx==0.6.0 and mlx-lm==0.2.0 and I still see the loss spiking with the high learning rate.

I think it might be worth tuning the learning rate with the latest MLX given the changes to casting. I don't know the cause of the NaNs precisely.. but the fact that the loss spikes means the run is likely already lost and it wouldn't suprise me if the low precision is overflowing soon after.

Another thing you could do for stability (but it sill slow things down) is to quantize the model but use fp32 as the precision for all the weights.

@JKwon0331
Copy link

I have a similar problem with nan values for training and validation loss.

I trained my own model, and saved it.
After loading the weight, it showed the nan value for the losses, which was not shown in the training phase.
Could you help me?

@awni
Copy link
Member

awni commented Mar 26, 2024

  • Please provide some steps to reproduce
  • Please provide information on your versions / environment.

@chimezie
Copy link
Contributor Author

@awni I don't remember which version I used the last time I successfully did mlx_lm tuning on the same dataset, but I believe it was around Feb 15th. I'm going to see how much I can get from boosting alpha and r values above the default (which seems too low for the domain) to help fit the data at reasonable loss values without having to have the LR be so high. I'll also quantize with fp32, which may help towards the same end.

@awni
Copy link
Member

awni commented Mar 26, 2024

Ok let's close this for now as I believe this is related to overflowing fp16.

@JKwon0331 if you have a NaN issue please open a new issue and provide some more details, we can help debug

@awni awni closed this as completed Mar 26, 2024
@chimezie
Copy link
Contributor Author

chimezie commented Apr 9, 2024

I'm getting this problem again. This time training a Qwen1.5-14B model quantized with float32 adn converted this way:

% python -m mlx_lm.convert --hf-path Qwen/Qwen1.5-14B -q --dtype float32

The git hashes for mlx and mlx-example are below

 mlx-examples % git rev-parse HEAD                  
c386dd5f5a1c1d40f94d6f3bd7b5bd25929e05aa
 mlx % git rev-parse HEAD
bddf23f175726a57f0e443cd45518c0757daa166

This is the output of the training that resulted in NaN values:

Iter 1: Val loss 2.199, Val took 1281.939s
Iter 151: Train loss 2.200, Learning Rate 2.975e-06, It/sec 0.063, Tokens/sec 209.210, Trained Tokens 503331, Peak mem 102.075 GB
Iter 302: Train loss 2.117, Learning Rate 5.961e-06, It/sec 0.059, Tokens/sec 200.795, Trained Tokens 1012934, Peak mem 143.826 GB
Iter 453: Train loss 2.098, Learning Rate 8.946e-06, It/sec 0.061, Tokens/sec 208.397, Trained Tokens 1527324, Peak mem 143.826 GB
Iter 604: Train loss 2.080, Learning Rate 1.193e-05, It/sec 0.057, Tokens/sec 198.680, Trained Tokens 2054545, Peak mem 143.826 GB
Iter 755: Train loss 2.077, Learning Rate 1.492e-05, It/sec 0.062, Tokens/sec 202.874, Trained Tokens 2546743, Peak mem 143.826 GB
Iter 906: Train loss 2.068, Learning Rate 1.790e-05, It/sec 0.058, Tokens/sec 204.168, Trained Tokens 3074828, Peak mem 143.826 GB
Iter 1057: Train loss 2.043, Learning Rate 2.089e-05, It/sec 0.056, Tokens/sec 195.615, Trained Tokens 3605623, Peak mem 143.826 GB
Iter 1208: Train loss 2.055, Learning Rate 2.387e-05, It/sec 0.062, Tokens/sec 206.878, Trained Tokens 4110642, Peak mem 143.826 GB
Iter 1359: Train loss 2.052, Learning Rate 2.686e-05, It/sec 0.067, Tokens/sec 206.440, Trained Tokens 4575392, Peak mem 143.826 GB
Iter 1510: Train loss 2.029, Learning Rate 2.984e-05, It/sec 0.062, Tokens/sec 207.524, Trained Tokens 5080397, Peak mem 143.826 GB
Iter 1517: Val loss 1.984, Val took 1277.913s
Iter 1661: Train loss 2.026, Learning Rate 2.999e-05, It/sec 0.066, Tokens/sec 218.319, Trained Tokens 5578308, Peak mem 143.826 GB
Iter 1812: Train loss 2.021, Learning Rate 2.997e-05, It/sec 0.064, Tokens/sec 205.973, Trained Tokens 6067475, Peak mem 143.826 GB
Iter 1963: Train loss 2.020, Learning Rate 2.994e-05, It/sec 0.061, Tokens/sec 200.265, Trained Tokens 6560695, Peak mem 143.966 GB
Iter 2114: Train loss 2.038, Learning Rate 2.989e-05, It/sec 0.060, Tokens/sec 207.632, Trained Tokens 7083937, Peak mem 143.966 GB
Iter 2265: Train loss nan, Learning Rate 2.982e-05, It/sec 0.056, Tokens/sec 194.322, Trained Tokens 7609702, Peak mem 160.765 GB
Iter 2416: Train loss nan, Learning Rate 2.974e-05, It/sec 0.062, Tokens/sec 208.390, Trained Tokens 8114928, Peak mem 160.765 GB
Iter 2567: Train loss nan, Learning Rate 2.965e-05, It/sec 0.059, Tokens/sec 207.165, Trained Tokens 8645401, Peak mem 160.765 GB
Iter 2718: Train loss nan, Learning Rate 2.954e-05, It/sec 0.067, Tokens/sec 206.968, Trained Tokens 9110687, Peak mem 160.765 GB
Iter 2869: Train loss nan, Learning Rate 2.942e-05, It/sec 0.062, Tokens/sec 208.907, Trained Tokens 9620573, Peak mem 160.765 GB
Iter 3020: Train loss nan, Learning Rate 2.928e-05, It/sec 0.060, Tokens/sec 208.873, Trained Tokens 10148813, Peak mem 160.765 GB
Iter 3034: Val loss nan, Val took 1295.306s
[..snip..]

The lora parameters were

  • alpha: 128,
  • dropout: 0.3205,
  • rank: 64
  • scale': 10.0

As I indicated in this ticket earlier, I have had to use a learning rate and alpha/rank values a little higher than normal because I was not getting convergence training otherwise. Still, this run was using a warmup period w/ Cosine annealing:

lr_schedule:
  name: "cosine_decay"
  warmup: 100
  warmup_init: 1e-8
  arguments: [3e-5, 900, 7e-6]

I'm having trouble isolating a reproducible run with the git LoRa data, but I will provide a follow-up if/when I can.

@awni
Copy link
Member

awni commented Apr 9, 2024

@chimezie I don't see a NaN in that log you shared. What am I missing?

@chimezie
Copy link
Contributor Author

chimezie commented Apr 9, 2024

Sorry. I have updated the comment with the training/validation errors, including the NaN values

@chimezie
Copy link
Contributor Author

Running on the following git hashes, I'm still getting NaN values training Qwen1.5-14.

mlx % git rev-parse HEAD
99abb9eff4779700741c3faa92d7fdcb259e2022
mlx-examples % git rev-parse HEAD
eff6690952847386aa3cc375b4ac83decc886868

I tried lowering the learning rate, alpha, and rank as well:

learning_rate: 1e-5
lora_layers: 20
lora_parameters:
  alpha: 64
  dropout: 0.3205
  rank: 32
  scale: 10.0

lr_schedule:
  name: cosine_decay
  warmup: 1000
  warmup_init: 1e-8
  arguments: [1e-5, 15175, 7e-6]
Iter 1: Val loss 2.199, Val took 1283.062s
Iter 151: Train loss 2.171, Learning Rate 1.987e-06, It/sec 0.062, Tokens/sec 209.019, Trained Tokens 507685, Peak mem 99.752 GB
Iter 302: Train loss 2.123, Learning Rate 3.976e-06, It/sec 0.059, Tokens/sec 196.263, Trained Tokens 1006928, Peak mem 143.892 GB
Iter 453: Train loss 2.081, Learning Rate 5.966e-06, It/sec 0.057, Tokens/sec 199.289, Trained Tokens 1530602, Peak mem 143.892 GB
Iter 604: Train loss 2.050, Learning Rate 7.956e-06, It/sec 0.059, Tokens/sec 194.431, Trained Tokens 2025091, Peak mem 143.892 GB
Iter 755: Train loss 2.071, Learning Rate 9.946e-06, It/sec 0.061, Tokens/sec 203.686, Trained Tokens 2532112, Peak mem 143.892 GB
Iter 906: Train loss 2.070, Learning Rate 1.194e-05, It/sec 0.065, Tokens/sec 202.642, Trained Tokens 3005750, Peak mem 143.892 GB
Iter 1057: Train loss 2.071, Learning Rate 1.393e-05, It/sec 0.062, Tokens/sec 203.865, Trained Tokens 3500267, Peak mem 143.892 GB
Iter 1208: Train loss 2.061, Learning Rate 1.592e-05, It/sec 0.060, Tokens/sec 206.005, Trained Tokens 4022766, Peak mem 143.892 GB
Iter 1359: Train loss nan, Learning Rate 1.790e-05, It/sec 0.055, Tokens/sec 193.330, Trained Tokens 4551018, Peak mem 158.316 GB
Iter 1510: Train loss nan, Learning Rate 1.989e-05, It/sec 0.063, Tokens/sec 211.451, Trained Tokens 5054493, Peak mem 158.316 GB
Iter 1517: Val loss nan, Val took 1278.666s

@awni
Copy link
Member

awni commented Apr 11, 2024

Sorry @chimezie haven't had a chance to debug yet. Is this with the default data set? Can you share the exact training command you are using so I can repro?

@chimezie
Copy link
Contributor Author

@awni I just tried it again using this commandline (against my own data):

$ % python -m mlx_lm.lora -c train.yaml
[..snip..]
Iter 1: Val loss 2.199, Val took 1358.465s
Iter 151: Train loss 2.221, Learning Rate 5.893e-07, It/sec 0.055, Tokens/sec 178.991, Trained Tokens 495465, Peak mem 192.693 GB
Iter 302: Train loss 2.211, Learning Rate 1.082e-06, It/sec 0.055, Tokens/sec 185.651, Trained Tokens 1006881, Peak mem 192.693 GB
Iter 453: Train loss 2.145, Learning Rate 1.574e-06, It/sec 0.056, Tokens/sec 189.687, Trained Tokens 1521827, Peak mem 192.693 GB
Iter 604: Train loss 2.142, Learning Rate 2.067e-06, It/sec 0.059, Tokens/sec 199.163, Trained Tokens 2027825, Peak mem 192.693 GB
Iter 755: Train loss 2.096, Learning Rate 2.560e-06, It/sec 0.056, Tokens/sec 188.073, Trained Tokens 2536171, Peak mem 192.693 GB
Iter 906: Train loss 2.097, Learning Rate 3.052e-06, It/sec 0.052, Tokens/sec 186.432, Trained Tokens 3073332, Peak mem 192.693 GB
Iter 1057: Train loss 2.081, Learning Rate 3.545e-06, It/sec 0.055, Tokens/sec 192.338, Trained Tokens 3602905, Peak mem 192.693 GB
Iter 1208: Train loss 2.065, Learning Rate 4.037e-06, It/sec 0.059, Tokens/sec 188.558, Trained Tokens 4084952, Peak mem 192.693 GB
Iter 1359: Train loss 2.077, Learning Rate 4.530e-06, It/sec 0.058, Tokens/sec 191.755, Trained Tokens 4581628, Peak mem 192.693 GB
Iter 1510: Train loss 2.073, Learning Rate 5.022e-06, It/sec 0.054, Tokens/sec 176.646, Trained Tokens 5079093, Peak mem 192.693 GB
Iter 1517: Val loss 2.021, Val took 1346.512s
Iter 1661: Train loss 2.051, Learning Rate 5.515e-06, It/sec 0.064, Tokens/sec 199.175, Trained Tokens 5546092, Peak mem 192.693 GB
Iter 1812: Train loss 2.051, Learning Rate 6.007e-06, It/sec 0.053, Tokens/sec 185.251, Trained Tokens 6075599, Peak mem 192.693 GB
Iter 1963: Train loss nan, Learning Rate 6.500e-06, It/sec 0.048, Tokens/sec 166.686, Trained Tokens 6596895, Peak mem 192.693 GB

Below is the configuration that was used (the model reference is to a 4 bit float32 quantized local copy of Qwen1.5-14B):

model: "/path/to/raw_models/mlx/Qwen1.5-14B"
train: true
data: "/path/to/corpus/"
seed: 4
batch_size: 8
learning_rate: 1e-5
lora_layers: 20

iters: 15175
val_batches: 189
steps_per_report: 151
steps_per_eval: 1517
save_every: 5000

lora_parameters:
  keys: ["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.o_proj"]
  rank: 64
  alpha: 32
  dropout: 0.3205
  scale: 10.0

lr_schedule:
  name: cosine_decay
  warmup: 3035
  warmup_init: 1e-7
  arguments: [1e-5, 30351, 1e-6]

I have yet to reproduce this with the default data set, which is significantly simpler than the proprietary data I'm training with.

@awni
Copy link
Member

awni commented Apr 12, 2024

@chimezie without a good way to reproduce this it will be hard to help debug.

I can suggest a couple things that would be really helpful if you are up for it:

  • Try to reduce the time to reproduce the NaN. If you can get it to show up with a small fp32 model for example that would be really useful.
  • Try to reproduce on an open dataset that you are comfortable sharing so we can reproduce the issue ourselves and debug it.

@chimezie
Copy link
Contributor Author

chimezie commented Apr 22, 2024

@awni I ended up publishing the dataset that seems to be able to most consistently reproduce the NaN loss values.

The training data can be saved/downloaded this way:

import json
from datasets import load_dataset
data = load_dataset('cogbuji/medqa_corpus_en', None, split='train[:1100]')
split = data.train_test_split(test_size=.1)
with open('/tmp/train.jsonl', 'w') as f:
    for entry in split['train']:
        json.dump(entry, f)
        f.write('\n')

with open('/tmp/valid.jsonl', 'w') as f:
    for entry in split['test']:
        json.dump(entry, f)
        f.write('\n')

The model this was run against was downloaded/quantified this way:

% mlx_lm.convert --hf-path mistralai/Mixtral-8x7B-Instruct-v0.1 -q --dtype float32 \
                 --mlx-path /path/to/Mixtral-8x7B-Instruct-v0.1  --q-group-size 32 

Then using the following YAML configuration:

model: "/path/to/Mixtral-8x7B-Instruct-v0.1"
train: true
data: "/tmp"
lora_layers: 16
batch_size: 8
learning_rate: 1e-5

lr_schedule:
  name: "cosine_decay"
  warmup: 100
  warmup_init: 1e-7
  arguments: [3e-5, 1000, 7e-6]

Then the training resulting in NaN values:

% mlx_lm.lora -c mlx_error.yaml
Loading configuration file mlx_error.yaml
Loading pretrained model
Trainable parameters: 0.038% (17.834M/46596.297M)
Loading datasets
Training
Starting training..., iters: 1000
Iter 1: Val loss nan, Val took 129.093s
Iter 10: Train loss nan, Learning Rate 2.791e-06, It/sec 0.077, Tokens/sec 193.032, Trained Tokens 25128, Peak mem 74.917 GB

Below are the git hashes for mlx and mlx_lm:

% git rev-parse HEAD
b0012cdd0f3af3b5643e63da1c6da39610fe63e6
% git rev-parse HEAD
f20e68fcc0eab129911828c00cbeb1c2a5246156

Oddly enough, switching to a smaller model (Qwen/Qwen1.5-0.5B) converted/quantized the same way, I can train without any NaN values

@awni
Copy link
Member

awni commented Apr 23, 2024

Thanks for the detailed repro. I am looking now.

@awni
Copy link
Member

awni commented Apr 23, 2024

That specific case should be fixed shortly in ml-explore/mlx#1028

@chimezie
Copy link
Contributor Author

I tried with the latest changes (including #1028), but still get NaN's right away

@awni
Copy link
Member

awni commented Apr 24, 2024

Did you requantize the model?

@chimezie
Copy link
Contributor Author

I just did and it runs without any NaNs. Thanks.

@mukundsayeeganesh
Copy link

I am trying to finetune gemma 2b using summarizer dataset from hugging face https://huggingface.co/datasets/pszemraj/govreport-summarization-8192 and always get NaN values for training and validation loss right from the first iteration.

Loading configuration file config/lora_config_default.yml
Loading pretrained model
Trainable parameters: 0.033% (0.819M/2506.172M)
Loading datasets
Training
Starting training..., iters: 2
Iter 1: Val loss nan, Val took 281.075s
Iter 2: Train loss nan, Learning Rate 1.000e-06, It/sec 0.100, Tokens/sec 449.192, Trained Tokens 44928, Peak mem 214.779 GB
Iter 2: Val loss nan, Val took 59.269s
Saved final adapter weights to adapters/adapters.safetensors.
Testing
[WARNING] Some sequences are longer than 2048 tokens. The longest sentence 6517 will be truncated to 2048. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 2048 tokens. The longest sentence 3402 will be truncated to 2048. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 2048 tokens. The longest sentence 6839 will be truncated to 2048. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 2048 tokens. The longest sentence 8160 will be truncated to 2048. Consider pre-splitting your data to save memory.
[WARNING] Some sequences are longer than 2048 tokens. The longest sentence 5607 will be truncated to 2048. Consider pre-splitting your data to save memory.
Test loss 2.797, Test ppl 16.388.

YAML configuration used -

# The path to the local model directory or Hugging Face repo.
model: "google/gemma-2b-it"
data: "/Users/aifocal/Documents/Workspace/Finetuning/llm_finetuning/data"
lora_layers: 16
batch_size: 4

# Iterations to train for.
iters: 2

# Adam learning rate.
learning_rate: 1e-5

# Load path to resume training with the given adapter weights.
resume_adapter_file: null

# Save/load path for the trained adapter weights.
adapter_file: "adapters.npz"

max_seq_length: 8192
grad_checkpoint: false

# LoRA parameters can only be specified in a config file
lora_parameters:
  keys: ["self_attn.q_proj", "self_attn.v_proj"]
  rank: 8
  alpha: 16.0
  scale: 10.0
  dropout: 0.0

I am new to using MLX framework and would like your insights on how to rectify this issue.
Also I face max tokens warning during testing even though the max tokens provided in the dataset is set to 8192. Can you help me by guiding on how to rectify these issues?

@awni
Copy link
Member

awni commented May 1, 2024

The fix for the test sequence length is in #743

As for the NaN, could you share how you preprocessed the data so I can reproduce it?

@mukundsayeeganesh
Copy link

I preprocessed the data by converting it into the standard gemma format for finetuning as following

<bos><start_of_turn>user
## Instruction
You're a proficient summarizing tool, designed to condense information efficiently while maintaining politeness. Please summarize the text clearly and succinctly.
## User
There are some similarities in how Medicare pays ASCs and hospital outpatient departments for the procedures they perform. However, the methods used by CMS to calculate the payment rates in each system, as well as the mechanisms used to revise the Medicare payment rates, differ. In 1980, legislation was enacted that enabled ASCs to bill Medicare for certain surgical procedures provided to Medicare beneficiaries. Under the ASC payment system, Medicare pays a predetermined, and generally all- inclusive, amount per procedure to the facility. The approximately 2,500 surgical procedures that ASCs may bill for under Medicare are assigned to one of nine payment groups that contain procedures with similar costs, but not necessarily clinical similarities. All procedures assigned to one payment group are paid at the same rate. Under the Medicare payment system, when more than one procedure is performed at the same time, the ASC receives a payment for each of the procedures. However, the procedure that has the highest payment rate receives 100 percent of the applicable payment, and each additional procedure receives 50 percent of the applicable payment. The Medicare payment for a procedure performed at an ASC is intended to cover the direct costs for a procedure, such as nursing and technician services, drugs, medical and surgical supplies and equipment, anesthesia materials, and diagnostic services (including imaging services), and the indirect costs associated with the procedure, including use of the facility and related administrative services. The ASC payment for a procedure does not include payment for implantable devices or prosthetics related to the procedure; ASCs may bill separately for those items. In addition, the payment to the ASC does not include payment for professional services associated with the procedure; the physician who performs the procedure and the anesthesiologist or anesthetist bill Medicare directly for their services. Finally, the ASC payment does not include payment for certain other services that are not directly related to performing the procedure and do not occur during the time that the procedure takes place, such as some laboratory, X-ray, and other diagnostic tests. Because these additional services are not ASC procedures, they may be performed by another provider. In those cases, Medicare makes payments to those providers for the additional services. For example, a laboratory service needed to evaluate a tissue sample removed during an ASC procedure is not included in the ASC payment. The provider that evaluated the tissue sample would bill and receive payment from Medicare for that service. Because ASCs receive one inclusive payment for the procedure performed and its associated services, such as drugs, they generally include on their Medicare claim only the procedure performed. In 1997, legislation was enacted that required the implementation of a prospective payment system for hospital outpatient departments; the OPPS was implemented in August 2000. Although ASCs perform only procedures, hospital outpatient departments provide a much broader array of services, including diagnostic services, such as X-rays and laboratory tests, and emergency room and clinic visits. Each of the approximately 5,500 services, including procedures, that hospital outpatient departments perform is assigned to one of over 800 APC groups with other services with clinical and cost similarities for payment under the OPPS. All services assigned to one APC group are paid the same rate. Similar to ASCs, when hospitals perform multiple procedures at the same time, they receive 100 percent of the applicable payment for the procedure that has the highest payment rate, and 50 percent of the applicable payment for each additional procedure, subject to certain exceptions. Like payments to ASCs, payment for a procedure under the OPPS is intended to cover the costs of the use of the facility, nursing and technician services, most drugs, medical and surgical supplies and equipment, anesthesia materials, and administrative costs. Medicare payment to a hospital for a procedure does not include professional services for physicians or other nonphysician practitioners. These services are paid for separately by Medicare. However, there are some differences between ASC and OPPS payments for procedures. Under the OPPS, hospital outpatient departments generally may not bill separately for implantable devices related to the procedure, but they may bill separately for additional services that are directly related to the procedure, such as certain drugs and diagnostic services, including X-rays. Hospital outpatient departments also may bill separately for additional services that are not directly related to the procedure and do not occur during the procedure, such as laboratory services to evaluate a tissue sample. Because they provide a broader array of services, and because CMS has encouraged hospitals to report all services provided during a procedure on their Medicare claims for rate-setting purposes, hospital claims may provide more detail about the services delivered during a procedure than ASC claims do. CMS set the initial 1982 ASC payment rates based on cost and charge data from 40 ASCs. At that time, there were about 125 ASCs in operation. Procedures were placed into four payment groups, and all procedures in a group were paid the same rate. When the ASC payment system was first established, federal law required CMS to review the payment rates periodically. In 1986, CMS conducted an ASC survey to gather cost and charge data. In 1990, using these data, CMS revised the payment rates and increased the number of payment groups to eight. A ninth payment group was established in 1991. These groups are still in use, although some procedures have been added to or deleted from the ASC-approved list. Although payments have not been revised using ASC cost data since 1990, the payment rates have been periodically updated for inflation. In 1994, Congress required that CMS conduct a survey of ASC costs no later than January 1, 1995, and thereafter every 5 years, to revise ASC payment rates. CMS conducted a survey in 1994 to collect ASC cost data. In 1998, CMS proposed revising ASC payment rates based on the 1994 survey data and assigned procedures performed at ASCs into payment groups that were comparable to the payment groups it was developing for the same procedures under the OPPS. However, CMS did not implement the proposal, and, as a result, the ASC payment system was not revised using the 1994 data. In 2003, MMA eliminated the requirement to conduct ASC surveys every 5 years and required CMS to implement a revised ASC payment system no later than January 1, 2008. During the course of our work, in August 2006, CMS published a proposed rule that would revise the ASC payment system effective January 1, 2008. In this proposed rule, CMS bases the revised ASC payment rates on the OPPS APC groups. However, the payment rates would be lower for ASCs. The initial OPPS payment rates, implemented in August 2000, were based on hospitals' 1996 costs. To determine the OPPS payment rates, CMS first calculates each hospital's cost for each service by multiplying the charge for that service by a cost-to-charge ratio computed from the hospital's most recently reported data. After calculating the cost of each service for each hospital, the services are grouped by their APC assignment, and a median cost for each APC group is calculated from the median costs of all services assigned to it. Using the median cost, CMS assigns each APC group a weight based on its median cost relative to the median cost of all other APCs. To obtain a payment rate for each APC group, CMS multiplies the relative weight by a factor that converts it to a dollar amount. Beginning in 2002, as required by law, the APC group payment rates have been revised annually based on the latest charge and cost data. In addition, the payment rates for services paid under the OPPS receive an annual inflation update. We found many similarities in the additional services provided by ASCs and hospital outpatient departments with the top 20 procedures. Of the additional services billed with a procedure, few resulted in an additional payment in one setting but not the other. Hospitals were paid for some of the related additional services they billed with the procedures. In the ASC setting, other providers billed Medicare for these services and received payment for them. In our analysis of Medicare claims, we found many similarities in the additional services billed in the ASC or hospital outpatient department setting with the top 20 procedures. The similar additional services are illustrated in the following four categories of services: additional procedures, laboratory services, radiology services, and anesthesia services. First, one or more additional procedures was billed with a procedure performed in either the ASC or hospital outpatient department setting for 14 of the top 20 procedures. The proportion of time each additional procedure was billed in each setting was similar. For example, when a hammertoe repair procedure was performed, our analysis indicated that another procedure to correct a bunion was billed 11 percent of the time in the ASC setting, and in the hospital outpatient setting, the procedure to correct a bunion was billed 13 percent of the time. Similarly, when a diagnostic colonoscopy was performed, an upper gastrointestinal (GI) endoscopy was billed 11 percent of the time in the ASC setting, and in the hospital setting, the upper GI endoscopy was billed 12 percent of the time. For 11 of these 14 procedures, the proportion of time each additional procedure was billed differed by less than 10 percentage points between the two settings. For the 3 remaining procedures, the percentage of time that an additional procedure was billed did not vary by more than 25 percentage points between the two settings. See appendix III for a complete list of the additional procedures billed and the proportion of time they were billed in each setting. Second, laboratory services were billed with 10 of the top 20 procedures in the hospital outpatient department setting and 7 of the top 20 procedures in the ASC setting. While these services were almost always billed by the hospital in the outpatient setting, they were typically not billed by the ASCs. These laboratory services were present in our analysis in the ASC setting because they were performed and billed by another Medicare part B provider. Third, four different radiology services were billed with 8 of the top 20 procedures. Radiology services were billed with 5 procedures in the ASC setting and with 8 procedures in the hospital outpatient department setting. The radiology services generally were included on the hospital outpatient department bills but rarely were included on the ASC bills. Similar to laboratory services, hospital outpatient departments billed for radiology services that they performed in addition to the procedures. When radiology services were billed with procedures in the ASC setting, these services generally were performed and billed by another part B provider. Fourth, anesthesia services were billed with 17 of the top 20 procedures in either the ASC or hospital outpatient settings and with 14 procedures in both settings. In virtually every case in the ASC setting, and most cases in the hospital outpatient department setting, these services were billed by another part B provider. According to our analysis, ASCs did not generally include any services other than the procedures they performed on their bills. However, in the hospital outpatient setting, some additional services were included on the hospitals' bills. We believe this is a result of the structure of the two payment systems. As ASCs generally receive payment from Medicare only for procedures, they typically include only those procedures on their bills. In contrast, hospital outpatient departments' bills often include many of the individual items or services they provide as a part of a procedure because CMS has encouraged them to do so, whether the items or services are included in the OPPS payment or paid separately. With the exception of additional procedures, there were few separate payments that could be made for additional services provided with the top 20 procedures because most of the services in our analysis were included in the Medicare payment to the ASC or hospital. Under both the Medicare ASC and OPPS payment systems, when more than one procedure is performed at the same time, the facility receives 100 percent of the applicable payment for the procedure that has the highest payment rate and 50 percent of the applicable payment for each additional procedure. As this policy is applicable to both settings, for those instances in our analysis when an additional procedure was performed with one of the top 20 procedures in either setting, the ASC or hospital outpatient department received 100 percent of the payment for the procedure with the highest payment rate and 50 percent of the payment for each lesser paid procedure. Individual drugs were billed by hospital outpatient departments for most of the top 20 procedures, although they were not present on the claims from ASCs, likely because ASCs generally cannot receive separate Medicare payments for individual drugs. However, none of the individual drugs billed by the hospital outpatient departments in our analysis resulted in an additional payment to the hospitals. In each case, the cost of the particular drug was included in the Medicare payment for the procedure. In the case of the laboratory services billed with procedures in the ASC and hospital outpatient department settings, those services were not costs included in the payment for the procedure in either setting and were paid separately in each case. For both settings, the payment was made to the provider that performed the service. In the case of the hospital outpatient department setting, the payment was generally made to the hospital, while, for procedures performed at ASCs, payment was made to another provider who performed the service. Of the four radiology services in our analysis, three were similar to the laboratory services in that they are not included in the cost of the procedure and are separately paid services under Medicare. Therefore, when hospitals provided these services, they received payment for them. In the ASC setting, these services were typically billed by a provider other than the ASC, and the provider received payment for them. The fourth radiology service is included in the payment for the procedure with which it was associated. Therefore, no separate payment was made to either ASCs or hospital outpatient departments. With regard to anesthesia services, most services were billed by and paid to a provider other than an ASC or hospital. As a group, the costs of procedures performed in ASCs have a relatively consistent relationship with the costs of the APC groups to which they would be assigned under the OPPS. That is, the APC groups accurately reflect the relative costs of procedures performed in ASCs. We found that the ASC-to-APC cost ratios were more tightly distributed around their median cost ratio than the OPPS-to-APC cost ratios were around their median cost ratio. Specifically, 45 percent of all procedures in our analysis fell within 0.10 points of the ASC-to-APC median cost ratio, and 33 percent of procedures fell within 0.10 points of the OPPS-to-APC median cost ratio. However, the costs of procedures in ASCs are substantially lower than costs for the same procedures in the hospital outpatient setting. The APC groups reflect the relative costs of procedures provided by ASCs as well as they reflect the relative costs of procedures provided in the hospital outpatient department setting. In our analysis, we listed the procedures performed at ASCs and calculated the ratio of the cost of each procedure to the cost of the APC group to which it would have been assigned, referred to as the ASC-to-APC cost ratio. We then calculated similar cost ratios for the same procedures exclusively within the OPPS. To determine an OPPS-to-APC cost ratio, we divided individual procedures' median costs, as calculated by CMS for the OPPS, by the median cost of their APC group. Our analysis of the cost ratios showed that the ASC-to-APC cost ratios were more tightly distributed around their median than were the OPPS-to-APC cost ratios; that is, there were more of them closer to the median. Specifically, 45 percent of procedures performed in ASCs fell within a 0.10 point range of the ASC-to-APC median cost ratio, and 33 percent of those procedures fell within a 0.10 point range of the OPPS-to-APC median cost ratio in the hospital outpatient department setting (see figs. 1 and 2). Therefore, there is less variation in the ASC setting between individual procedures' costs and the costs of their assigned APC groups than there is in the hospital outpatient department setting. From this outcome, we determined that the OPPS APC groups could be used to pay for procedures in ASCs. The median costs of procedures performed in ASCs were generally lower than the median costs of their corresponding APC group under the OPPS. Among all procedures in our analysis, the median ASC-to-APC cost ratio was 0.39. The ASC-to-APC cost ratios ranged from 0.02 to 3.34. When weighted by Medicare volume based on 2004 claims data, the median ASC- to-APC cost ratio was 0.84. We determined that the median OPPS-to-APC cost ratio was 1.04. This analysis shows that when compared to the median cost of the same APC group, procedures performed in ASCs had substantially lower costs than when those same procedures were performed in hospital outpatient departments. Generally, there are many similarities between the additional services provided in ASCs and hospital outpatient departments with one of the top 20 procedures, and few resulted in an additional Medicare payment to ASCs or hospital outpatient departments. Although costs for individual procedures vary, in general, the median costs for procedures are lower in ASCs, relative to the median costs of their APC groups, than the median costs for the same procedures in the hospital outpatient department setting. The APC groups in the OPPS reflect the relative costs of procedures performed in ASCs in the same way that they reflect the relative costs of the same procedures when they are performed in hospital outpatient departments. Therefore, the APC groups could be applied to procedures performed in ASCs, and the OPPS could be used as the basis for an ASC payment system, eliminating the need for ASC surveys and providing for an annual revision of the ASC payment groups. We recommend that the Administrator of CMS implement a payment system for procedures performed in ASCs based on the OPPS. The Administrator should take into account the lower relative costs of procedures performed in ASCs compared to hospital outpatient departments in determining ASC payment rates. We received written comments on a draft of this report from CMS (see app. IV). We also received oral comments from external reviewers representing two ASC industry organizations, AAASC and FASA. In commenting on a draft of this report, CMS stated that our recommendation is consistent with its August 2006 proposed revisions to the ASC payment system. Industry representatives who reviewed a draft of this report did not agree or disagree with our recommendation for executive action. They did, however, provide several comments on the draft report. The industry representatives noted that we did not analyze the survey results to examine differences in per-procedure costs among single-specialty and multi-specialty ASCs. Regarding this comment, we initially considered developing our survey sample stratified by ASC specialty type. However, because accurate data identifying ASCs' specialties do not exist, we were unable to stratify our survey sample by specialty type. The industry representatives asked us to provide more explanation in our scope and methodology regarding our development of a relative weight scale for Medicare ASC-approved procedures to capture the general variation in resources associated with performing different procedures. We expanded the discussion of how we developed the relative weight scale in our methodology section. Reviewers also made technical comments, which we incorporated where appropriate. We are sending a copy of this report to the Administrator of CMS and appropriate congressional committees. The report is available at no charge on GAO's Web site at http://www.gao.gov. We will also make copies available to others on request. If you or your staff members have any questions about this report, please contact me at (202) 512-7119 or kingk@gao.gov. Contact points for our Offices of Congressional Relations and Public Affairs may be found on the last page of this report. GAO staff members who made significant contributions to this report are listed in appendix V. The Medicare payment rates for ambulatory surgical centers (ASC), along with those of other facilities, are adjusted to account for the variation in labor costs across the country. To calculate payment rates for individual ASCs, the Centers for Medicare & Medicaid Services (CMS) calculates the share of total costs that are labor-related and then adjusts ASCs' labor- related share of costs based on a wage index calculated for specific geographic areas across the country. The wage index reflects how the average wage for health care personnel in each geographic area compares to the national average health care personnel wage. The geographic areas are intended to represent the separate labor markets in which health care facilities compete for employees. In setting the initial ASC payment rates for 1982, CMS determined from the first survey of ASCs that one-third of their costs were labor-related. The labor-related costs included employee salaries and fringe benefits, contractual personnel, and owners' compensation for duties performed for the facility. To determine the payment rates for each individual ASC, CMS multiplied one-third of the payment rate for each procedure--the labor- related portion--by the local area wage index. Each ASC received the base payment rate for two-thirds of the payment rate--the nonlabor-related portion--for each procedure. The sum of the labor-related and nonlabor- related portions equaled each ASC's payment rate for each procedure. In 1990, when CMS revised the payment system based on a 1986 ASC survey, CMS found ASCs' average labor-related share of costs to be 34.45 percent and used this percentage as the labor-related portion of the payment rate. In a 1998 proposed rule, CMS noted that ASCs' share of labor-related costs as calculated from the 1994 ASC cost survey had increased to an average of 37.66 percent, slightly higher than the percentage calculated from the 1986 survey. However, CMS did not implement the 1998 proposal. Currently, the labor-related proportion of costs from CMS's 1986 survey, 34.45 percent, is used for calculating ASC payment rates. Using 2004 cost data we received from 290 ASCs that responded to our survey request for information, we determined that the mean labor-related proportion of costs was 50 percent, and the range of the labor-related costs for the middle 50 percent of our ASC facilities was 43 percent to 57 percent of total costs. To compare the delivery of procedures between ASCs and hospital outpatient departments, we analyzed Medicare claims data from 2003. To compare the relative costs of procedures performed in ASCs and hospital outpatient departments, we collected cost and procedure data from 2004 from a sample of Medicare-participating ASCs. We also interviewed officials at CMS and representatives from ASC industry organizations, specifically, the American Association of Ambulatory Surgery Centers (AAASC) and FASA, physician specialty societies, and nine ASCs. To compare the delivery of additional services provided with procedures performed in ASCs and hospital outpatient departments, we identified all additional services frequently billed in each setting when one of the top 20 procedures with the highest Medicare ASC claims volume is performed. These procedures represented approximately 75 percent of all Medicare ASC claims in 2003. Using Medicare claims data for 2003, we identified beneficiaries receiving one of the top 20 procedures in either an ASC or hospital outpatient department, then identified any other claims for those beneficiaries from ASCs, hospital outpatient departments, durable medical equipment suppliers, and other Medicare part B providers. We identified claims for the beneficiaries on the day the procedure was performed and the day after. We created a list that included all additional services that were billed at least 10 percent of the time with each of the top 20 procedures when they were performed in ASCs. We created a similar list of additional services for each of the top 20 procedures when they were performed in hospital outpatient departments. We then compared the lists for each of the top 20 procedures between the two settings to determine whether there were similarities in the additional services that were billed to Medicare. To compare the Medicare payments for procedures performed in ASCs and hospital outpatient departments, we identified whether any additional services included in our analysis resulted in an additional payment. We used Medicare claims data from the National Claims History (NCH) files. These data, which are used by the Medicare program to make payments to health care providers, are closely monitored by both CMS and the Medicare contractors that process, review, and pay claims for Medicare services. The data are subject to various internal controls, including checks and edits performed by the contractors before claims are submitted to CMS for payment approval. Although we did not review these internal controls, we did assess the reliability of the NCH data. First, we reviewed all existing information about the data, including the data dictionary and file layouts. We also interviewed experts at CMS who regularly use the data for evaluation and analysis. We found the data to be sufficiently reliable for the purposes of this report. To compare the relative costs of procedures performed in ASCs and hospital outpatient departments, we first compiled information on ASCs' costs and procedures performed. Because there were no recent existing data on ASC costs, we surveyed 600 ASCs, randomly selected from all ASCs, to obtain their 2004 cost and procedure data. We received response data from 397 ASC facilities. We assessed the reliability of these data through several means. We identified incomplete and inconsistent survey responses within individual surveys and placed follow-up calls to respondents to complete or verify their responses. To ensure that survey response data were accurately transferred to electronic files for our analytic purposes, two analysts independently entered all survey responses. Any discrepancies between the two sets of entered responses were resolved. We performed electronic testing for errors in accuracy and completeness, including an analysis of costs per procedure. As a result of our data reliability testing, we determined that data from 290 responding facilities were sufficiently reliable for our purposes. Our nonresponse analysis showed that there was no geographic bias among the facilities responding to our survey. The responding facilities performed more Medicare services than the average for all ASCs in our sample. To allocate ASCs' total costs among the individual procedures they perform, we developed a method to allocate the portion of an ASC's costs accounted for by each procedure. We constructed a relative weight scale for Medicare ASC-approved procedures that captures the general variation in resources associated with performing different procedures. The resources we used were the clinical staff time, surgical supplies, and surgical equipment used during the procedures. We used cost and quantity data on these resources from information CMS had collected for the purpose of setting the practice expense component of physician payment rates. For procedures for which CMS had no data on the resources used, we used information we collected from medical specialty societies and physicians who work for CMS. We summed the costs of the resources for each procedure and created a relative weight scale by dividing the total cost of each procedure by the average cost across all of the procedures. We assessed the reliability of these data through several means. We compared electronic CMS data with the original document sources for a large sample of records, performed electronic testing for errors in accuracy and completeness, and reviewed data for reasonableness. Based on these efforts, we determined that data were sufficiently reliable for our purposes. To calculate per-procedure costs with the data from the surveyed ASC facilities, we first deducted costs that Medicare considers unallowable, such as advertising and entertainment costs. (See fig. 3 for our per- procedure cost calculation methodology.) We also deducted costs for services that Medicare pays for separately, such as physician and nonphysician practitioner services. We then separated each facility's total costs into its direct and indirect costs. We defined direct costs as those associated with the clinical staff, equipment, and supplies used during the procedure. Indirect costs included all remaining costs, such as support and administrative staff, building expenses, and outside services purchased. To allocate each facility's direct costs across the procedures it performed, we applied our relative weight scale. We allocated indirect costs equally across all procedures performed by the facility. For each procedure performed by a responding ASC facility, we summed its allocated direct and indirect costs to determine a total cost for the procedure. To obtain a per-procedure cost across all ASCs, we arrayed the calculated costs for all ASCs performing that procedure and identified the median cost. To compare per-procedure costs for ASCs and hospital outpatient departments, we first obtained from CMS the list of ambulatory payment classification (APC) groups used for the outpatient prospective payment system (OPPS) and the procedures assigned to each APC group. We also obtained from CMS the OPPS median cost of each procedure and the median cost of each APC group. We then calculated a ratio between each procedure's ASC median cost, as determined by the survey, and the median cost of each procedure's corresponding APC group under the OPPS, referred to as the ASC-to-APC cost ratio. We also calculated a ratio between each ASC procedure's median cost under the OPPS and the median cost of the procedure's APC group, using the data obtained from CMS, referred to as the OPPS-to-APC cost ratio. To evaluate the difference in procedure costs between the two settings, we compared the ASC-to- APC and OPPS-to-APC cost ratios. To assess how well the relative costs of procedures in the OPPS, defined by their assignment to APC groups, reflect the relative costs of procedures in the ASC setting, we evaluated the distribution of the ASC-to-APC and OPPS-to-APC cost ratios. To calculate the percentage of labor-related costs among our sample ASCs, for each ASC, we divided total labor costs by total costs, after deducting costs not covered by Medicare's facility payment. We then determined the range of the percentage of labor-related costs among all of our ASCs and between the 25th percentile and the 75th percentile, as well as the mean and median percentage of labor-related costs. We performed our work from April 2004 through October 2006 in accordance with generally accepted government auditing standards. Appendix III: Additional Procedures Billed with the Top 20 ASC Procedures, 2003 (percentage) N/A (percentage) In addition to the contact named above, key contributors to this report were Nancy A. Edwards, Assistant Director; Kevin Dietz; Beth Cameron Feldpush; Marc Feuerberg; and Nora Hoban.<end_of_turn>
<start_of_turn>model
Medicare pays for surgical procedures performed at ambulatory surgical centers (ASC) and hospital outpatient departments through different payment systems. Although they perform a similar set of procedures, no comparison of ASC and hospital outpatient per-procedure costs has been conducted. The Medicare Prescription Drug, Improvement, and Modernization Act of 2003 directed GAO to compare the relative costs of procedures furnished in ASCs to the relative costs of those procedures furnished in hospital outpatient departments, in particular, how accurately the payment groups used in the hospital outpatient prospective payment system (OPPS) reflect the relative costs of procedures performed in ASCs. To do this, GAO collected data from ASCs through a survey. GAO also obtained hospital outpatient data from the Centers for Medicare & Medicaid Services (CMS). GAO determined that the payment groups in the OPPS, known as ambulatory payment classification (APC) groups, accurately reflect the relative cost of procedures performed in ASCs. GAO calculated the ratio between each procedure's ASC median cost, as determined by GAO's survey, and the median cost of each procedure's corresponding APC group under the OPPS, referred to as the ASC-to-APC cost ratio. GAO also compared the OPPS median costs of those same procedures with the median costs of their APC groups, referred to as the OPPS-to-APC cost ratio. GAO's analysis of the ASC-to-APC and OPPS-to-APC cost ratios showed that 45 percent of all procedures in the analysis fell within a 0.10 point range of the ASC-to-APC median cost ratio, and 33 percent of procedures fell within a 0.10 point range of the OPPS-to-APC median cost ratio. These similar patterns of distribution around the median show that the APC groups reflect the relative costs of procedures provided by ASCs as well as they reflect the relative costs of procedures provided in hospital outpatient departments and can be used as the basis for the ASC payment system. GAO's analysis also identified differences in the cost of procedures in the two settings. The median cost ratio among all ASC procedures was 0.39 and when weighted by Medicare claims volume was 0.84. The median cost ratio for OPPS procedures was 1.04. Thus, the cost of procedures in ASCs is substantially lower than the corresponding cost in hospital outpatient departments.<end_of_turn><eos>

This is the formatting code -

summarizer_system_prompt = '''You're a proficient summarizing tool, designed to condense information efficiently while maintaining politeness. Please summarize the text clearly and succinctly.'''

def generate_summarizer_prompt(row: pd.Series) -> str:
    "Format to Gemma's chat template"
    return """<bos><start_of_turn>user
## Instruction
{}
## User
{}<end_of_turn>
<start_of_turn>model
{}<end_of_turn><eos>""".format(summarizer_system_prompt, row["report"], row["summary"])

@awni
Copy link
Member

awni commented May 3, 2024

Thanks I will take a look tomorrow!

@mukundsayeeganesh
Copy link

@awni Do we have any update regarding the issue?

@awni
Copy link
Member

awni commented May 6, 2024

@mukundsayeeganesh I opened a relevant issue in MLX core ml-explore/mlx#1084. Once that is fixed the NaNs for your case should be resolved.

@mukundsayeeganesh
Copy link

Thanks for the update.

@awni
Copy link
Member

awni commented May 6, 2024

Slightly unrelated: your LoRA settings are using a lot of memory. I would consider trying to decrease memory consumption by either:

  1. Reduce batch size
  2. Reduce max sequence length

You can see the peak mem is over 200GB which will cause swapping and be very very slow. Also I think you will work around the NaN issue if you use a batch size of 1 or 2 for now.

@mukundsayeeganesh
Copy link

Yeah, understood. I will try it out later today and let you know how it works out.

@mukundsayeeganesh
Copy link

@awni I tried your suggestion - Reducing the batch size to 2 worked well without any NaN values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants