Skip to content

PC: Add support for Starcoder2#518

Closed
Blaizzy wants to merge 12 commits intoml-explore:mainfrom
Blaizzy:pc/starcoder2
Closed

PC: Add support for Starcoder2#518
Blaizzy wants to merge 12 commits intoml-explore:mainfrom
Blaizzy:pc/starcoder2

Conversation

@Blaizzy
Copy link
Copy Markdown
Contributor

@Blaizzy Blaizzy commented Mar 2, 2024

Hey @awni and @Muhtasham,

I was super excited about integrating starcoder 2 into mlx. However, while working on the mlx-langchain integration, I realized that @Muhtasham beat me to the punch and submitted a PR. Kudos to you!

I took a look at the original draft PR a few days ago and noticed a few areas that could use some improvement (mlp and others). When I tested the code with the quantized models uploaded to the hub, it didn't quite work as expected. But not to worry, I've fixed it! ✨ ✅

This implementation is working well and matches the original model using transformers in full precision and quantized configurations.

python -m mlx_lm.generate --model mlx-community/starcoder2-3b-4bit --prompt "Write a quick sort in C++"  --colorize --max-tokens 200

Output:

Prompt: Write a quick sort in C++

```cpp
//https://www.hackerrank.com/challenges/quicksort3/problem

#include <bits/stdc++.h>

using namespace std;

vector<int> quickSort(vector<int> ar) {
    vector<int> output;
    int pivot = ar[0];
    for(int i = 1; i < ar.size(); i++){
        if(ar[i] < pivot)
            output.push_back(ar[i]);
    }
    output.push_back(pivot);
    for(int i = 1; i < ar.size(); i++){
        if(ar[i] > pivot)
            output.push_back(ar[i]);
    }
    return output;
}

int main(void) {
    vector <int>  ar;
    int num_of_elements;
    cin >> num

@Blaizzy Blaizzy changed the title PC: Add support for starcoder2 PC: Add support for Starcoder2 Mar 3, 2024
@mzbac
Copy link
Copy Markdown
Contributor

mzbac commented Mar 3, 2024

Nice work! Just curious, does this work with the FIM prompt?

@mzbac mzbac mentioned this pull request Mar 3, 2024
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need to keep scores.astype(mx.float32) in case softmax overflows on float16 values.

@awni
Copy link
Copy Markdown
Member

awni commented Mar 3, 2024

@Blaizzy thanks for adding! I just merged #502 since it was basically done. I'm going to close this in favor of that. Appreciate the insights here which helped with #502!

@awni awni closed this Mar 3, 2024
@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Mar 3, 2024

Nice work! Just curious, does this work with the FIM prompt?

Thanks! It should work without a problem.

If not I can investigate.

@Blaizzy
Copy link
Copy Markdown
Contributor Author

Blaizzy commented Mar 3, 2024

@Blaizzy thanks for adding! I just merged #502 since it was basically done. I'm going to close this in favor of that. Appreciate the insights here which helped with #502!

Most welcome!

It's my pleasure 😊.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants