From 058ee6f11c5aab4c0a445044e0ed011b458663c5 Mon Sep 17 00:00:00 2001 From: Sachin Desai Date: Thu, 11 Sep 2025 09:59:44 -0700 Subject: [PATCH 1/2] port of lille-130m --- Libraries/MLXLLM/LLMModelFactory.swift | 7 + Libraries/MLXLLM/Models/Lille130m.swift | 246 ++++++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 Libraries/MLXLLM/Models/Lille130m.swift diff --git a/Libraries/MLXLLM/LLMModelFactory.swift b/Libraries/MLXLLM/LLMModelFactory.swift index dd76def8..9f0533de 100644 --- a/Libraries/MLXLLM/LLMModelFactory.swift +++ b/Libraries/MLXLLM/LLMModelFactory.swift @@ -57,6 +57,7 @@ public class LLMTypeRegistry: ModelTypeRegistry, @unchecked Sendable { "baichuan_m1": create(BaichuanM1Configuration.self, BaichuanM1Model.init), "exaone4": create(Exaone4Configuration.self, Exaone4Model.init), "gpt_oss": create(GPTOSSConfiguration.self, GPTOSSModel.init), + "lille-130m": create(Lille130mConfiguration.self, Lille130mModel.init), ] } } @@ -296,6 +297,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable { defaultPrompt: "Why is the sky blue?" ) + static public let lille_130m_bf16 = ModelConfiguration( + id: "mlx-community/lille-130m-instruct-bf16", + defaultPrompt: "Why is the sky blue?" + ) + private static func all() -> [ModelConfiguration] { [ codeLlama13b4bit, @@ -338,6 +344,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable { lfm2_1_2b_4bit, baichuan_m1_14b_instruct_4bit, exaone_4_0_1_2b_4bit, + lille_130m_bf16, ] } diff --git a/Libraries/MLXLLM/Models/Lille130m.swift b/Libraries/MLXLLM/Models/Lille130m.swift new file mode 100644 index 00000000..0ee351b7 --- /dev/null +++ b/Libraries/MLXLLM/Models/Lille130m.swift @@ -0,0 +1,246 @@ +// +// Lille130m.swift +// mlx-swift-examples +// +// Created by Sachin Desai on 9/10/25. +// + +import Foundation +import MLX +import MLXLMCommon +import MLXNN + +// MARK: - Attention + +private final class Lille130mAttention: Module { + let args: Lille130mConfiguration + let headDim: Int + let scale: Float + + @ModuleInfo(key: "qkv_proj") var qkvProj: Linear + @ModuleInfo(key: "out_proj") var outProj: Linear + @ModuleInfo(key: "norm") var norm: RMSNorm + + var rope: RoPE + + init(_ args: Lille130mConfiguration) { + self.args = args + self.headDim = args.hiddenSize / args.attentionHeads + self.scale = pow(Float(headDim), -0.5) + + let qkvOut = (args.attentionHeads + 2 * args.kvHeads) * headDim + _qkvProj.wrappedValue = Linear(args.hiddenSize, qkvOut, bias: false) + _outProj.wrappedValue = Linear(args.attentionHeads * headDim, args.hiddenSize, bias: false) + + _norm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.layerNormEps) + + self.rope = RoPE( + dimensions: headDim, + traditional: true, + base: args.ropeTheta, + scale: 1.0 + ) + } + + func callAsFunction( + _ x: MLXArray, + mask: MLXFast.ScaledDotProductAttentionMaskMode, + cache: KVCache? + ) -> MLXArray { + let (B, L) = (x.dim(0), x.dim(1)) + + let qkv = qkvProj(norm(x)) + + let qSize = args.attentionHeads * headDim + let kvSize = args.kvHeads * headDim + + var queries = qkv[.ellipsis, 0 ..< qSize] + let kv = qkv[.ellipsis, qSize...] + var keys = kv[.ellipsis, 0 ..< kvSize] + var values = kv[.ellipsis, kvSize...] + + queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) + keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) + + // Apply RoPE with cache-aware offset if available + if let cache { + queries = rope(queries, offset: cache.offset) + keys = rope(keys, offset: cache.offset) + } else { + queries = rope(queries) + keys = rope(keys) + } + + let output = attentionWithCacheUpdate( + queries: queries, + keys: keys, + values: values, + cache: cache, + scale: scale, + mask: mask + ) + .transposed(0, 2, 1, 3) + .reshaped(B, L, -1) + + return outProj(output) + } +} + +// MARK: - MLP + +private final class Lille130mMLP: Module, UnaryLayer { + @ModuleInfo(key: "norm") var norm: RMSNorm + @ModuleInfo(key: "gate_proj") var gate: Linear + @ModuleInfo(key: "up_proj") var up: Linear + @ModuleInfo(key: "down_proj") var down: Linear + + init(_ args: Lille130mConfiguration) { + let numerator = (8 * args.hiddenSize) / 3 + let rounded = Int(round(Float(numerator) / 256.0)) + let hiddenDim = max(256 * rounded, 1) + + _norm.wrappedValue = RMSNorm(dimensions: args.hiddenSize, eps: args.layerNormEps) + _gate.wrappedValue = Linear(args.hiddenSize, hiddenDim, bias: false) + _up.wrappedValue = Linear(args.hiddenSize, hiddenDim, bias: false) + _down.wrappedValue = Linear(hiddenDim, args.hiddenSize, bias: false) + } + + func callAsFunction(_ x: MLXArray) -> MLXArray { + let h = norm(x) + return down(silu(gate(h)) * up(h)) + } +} + +// MARK: - Block + +private final class Lille130mBlock: Module { + @ModuleInfo(key: "attention") var attention: Lille130mAttention + @ModuleInfo(key: "feed_forward") var feedForward: Lille130mMLP + + init(_ args: Lille130mConfiguration) { + _attention.wrappedValue = Lille130mAttention(args) + _feedForward.wrappedValue = Lille130mMLP(args) + } + + func callAsFunction( + _ x: MLXArray, + mask: MLXFast.ScaledDotProductAttentionMaskMode, + cache: KVCache? + ) -> MLXArray { + let h = x + attention(x, mask: mask, cache: cache) + return h + feedForward(h) + } +} + +// MARK: - Model (inner) + +private final class Lille130mModelInner: Module { + @ModuleInfo(key: "tok_embeddings") var embedTokens: Embedding + + let layers: [Lille130mBlock] + let norm: RMSNorm + + init(_ args: Lille130mConfiguration) { + precondition(args.vocabularySize > 0) + _embedTokens.wrappedValue = Embedding(embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + self.layers = (0 ..< args.hiddenLayers).map { _ in Lille130mBlock(args) } + self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.layerNormEps) + } + + func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { + var h = embedTokens(inputs) + let mask = createAttentionMask(h: h, cache: cache) + for (i, layer) in layers.enumerated() { + h = layer(h, mask: mask, cache: cache?[i]) + } + return norm(h) + } +} + +// MARK: - Public Model + +public final class Lille130mModel: Module, LLMModel, KVCacheDimensionProvider { + public let vocabularySize: Int + public let kvHeads: [Int] + + @ModuleInfo(key: "transformer") fileprivate var transformer: Lille130mModelInner + private let configuration: Lille130mConfiguration + + public init(_ args: Lille130mConfiguration) { + self.configuration = args + self.vocabularySize = args.vocabularySize + self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } + + _transformer.wrappedValue = Lille130mModelInner(args) + } + + public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { + let out = transformer(inputs, cache: cache) + return transformer.embedTokens.asLinear(out) + } + + public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { + let w = weights.filter { key, _ in !key.contains("rotary_emb") } + return w + } +} + +// MARK: - Configuration + +public struct Lille130mConfiguration: Codable, Sendable { + public var modelType: String + public var blockSize: Int + public var layerNormEps: Float + public var hiddenSize: Int // n_embd + public var attentionHeads: Int // n_head + public var kvHeads: Int // n_kv_heads + public var hiddenLayers: Int // n_layer + public var ropeTheta: Float + public var vocabularySize: Int + public var tieWordEmbeddings: Bool = true + + enum CodingKeys: String, CodingKey { + case modelType = "model_type" + case blockSize = "block_size" + case layerNormEps = "layer_norm_eps" + case hiddenSize = "n_embd" + case attentionHeads = "n_head" + case kvHeads = "n_kv_heads" + case hiddenLayers = "n_layer" + case ropeTheta = "rope_theta" + case vocabularySize = "vocab_size" + } + + public init(from decoder: Decoder) throws { + let container: KeyedDecodingContainer = + try decoder.container(keyedBy: Lille130mConfiguration.CodingKeys.self) + + self.modelType = try container.decode( + String.self, forKey: Lille130mConfiguration.CodingKeys.modelType) + self.blockSize = try container.decode( + Int.self, forKey: Lille130mConfiguration.CodingKeys.blockSize) + self.layerNormEps = try container.decode( + Float.self, forKey: Lille130mConfiguration.CodingKeys.layerNormEps) + self.hiddenSize = try container.decode( + Int.self, forKey: Lille130mConfiguration.CodingKeys.hiddenSize) + self.attentionHeads = try container.decode( + Int.self, forKey: Lille130mConfiguration.CodingKeys.attentionHeads) + self.kvHeads = try container.decode( + Int.self, forKey: Lille130mConfiguration.CodingKeys.kvHeads) + self.hiddenLayers = try container.decode( + Int.self, forKey: Lille130mConfiguration.CodingKeys.hiddenLayers) + self.ropeTheta = try container.decode( + Float.self, forKey: Lille130mConfiguration.CodingKeys.ropeTheta) + self.vocabularySize = try container.decode( + Int.self, forKey: Lille130mConfiguration.CodingKeys.vocabularySize) + } +} + +// MARK: - LoRA + +extension Lille130mModel: LoRAModel { + public func loraLinearLayers() -> LoRALinearLayers { + transformer.layers.map { ($0.attention, ["qkv_proj"]) } + } +} From aa62ff456ace4868bb0f899ff26812ae9e7445ff Mon Sep 17 00:00:00 2001 From: Sachin Desai Date: Thu, 11 Sep 2025 16:34:52 -0700 Subject: [PATCH 2/2] swift format --- Libraries/MLXLLM/Models/Lille130m.swift | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Libraries/MLXLLM/Models/Lille130m.swift b/Libraries/MLXLLM/Models/Lille130m.swift index 0ee351b7..3906ad67 100644 --- a/Libraries/MLXLLM/Models/Lille130m.swift +++ b/Libraries/MLXLLM/Models/Lille130m.swift @@ -143,7 +143,8 @@ private final class Lille130mModelInner: Module { init(_ args: Lille130mConfiguration) { precondition(args.vocabularySize > 0) - _embedTokens.wrappedValue = Embedding(embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) + _embedTokens.wrappedValue = Embedding( + embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) self.layers = (0 ..< args.hiddenLayers).map { _ in Lille130mBlock(args) } self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.layerNormEps) } @@ -192,10 +193,10 @@ public struct Lille130mConfiguration: Codable, Sendable { public var modelType: String public var blockSize: Int public var layerNormEps: Float - public var hiddenSize: Int // n_embd + public var hiddenSize: Int // n_embd public var attentionHeads: Int // n_head - public var kvHeads: Int // n_kv_heads - public var hiddenLayers: Int // n_layer + public var kvHeads: Int // n_kv_heads + public var hiddenLayers: Int // n_layer public var ropeTheta: Float public var vocabularySize: Int public var tieWordEmbeddings: Bool = true @@ -211,11 +212,11 @@ public struct Lille130mConfiguration: Codable, Sendable { case ropeTheta = "rope_theta" case vocabularySize = "vocab_size" } - + public init(from decoder: Decoder) throws { let container: KeyedDecodingContainer = try decoder.container(keyedBy: Lille130mConfiguration.CodingKeys.self) - + self.modelType = try container.decode( String.self, forKey: Lille130mConfiguration.CodingKeys.modelType) self.blockSize = try container.decode(