Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable the Mixtral-like Moe model without the quantized gate layer #394

Open
mzbac opened this issue Jan 31, 2024 · 42 comments
Open

Enable the Mixtral-like Moe model without the quantized gate layer #394

mzbac opened this issue Jan 31, 2024 · 42 comments

Comments

@mzbac
Copy link
Contributor

mzbac commented Jan 31, 2024

Currently, the community has started experimenting with building more models using a mix of different local experts. In the current implementation of mlx-lm, we have hardcoded the linear_class_predicate with 8 as an assumption that there will always be 8 local experts. It would be great if we could read the number of local experts from the configuration and make it adjustable to support varying numbers of local experts. This way, users can use mlx-lm to fine-tune the gate for LoRa. For example, here I have to copy and paste code from mlx-lm in order to make lora moe 4x7b work.

@awni
Copy link
Member

awni commented Jan 31, 2024

Actually, if the size is not 8 won't the gating layer just be quantized? I feel like it would be better to simply quantize the gate layers so we don't have to deal with mixed quantization. As far as I understand the only reason we have the predicate for 8 still there is for back compatibility with already quantized Mixtrals

@mzbac
Copy link
Contributor Author

mzbac commented Jan 31, 2024

Yeah, I had thought about just the quantize gate. but, the gate matrix is quite small (4096*(3-8)). From my experiment with the current mixtral implementation, I didn't see much effect of applying the lora weight to it (maybe there's an issue with the current gate implementation in mixtral?). However, when I tried fine-tuning the gate layer directly, I encountered a loss value of nan, I am not sure if lora on quant gate works or not.

@mzbac
Copy link
Contributor Author

mzbac commented Feb 1, 2024

@awni Would you be able to give me some hints as to why we have to use stop_gradient indices on the sparse MOE block? Also, is there anything I can do to help fine-tune the Moe routing and make it work properly? Personally, I believe that there is great potential for Moe models to excel on Apple Silicon. Thanks in advance.

@awni
Copy link
Member

awni commented Feb 1, 2024

@awni Would you be able to give me some hints as to why we have to use stop_gradient indices

Well it's a good question. We could have a gradient on the topk as 1s at the selected indices and 0s elsewhere. I don't know how well that will work that well though because the experts that are not selected get 0 gradient. If you specialize early on to an expert (or set of experts) then the other models will effectively be not used.

Training MOEs in general requires a bit of finesse. It's worth taking a look at the literature to see what kind of techniques they use to balance expert selection. Usually there is some forced hard constraint like each expert has to be used a certain number of times per token.

Also, is there anything I can do to help fine-tune the Moe routing and make it work properly?

I don't know exactly what you mean by "work properly". I think we need to understand what exactly is not working and diagnose the reason. Is the current behavior that the gates are collapsing to only one or two experts?

I would be happy to dig into this a bit when I have some bandwidth, possibly early next week. In the meantime we can use this thread to chat about it? And if you uncover anything interesting I would definitely love to hear about it / give my input.

@mzbac
Copy link
Contributor Author

mzbac commented Feb 1, 2024

@awni, didn't expect you guys to spend too much effort on this right now. I know you guys are busy improving the mlx framework 🚀 , but just a few hints or thoughts that could help me investigate the issue would be greatly appreciated.

Here are some of my findings. I am keen to hear your thoughts or if there is anything I missed.

  • From what I remember, there was a backpropagation error thrown from mlx when we didn't use stop_gradient on indices. However, today I tried installing the latest mlx from the master branch and removing the stop_gradient, and the error is no longer present. Maybe this is something that was fixed on mlx's side. (My understanding is that mlx doesn't have pytorch.topk like ops so the argpartition's or argsort's backpropagation won't work, but somehow it works on the new version?)
  • The current mlx 0.0.11 still throws an error when removing stop_gradient.
      File "/Users/anchenli/miniconda3/envs/mlx/lib/python3.10/site-packages/mlx/nn/utils.py", line 30, in wrapped_value_grad_fn
        value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
    ValueError: [gather] Cannot calculate VJP with respect to indices.
    
  • As suggested in the "makemoe" article (https://huggingface.co/blog/AviSoori1x/makemoe-from-scratch), I have added a noise layer to balance the gate layers to prevent favoring certain experts. However, I didn't notice much difference. You can find more details at this link: https://github.com/mzbac/mlx-moe/blob/main/mixral.py#L170-L175
  • Currently, we evaluate indices and convert them to a numpy array before calculating results from selected experts. I am not sure if this operation will impact backpropagation or slow down the fine-tuning process. I don't really understand why we need to evaluate indices, but based on my intuitive understanding, it may be caused by mx.where not behaving like np.where? https://github.com/mzbac/mlx-moe/blob/main/mixral.py#L187-L196

@mzbac
Copy link
Contributor Author

mzbac commented Feb 2, 2024

Just found that I removed the mx.eval on indexes and it still works but I didn't notice any improvement in training speed.
FYI, here is my test code -> https://github.com/mzbac/mlx-moe/blob/main/mixral.py#L164-L201

@awni
Copy link
Member

awni commented Feb 2, 2024

Btw, there were some issues with eval'ing in grad graphs in introduced in 0.1.0 😢 . (Fixed in ml-explore/mlx#612)

Are you using the latest MLX for these tests? If so I might go back to 0.0.11 until we do a patch release or you can use main once my fix PR lands.

@mzbac
Copy link
Contributor Author

mzbac commented Feb 3, 2024

Btw, there were some issues with eval'ing in grad graphs in introduced in 0.1.0 😢 . (Fixed in ml-explore/mlx#612)

Are you using the latest MLX for these tests? If so I might go back to 0.0.11 until we do a patch release or you can use main once my fix PR lands.

Double checked with mlx latest master, it seems that the backpropagation still doesn't work with argpartition. I have to add the stop_gradient back. I don't know if there is a stop_gradient on argpartition. Will the fine-tuning with expert's MLP still work? I couldn't wrap my head around how the backpropagation flows through the topk in this case.

@awni
Copy link
Member

awni commented Feb 3, 2024

  • Fine tuning should still work as it did before, we haven't made any changes (after the bug fix) that would affect that.
  • No we didn't add gradients for argpartition (or argmax) or indices of scatter yet.

I think to get backprop to work through expert index selection we would have to do something like:

  • Grad for argpartition (not too bad)
  • Change this to use mlx.core ops https://github.com/mzbac/mlx-moe/blob/main/mixral.py#L193-L196
    • I think a take might work, but not certain
  • y[idx1, idx2] = expert(x[idx1]) that is a scatter into y and we don't differentiate w.r.t scatter indices so that needs to be updated as well. If we could do it with a concatenate followed by a gather that might fix that part or we add grad for scatter indices.

@awni
Copy link
Member

awni commented Feb 3, 2024

@mzbac a couple requests from you:

  • Could point me to a command that tries to train a MOE (maybe with just 4 smaller models so I can iterate quickly)?
  • What did you currently see vs what is the expected result? I assume it's something like the gate matrices are not getting updated at all during LoRA fine tuning so the scores are unchanged?

@mzbac
Copy link
Contributor Author

mzbac commented Feb 3, 2024

@awni I am using mlx-moe (https://github.com/mzbac/mlx-moe) as a playground to test moe fine-tuning. It is easy to set up the training by following the readme. However, it defaults to using mixtral 4x7b. Tomorrow, I will see if I can create a phi-2 4x moe model for quick testing.

Or, if you prefer mlx-lm command, you can use the following command:

python -m mlx_lm.lora --model mlx-community/Mixtral-8x7B-v0.1-hf-4bit-mlx --train --iters 600 --data ../lora/data 

Regarding the fine-tuning part, I noticed that the gate weights are being updated, but the changes in values seem very small. I didn't analyze the results thoroughly because during training, I felt that there wasn't much decrease in loss. My guess was that maybe only fine-tuning q,v and gate is not sufficient. However, tomorrow I will do more testing on my local machine to see if I can find anything.

Thoughts:
I think we may need to take into account the difference between fine-tuning the Moe base models and the Moe instruct model. I feel that the issue I am having is more likely due to fine-tuning the instruct model. only q, v, and gate lora not enough to reduce the loss?

@mzbac
Copy link
Contributor Author

mzbac commented Feb 4, 2024

@awni, I have created a phi2 2x4 moe model with 4 identical phi2 experts. You can check out my fork branch here: https://github.com/mzbac/mlx-examples/tree/mlx-lm/moe. (I need to disable quant to gate, somehow there is an issue with lora the quant gate, I getting loss nan issue, also need to add phixtral model support)
you can use the normal lora command to do fine-tuning.

python -m mlx_lm.lora --model mzbac/phi2-2x4-hf-4bit-mlx --train --iters 600 --data ../lora/data --lora-layers 32

I have tested it on my local machine and it seems to work well with q,v, and gate lora. However, I assume that this is because the wikisql training dataset has very simple patterns to learn from, so even a regular phi2 model can be easily fine-tuned without much help from the experts. This means that the gate routing doesn't really matter in this case. In my experiment of fine-tuning on a complex dataset (guanaco) without being able to lora the experts' mlp layers, the loss doesn't decrease much (always around 1.xx).

@awni
Copy link
Member

awni commented Feb 4, 2024

In my experiment of fine-tuning on a complex dataset (guanaco)

Could you point me to this dataset? I can try training on it to see if I can repro

without being able to lora the experts' mlp layers, the loss doesn't decrease much

So you think there is a capacity issue? Did you try training a larger (but non-MOE) model to see if it can learn this dataset? It might not be related to the experts...could be some other setting in the training config /setup.

@mzbac
Copy link
Contributor Author

mzbac commented Feb 4, 2024

Yeah, just try phi2-2x4. Even though I want to overfit it, the loss doesn't go down to 0.6x and the result feels much worse than fine-tune mixtral (which I assume it has already been fine-tuned during pretraining so it is more effective than lora on 2x4 identical phi2 experts).

python -m mlx_lm.lora --model mzbac/phi2-2x4-hf-4bit-mlx --train --iters 600 --data ../lora/data --lora-layers 32

And also, I tried on a guanaco here. the loss seems always around 0.9x-1.xx.
I intuitively feel that fine-tuning Moe without applying Lora to the expert's MLP seems ineffective. According to the Qlora paper, it states that applying Lora to all linear layers would achieve full fine-tuning performance. So, I suspect that if we enable fine-tuning of the expert's MLP, it may improve performance.

@awni
Copy link
Member

awni commented Feb 4, 2024

Sorry I am still not understanding 100%:

Even though I want to overfit it, the loss doesn't go down to 0.6x

This is with WikiSQL right? Just curious, where did the loss end up?

I suspect that if we enable fine-tuning of the expert's MLP

What's stopping you from doing this (other than changing the default settings). Does it crash if you make the MLP layers LoRA layers?

@awni
Copy link
Member

awni commented Feb 4, 2024

Regarding capacity issues, some simple things to try are:

  1. Add more LoRA layers (like you have done), but also make more of the linear layers work with LoRA (like all the attention projections and the MLP layers)
  2. Try increasing the rank of the LoRA layers. For smaller models you will likely have to do this o/w it's just not enough parameters.

If those two tacks are not effective then it suggests something is wrong with the optimization..

@mzbac
Copy link
Contributor Author

mzbac commented Feb 4, 2024

This is with WikiSQL right? Just curious, where did the loss end up?

I tried around 2000 iterations (around 3 epochs). The loss is approximately 0.6 or 0.5x, but when I ran the inference, the result was not very good.

What's stopping you from doing this (other than changing the default settings). Does it crash if you make the MLP layers LoRA layers?

Maybe I misunderstood. If the moe doesn't backpropagate at topk, can we still apply LORA to the expert's MLP layers? I thought it wouldn't be able to update LORA since the backprops stop at argpartition inds.

@mzbac
Copy link
Contributor Author

mzbac commented Feb 7, 2024

@awni please let me know if raise the issue in mlx for support torch topk like ops would be making more sense for this?

@awni
Copy link
Member

awni commented Feb 7, 2024

I think it’s a good idea to have issues in mlx for the missing grads. I’m still not certain that is the problem here though. Sorry I have been intending to look into this just after we get our next release out.

What is a good dataset to use to investigate? Sounds like wikisql is a bit too on the easy side?

@mzbac
Copy link
Contributor Author

mzbac commented Feb 7, 2024

I think there are two questions, I am not very sure.

  1. Missing grad for mx.argpartition, in that case are we still be able to fine-tune expert's mlp layers.
  2. Is https://github.com/mzbac/mlx-moe/blob/main/mixral.py#L193-L196 will impact lora fine-tune expert's lora?

For dataset I feel wikisql may too easy to overfit so it doesn't require lora mlp layers. But I think if we enabled full fine-tuning all linear layer then it will be up to user to adjust the hpye params than mlx concerns.

Please let me know if those make sense, I am happy to setup some playground repo for testing.

@awni
Copy link
Member

awni commented Feb 8, 2024

Missing grad for mx.argpartition, in that case are we still be able to fine-tune expert's mlp layers

Yes there can still be gradient to those layers. Did you try adding LoRA to the MLP layers? You can check if the gradient is zero during the update but I do not think it should be.

Is https://github.com/mzbac/mlx-moe/blob/main/mixral.py#L193-L196 will impact lora fine-tune expert's lora?

The conversion to NumPy there is not ideal .. but I don't think it should preclude one from training the expert MLP. There should still be gradient

@mzbac
Copy link
Contributor Author

mzbac commented Feb 8, 2024

Missing grad for mx.argpartition, in that case are we still be able to fine-tune expert's mlp layers

Yes there can still be gradient to those layers. Did you try adding LoRA to the MLP layers? You can check if the gradient is zero during the update but I do not think it should be.

Is https://github.com/mzbac/mlx-moe/blob/main/mixral.py#L193-L196 will impact lora fine-tune expert's lora?

The conversion to NumPy there is not ideal .. but I don't think it should preclude one from training the expert MLP. There should still be gradient

@awni, thanks for clarifying it. I will do some tests. I haven't really checked the weight of the mlp's lora before, was assuming it wouldn't work.

@awni
Copy link
Member

awni commented Feb 9, 2024

Awesome, let me know how it goes!! I was just planning to try some MOE fine-tunes to see how it all works myself.

@awni
Copy link
Member

awni commented Feb 9, 2024

python download_dataset.py WizardLM/WizardLM_evol_instruct_70k

@mzbac is that dataset a good one to try to see how well the MOE fine-tuning works?

@mzbac
Copy link
Contributor Author

mzbac commented Feb 9, 2024

python download_dataset.py WizardLM/WizardLM_evol_instruct_70k

@mzbac is that dataset a good one to try to see how well the MOE fine-tuning works?

Yeah, Not being experts on this topic, but that's the dataset I have used to fine-tune a custom moe models. From my experience, it can reduce loss well during the fine-tuning process and doesn't easily lead to overfitting (I have tried 9k-15k iterations and didn't see overfitting on phi2-2x3 and qwen2-2x3). However, if you want to try something with fewer examples, I personally think https://huggingface.co/datasets/timdettmers/openassistant-guanaco also would be a good option. The guanaco is not in pair format, so preprocessing may be required if you intend to fine-tune chat-based models.

@awni
Copy link
Member

awni commented Feb 9, 2024

@mzbac I tried LoRA fine tuning Phixtral on the wizard dataset and it works pretty well as far as I can tell:

Iter 1: Val loss 1.716, Val took 0.168s
Iter 600: Val loss 0.343, Val took 0.442s

Is that what you expect? Maybe that dataset is also too easy?

This is the setup I used:

python -m mlx_lm.lora \
    --model mlabonne/phixtral-4x2_8 \
    --train \
    --data . \
    --iters 600 \
  --val-batches 1 \
  --batch-size 1

I'm guessing the gates did not change much though but I didn't check yet.

@mzbac
Copy link
Contributor Author

mzbac commented Feb 9, 2024

@mzbac I tried LoRA fine tuning Phixtral on the wizard dataset and it works pretty well as far as I can tell:

Iter 1: Val loss 1.716, Val took 0.168s
Iter 600: Val loss 0.343, Val took 0.442s

Is that what you expect? Maybe that dataset is also too easy?

This is the setup I used:

python -m mlx_lm.lora \
    --model mlabonne/phixtral-4x2_8 \
    --train \
    --data . \
    --iters 600 \
  --val-batches 1 \
  --batch-size 1

I'm guessing the gates did not change much though but I didn't check yet.

Very strange, I have never get a loss below 0.5 on the wizard dataset before, but I did apply Lora with 32 layers of q, v and gate. Do you think that having all those layers might be causing some instability? I can try the 16 layers see if I can reproduce the results.

@mzbac
Copy link
Contributor Author

mzbac commented Feb 9, 2024

@mzbac I tried LoRA fine tuning Phixtral on the wizard dataset and it works pretty well as far as I can tell:

Iter 1: Val loss 1.716, Val took 0.168s
Iter 600: Val loss 0.343, Val took 0.442s

Is that what you expect? Maybe that dataset is also too easy?

This is the setup I used:

python -m mlx_lm.lora \
    --model mlabonne/phixtral-4x2_8 \
    --train \
    --data . \
    --iters 600 \
  --val-batches 1 \
  --batch-size 1

I'm guessing the gates did not change much though but I didn't check yet.

@awni, one thing: the Phixtral doesn't have a noise linear layer in the implementation. So during fine-tuning, the gate may quickly tend to direct all tokens to a favorite expert. My moe implementation has this noise layer enabled during fine-tuning https://github.com/mzbac/mlx-moe/blob/phi2-moe/phi2moe.py#L170-L179

@awni
Copy link
Member

awni commented Feb 9, 2024

the Phixtral doesn't have a noise linear layer in the implementation

Right I did not train with that. It might be collapsing to one expert.. I will see if it works with the noise

@awni
Copy link
Member

awni commented Feb 9, 2024

So the model (as you've pointed out before) starts out only using two experts for every token, and it stays that way during Lora finetuning. I checked the grad of the gate and it is zero 🤔

grad["transformer"]["h"][-5]["moe"]
{'mlp': [{'fc1': {}, 'fc2': {}, 'act': {}}, {'fc1': {}, 'fc2': {}, 'act': {}}, {'fc1': {}, 'fc2': {}, 'act': {}}, {'fc1': {}, 'fc2': {}, 'act': {}}], 'gate': {}, 'gatee': {'linear': {}, 'lora_dropout': {}, 'lora_a': array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=float32), 'lora_b': array([[0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       ...,
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 0, 0]], dtype=float32)}

@mzbac
Copy link
Contributor Author

mzbac commented Feb 9, 2024

Yeah, I am currently running llm_evaluation for the qwen 2x3 model. It actually performs worse in benchmarks compared to the original 7b chat model. I suspect this might be related to it.
FYI
Qwen-1_5-2x3-hf

MMLU

Groups Version Filter n-shot Metric Value Stderr
- humanities N/A none 0 acc 0.6488 ± 0.0237
- other N/A none 0 acc 0.6294 ± 0.0302
- social_sciences N/A none 0 acc 0.6905 ± 0.0281
- stem N/A none 0 acc 0.5227 ± 0.0375

GSM8K

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 2 get-answer 5 exact_match 0.4102 ± 0.0135

Qwen1.5-7B-Chat
MMLU

Groups Version Filter n-shot Metric Value Stderr
- humanities N/A none 0 acc 0.6533 ± 0.0239
- other N/A none 0 acc 0.6321 ± 0.0301
- social_sciences N/A none 0 acc 0.6934 ± 0.0282
- stem N/A none 0 acc 0.5329 ± 0.0376

GSM8K

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 2 get-answer 5 exact_match 0.0425 ± 0.0056

@awni
Copy link
Member

awni commented Feb 9, 2024

@mzbac there was a bug in my code so disregard the zero gradient comment above. I fixed that.

Setup

I did some analysis to understand how the expert selection changes over training with Phixtral. Some comments on setup:

  • 16 lora layers
  • Attention + gate only (no MLPs)
  • Wizard dataset
  • Batch size 1
  • No noise layer

Command:

 python -m mlx_lm.lora \
    --model mlabonne/phixtral-4x2_8 \
    --train \
    --data . \
    --iters 600 \
    --val-batches 1 \
    --batch-size 1

Findings:

  1. The training converges nicely. The final val loss is: Iter 600: Val loss 0.342, Val took 0.547s
  2. Initially only the first two experts are selected for every token. The gate scores from the pre-trained model start out uniform and argpartition takes the first two indices in every case.
  3. Non LoRA layers remain fixed as expected, always picking the first two experts (see figure below for example)
  4. After just a few iterations, the LoRA gate layers start to vary the experts selected quite nicely actually (see figure below)

Layer 0 does not learn experts, only expert 1 and 2 are selected, 3 and 4 have 0 counts (lines are on top of each other):

Screenshot 2024-02-09 at 12 29 05 PM

Layer 16 is the first LoRA layer, all experts get some tokens:

Screenshot 2024-02-09 at 12 29 20 PM

Same for Layer 31:

Screenshot 2024-02-09 at 12 29 38 PM

@awni
Copy link
Member

awni commented Feb 9, 2024

So it seems like learning the gate layer is working fine 🤔 ... I'm wondering what we are doing differently? I added the changes to work with Phixtral to this PR #426 (and removed phixtral as a standalone model from lms/, since I don't think we need it there).

@mzbac
Copy link
Contributor Author

mzbac commented Feb 10, 2024

@awni Thanks for the thorough analysis. I will try your setup and see if I can reproduce the result. I will report back if I find anything.

@mzbac
Copy link
Contributor Author

mzbac commented Feb 12, 2024

@awni I have done some testing and here are my findings:

Test 1:
Using 1000 samples from the wizard dataset, I initialized the gatelayer using a uniform distribution. After fine-tuning all 32 layers without gate noise for one epoch, I observed that the tokens were being routed to different experts but not significantly as compared to when using uniform weights for initialization.

Test 2:
Using 1000 samples from the wizard dataset, I initialized the gatelayer to default to the first two experts. After fine-tuning 16 layers without gate noise for one epoch, I noticed that the tokens were routing to different experts and in layer 31, there was a significant preference towards one expert.

Test 3:
Using 9000 samples from the wizard dataset, I initialized the gatelayer with a uniform distribution. After fine-tuning all 32 layers including lm_head, q,v,k,dense and mlp with gate noise for one epoch, I found that loss only decreased up to 0.6. This maybe just an issue of capacity.

I can confirm that the gate is definitely getting updated and introducing noise into it might be necessary in order to avoid imbalanced routing.

@awni
Copy link
Member

awni commented Feb 12, 2024

Thanks for sharing.

After fine-tuning all 32 layers including lm_head, q,v,k,dense and mlp with gate noise for one epoch, I found that loss only decreased up to 0.6. This maybe just an issue of capacity.

Interesting. Are you measuring the loss with or without the noise addition? I assume it is better to measure without noise addition since you would not use it when actually generating with the model. Maybe that accounts for the discrepancy?

@mzbac
Copy link
Contributor Author

mzbac commented Feb 16, 2024

Just sharing the findings, I have done a full linear layers Lora fine-tuning on one phi2-2x4 model. It looks like the experts' MLP is updating well despite the noise in the linear layer. even, the loss does not seem to decrease and remains around 0.6x. Nevertheless, the MMLU evaluation indicates an improvement in MMLU score.
https://huggingface.co/mzbac/phi-2-2x4-hf

I will try to do some fine-tuning without the noise layer later this week and see if there are any differences.

@awni
Copy link
Member

awni commented Feb 16, 2024

That’s super cool! Thanks for the update!

@chimezie
Copy link
Contributor

Just sharing the findings, I have done a full linear layers Lora fine-tuning on one phi2-2x4 model. It looks like the experts' MLP is updating well despite the noise in the linear layer. even, the loss does not seem to decrease and remains around 0.6x. Nevertheless, the MMLU evaluation indicates an improvement in MMLU score. https://huggingface.co/mzbac/phi-2-2x4-hf

I will try to do some fine-tuning without the noise layer later this week and see if there are any differences.

@mzbac Do you mind sharing the harness you used for the MMLU benchmark? I'd like to run evaluations on my MLX models but haven't done anything like this before

@mzbac
Copy link
Contributor Author

mzbac commented Feb 23, 2024

Just sharing the findings, I have done a full linear layers Lora fine-tuning on one phi2-2x4 model. It looks like the experts' MLP is updating well despite the noise in the linear layer. even, the loss does not seem to decrease and remains around 0.6x. Nevertheless, the MMLU evaluation indicates an improvement in MMLU score. https://huggingface.co/mzbac/phi-2-2x4-hf
I will try to do some fine-tuning without the noise layer later this week and see if there are any differences.

@mzbac Do you mind sharing the harness you used for the MMLU benchmark? I'd like to run evaluations on my MLX models but haven't done anything like this before

Yeah, I am using https://github.com/EleutherAI/lm-evaluation-harness. It's slow on the MPS backend, so you have to run it on a CUDA device. I normally run the benchmarks on my PC instead of mac

@chimezie
Copy link
Contributor

Thank you. I was poking around that very tool, how it evaluates, and if it can be done over Open AI to calculate perplexity of MMLU tasks (topics) on mlx fine tuned models in the same way the test/evaluate libraries do.

(Sorry if this is off topic, maybe later I can add a discussion item)

@mzbac
Copy link
Contributor Author

mzbac commented Feb 24, 2024

Thank you. I was poking around that very tool, how it evaluates, and if it can be done over Open AI to calculate perplexity of MMLU tasks (topics) on mlx fine tuned models in the same way the test/evaluate libraries do.

(Sorry if this is off topic, maybe later I can add a discussion item)

I only used basic evaluation and didn't set up evaluation with OpenAI. I think this is quite a big topic; maybe we can have a separate discussion. In my opinion, currently the LLM evaluation is kind of lagging behind the progress of the LLM advancements, so I normally just check the benchmarks difference over different epochs without really paying much attention to the actual score.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants