Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

MLX Swift was developed with contributions from the following individuals:

- [John Mai](https://github.com/johnmai-dev): Added support for multiple models (Qwen2, Starcoder2, InternLM2, Qwen3, Qwen3 MoE, GLM-4, MiMo, BitNet, SmolLM3, LFM2).

- [John Mai](https://github.com/johnmai-dev): Added support for multiple models (Qwen2, Starcoder2, InternLM2, Qwen3, Qwen3 MoE, GLM-4, MiMo, BitNet, SmolLM3, LFM2, Baichuan-M1).

<a href="https://github.com/ml-explore/mlx-swift-examples/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx-swift-examples&anon=0&columns=20&max=100&r=true" />
Expand Down
7 changes: 7 additions & 0 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable {
"smollm3": create(SmolLM3Configuration.self, SmolLM3Model.init),
"ernie4_5": create(Ernie45Configuration.self, Ernie45Model.init),
"lfm2": create(LFM2Configuration.self, LFM2Model.init),
"baichuan_m1": create(BaichuanM1Configuration.self, BaichuanM1Model.init),
]
}

Expand Down Expand Up @@ -234,6 +235,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
defaultPrompt: "Why is the sky blue?"
)

static public let baichuan_m1_14b_instruct_4bit = ModelConfiguration(
id: "mlx-community/Baichuan-M1-14B-Instruct-4bit-ft",
defaultPrompt: "Why is the sky blue?"
)

static public let smollm3_3b_4bit = ModelConfiguration(
id: "mlx-community/SmolLM3-3B-4bit",
defaultPrompt: "Why is the sky blue?"
Expand Down Expand Up @@ -284,6 +290,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
smollm3_3b_4bit,
ernie_45_0_3BPT_bf16_ft,
lfm2_1_2b_4bit,
baichuan_m1_14b_instruct_4bit,
]
}

Expand Down
300 changes: 300 additions & 0 deletions Libraries/MLXLLM/Models/BaichuanM1.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
//
// BaichuanM1.swift
// mlx-swift-examples
//
// Created by John Mai on 2025/6/16.
//

import Foundation
import MLX
import MLXFast
import MLXLMCommon
import MLXNN
import MLXRandom

public struct BaichuanM1Configuration: Codable, Sendable {
var vocabularySize: Int
var hiddenSize: Int
var intermediateSize: Int
var hiddenLayers: Int
var attentionHeads: Int
var kvHeads: Int
var ropeTheta: Float
var slidingWindow: Int
var slidingWindowLayers: [Int]
var convWindow: Int
var rmsNormEps: Float
var swaAttentionHeads: Int?
var swaKvHeads: Int?
var tieWordEmbeddings: Bool = false

enum CodingKeys: String, CodingKey {
case vocabularySize = "vocab_size"
case hiddenSize = "hidden_size"
case intermediateSize = "intermediate_size"
case hiddenLayers = "num_hidden_layers"
case attentionHeads = "num_attention_heads"
case kvHeads = "num_key_value_heads"
case ropeTheta = "rope_theta"
case slidingWindow = "sliding_window"
case slidingWindowLayers = "sliding_window_layers"
case convWindow = "conv_window"
case rmsNormEps = "rms_norm_eps"
case swaAttentionHeads = "num_swa_attention_heads"
case swaKvHeads = "num_swa_key_value_heads"
case tieWordEmbeddings = "tie_word_embeddings"
}
}

private class Attention: Module {
let config: BaichuanM1Configuration
let layerIdx: Int
let isSWA: Bool
let numHeads: Int
let numKVHeads: Int
let hiddenSize: Int
let headDim: Int
let scale: Float

@ModuleInfo(key: "W_pack") var wPack: Linear
@ModuleInfo(key: "o_proj") var oProj: Linear
let rope: RoPE

@ParameterInfo(key: "conv_k") var convK: MLXArray
@ParameterInfo(key: "conv_v") var convV: MLXArray

init(_ config: BaichuanM1Configuration, layerIdx: Int) {
self.config = config
self.layerIdx = layerIdx

self.isSWA = config.slidingWindowLayers.contains(layerIdx)
self.numHeads =
isSWA && config.swaAttentionHeads != nil
? config.swaAttentionHeads! : config.attentionHeads
self.numKVHeads = isSWA && config.swaKvHeads != nil ? config.swaKvHeads! : config.kvHeads

self.hiddenSize = config.hiddenSize
self.headDim = hiddenSize / numHeads
self.scale = pow(Float(headDim), -0.5)

_wPack.wrappedValue = Linear(
config.hiddenSize, config.hiddenSize + 2 * numKVHeads * headDim, bias: false)
_oProj.wrappedValue = Linear(numHeads * headDim, config.hiddenSize, bias: false)

self.rope = RoPE(dimensions: headDim, traditional: false, base: config.ropeTheta)

_convK.wrappedValue = MLXArray.zeros([1, 1, numKVHeads, 1, config.convWindow])
_convV.wrappedValue = MLXArray.zeros([1, 1, numKVHeads, 1, config.convWindow])
}

func customConvolution(_ u: MLXArray, _ weights: MLXArray, state: MLXArray? = nil) -> MLXArray {
let (B, H, L, D) = (u.dim(0), u.dim(1), u.dim(2), u.dim(3))
let reshapedWeights = weights.reshaped(1, H, config.convWindow, 1, 1)
let w0 = reshapedWeights[0..., 0..., 0]
let w1 = reshapedWeights[0..., 0..., 1]

let state = state ?? MLXArray.zeros([B, H, 1, D], dtype: u.dtype)

let uPrev: MLXArray =
L > 1 ? concatenated([state, u[0..., 0..., ..<(L - 1), 0...]], axis: 2) : state

return uPrev * w0 + u * w1
}

func callAsFunction(
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
) -> MLXArray {
let (B, L, D) = (x.dim(0), x.dim(1), x.dim(2))

let proj = wPack(x)
let qkv = split(proj, indices: [D, D + self.numKVHeads * self.headDim], axis: -1)

var queries = qkv[0].reshaped(B, L, numHeads, headDim).transposed(0, 2, 1, 3)
var keys = qkv[1].reshaped(B, L, numKVHeads, headDim).transposed(0, 2, 1, 3)
var values = qkv[2].reshaped(B, L, numKVHeads, headDim).transposed(0, 2, 1, 3)

var offset = 0
var lastK: MLXArray? = nil
var lastV: MLXArray? = nil

if let cacheList = cache as? CacheList {
offset = cacheList[1].offset
if let mambaCache = cacheList[0] as? MambaCache {
lastK = mambaCache[0]
lastV = mambaCache[1]
}
}

let kInit = keys
let vInit = values

keys = customConvolution(keys, convK, state: lastK)
values = customConvolution(values, convV, state: lastV)

queries = rope(queries, offset: offset)
keys = rope(keys, offset: offset)

if let cache = cache as? CacheList {
let kvCache = cache[1]
let (cachedKeys, cachedValues) = kvCache.update(keys: keys, values: values)
keys = cachedKeys
values = cachedValues

if L > 0 {
let convCache = cache[0] as! MambaCache
convCache[0] = kInit[0..., 0..., (L - 1)..., 0...]
convCache[1] = vInit[0..., 0..., (L - 1)..., 0...]
}
}

let out = MLXFast.scaledDotProductAttention(
queries: queries, keys: keys, values: values, scale: scale, mask: mask
)
.transposed(0, 2, 1, 3)
.reshaped(B, L, -1)

return oProj(out)
}
}

private class MLP: Module, UnaryLayer {
@ModuleInfo(key: "gate_proj") var gateProj: Linear
@ModuleInfo(key: "up_proj") var upProj: Linear
@ModuleInfo(key: "down_proj") var downProj: Linear

init(_ config: BaichuanM1Configuration) {
_gateProj.wrappedValue = Linear(config.hiddenSize, config.intermediateSize, bias: false)
_upProj.wrappedValue = Linear(config.hiddenSize, config.intermediateSize, bias: false)
_downProj.wrappedValue = Linear(config.intermediateSize, config.hiddenSize, bias: false)
}

func callAsFunction(_ x: MLXArray) -> MLXArray {
return downProj(silu(gateProj(x)) * upProj(x))
}
}

private class DecoderLayer: Module {
@ModuleInfo(key: "self_attn") var attention: Attention
let mlp: MLP
@ModuleInfo(key: "input_layernorm") var inputLayernorm: RMSNorm
@ModuleInfo(key: "post_attention_layernorm") var postAttentionLayernorm: RMSNorm

init(_ config: BaichuanM1Configuration, layerIdx: Int) {
_attention.wrappedValue = Attention(config, layerIdx: layerIdx)
self.mlp = MLP(config)
_inputLayernorm.wrappedValue = RMSNorm(
dimensions: config.hiddenSize, eps: config.rmsNormEps)
_postAttentionLayernorm.wrappedValue = RMSNorm(
dimensions: config.hiddenSize, eps: config.rmsNormEps)
}

func callAsFunction(
_ x: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode, cache: KVCache?
) -> MLXArray {
var r = attention(inputLayernorm(x), mask: mask, cache: cache)
let x = x + r
r = mlp(postAttentionLayernorm(x))
return x + r
}
}

private class BaichuanM1ModelInner: Module {
let args: BaichuanM1Configuration
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding

fileprivate let layers: [DecoderLayer]
let norm: RMSNorm

init(_ config: BaichuanM1Configuration) {
self.args = config
_embedTokens.wrappedValue = Embedding(
embeddingCount: config.vocabularySize, dimensions: config.hiddenSize)
self.layers = (0 ..< config.hiddenLayers).map { DecoderLayer(config, layerIdx: $0) }
norm = RMSNorm(dimensions: config.hiddenSize, eps: config.rmsNormEps)
}

func callAsFunction(
_ inputs: MLXArray, mask: MLXFast.ScaledDotProductAttentionMaskMode? = nil,
cache: [KVCache]?
) -> MLXArray {
var x = embedTokens(inputs)

let mask = mask ?? createAttentionMask(h: x, cache: cache)

for (i, layer) in layers.enumerated() {
x = layer(x, mask: mask, cache: cache?[i])
}

return norm(x)
}
}

public class BaichuanM1Model: Module, LLMModel, KVCacheDimensionProvider {

public let vocabularySize: Int
public let kvHeads: [Int]

private let model: BaichuanM1ModelInner
let configuration: BaichuanM1Configuration

@ModuleInfo(key: "lm_head") var lmHead: Linear?

public init(_ config: BaichuanM1Configuration) {
self.configuration = config
self.vocabularySize = config.vocabularySize
self.kvHeads = Array(repeating: config.kvHeads, count: config.hiddenLayers)
self.model = BaichuanM1ModelInner(config)

if !config.tieWordEmbeddings {
_lmHead.wrappedValue = Linear(config.hiddenSize, config.vocabularySize, bias: false)
}
}

public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
var outputs = model(inputs, cache: cache)

if let lmHead {
outputs = lmHead(outputs)
}

return outputs
}

public func newCache(parameters: GenerateParameters?) -> [KVCache] {
return model.layers.enumerated().map { (i, _) in
let isSWA = configuration.slidingWindowLayers.contains(i)
let convCache = MambaCache()
let kvCache: KVCache =
isSWA ? RotatingKVCache(maxSize: configuration.slidingWindow) : KVCacheSimple()
return CacheList(convCache, kvCache)
}
}

public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] {
var weights = weights
let isQuantized = weights["lm_head.scales"] != nil

if !isQuantized, let w = weights["lm_head.weight"] {
var w = w
if w.dtype != .float32 {
w = w.asType(.float32)
}

let norm = sqrt(sum(w * w, axes: [-1], keepDims: true))
w = (w / (norm + 1e-7)).asType(w.dtype)
weights["lm_head.weight"] = w
}

if configuration.tieWordEmbeddings {
weights["lm_head.weight"] = nil
}

return weights
}
}

extension BaichuanM1Model: LoRAModel {
public func loraLinearLayers() -> LoRALinearLayers {
model.layers.map { ($0.attention, ["W_pack"]) }
}
}