Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Libraries/Embedders/BaseConfiguration.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.

import Foundation
import MLX

/// Base ``LanguageModel`` configuration -- provides `modelType`
/// and `quantization` (used in loading the model).
Expand All @@ -18,12 +19,15 @@ public struct BaseConfiguration: Codable, Sendable {

public let groupSize: Int
public let bits: Int
private var _mode: QuantizationMode? = nil
public var mode: QuantizationMode { _mode ?? .affine }

public var asTuple: (Int, Int) { (groupSize, bits) }
public var asTuple: (Int, Int, QuantizationMode) { (groupSize, bits, mode) }

enum CodingKeys: String, CodingKey {
case groupSize = "group_size"
case bits = "bits"
case _mode = "mode"
}
}

Expand Down
6 changes: 6 additions & 0 deletions Libraries/MLXLLM/LLMModelFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
defaultPrompt: ""
)

static public let gpt_oss_20b_MXFP4_Q8 = ModelConfiguration(
id: "mlx-community/gpt-oss-20b-MXFP4-Q8",
defaultPrompt: "Why is the sky blue?"
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MXFP4 quantization is now supported. This model was used to test that and the quantized kvcache.


private static func all() -> [ModelConfiguration] {
[
codeLlama13b4bit,
Expand Down Expand Up @@ -389,6 +394,7 @@ public class LLMRegistry: AbstractModelRegistry, @unchecked Sendable {
ling_mini_2_2bit,
lfm2_8b_a1b_3bit_mlx,
nanochat_d20_mlx,
gpt_oss_20b_MXFP4_Q8,
]
}

Expand Down
3 changes: 2 additions & 1 deletion Libraries/MLXLLM/Models/GPTOSS.swift
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ private class AttentionBlock: Module {
scale: smScale,
mask: .array(mask),
groupSize: qcache.groupSize,
bits: qcache.bits
bits: qcache.bits,
mode: qcache.mode
)

return oProj(vHat.swappedAxes(1, 2).reshaped(B, L, -1))
Expand Down
15 changes: 10 additions & 5 deletions Libraries/MLXLLM/SwitchLayers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,28 @@ class SwitchLinear: Module, Quantizable {
return result
}

func toQuantized(groupSize: Int = 64, bits: Int = 4) -> Module {
QuantizedSwitchLinear(self, groupSize: groupSize, bits: bits)
func toQuantized(groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode) -> Module {
QuantizedSwitchLinear(self, groupSize: groupSize, bits: bits, mode: mode)
}
}

class QuantizedSwitchLinear: SwitchLinear, Quantized {
@ModuleInfo(key: "scales") var scales: MLXArray
@ModuleInfo(key: "biases") var biases: MLXArray
@ModuleInfo(key: "biases") var biases: MLXArray?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

biases are now optional


let groupSize: Int
let bits: Int
let mode: QuantizationMode

init(_ other: SwitchLinear, groupSize: Int = 64, bits: Int = 4) {
init(
_ other: SwitchLinear, groupSize: Int = 64, bits: Int = 4, mode: QuantizationMode = .affine
) {
self.groupSize = groupSize
self.bits = bits
self.mode = mode

let (quantizedWeight, scales, biases) = MLX.quantized(
other.weight, groupSize: groupSize, bits: bits)
other.weight, groupSize: groupSize, bits: bits, mode: mode)

self._scales.wrappedValue = scales
self._biases.wrappedValue = biases
Expand All @@ -183,6 +187,7 @@ class QuantizedSwitchLinear: SwitchLinear, Quantized {
transpose: true,
groupSize: self.groupSize,
bits: self.bits,
mode: mode,
sortedIndices: sortedIndices
)

Expand Down
6 changes: 4 additions & 2 deletions Libraries/MLXLMCommon/Adapters/LoRA/DoRA+Layers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ public class QDoRALinear: QuantizedLinear, LoRALayer {
super.init(
weight: linear.weight, bias: linear.bias,
scales: linear.scales, biases: linear.biases,
groupSize: linear.groupSize, bits: linear.bits
groupSize: linear.groupSize, bits: linear.bits,
mode: linear.mode
)

freeze()
Expand All @@ -171,7 +172,8 @@ public class QDoRALinear: QuantizedLinear, LoRALayer {

public override func callAsFunction(_ x: MLXArray) -> MLXArray {
let y = quantizedMatmul(
x, weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits)
x, weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits,
mode: mode)
return forward(
x: x, y: y,
weight: dequantizedWeight, bias: bias,
Expand Down
3 changes: 2 additions & 1 deletion Libraries/MLXLMCommon/Adapters/LoRA/LoRAModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ extension QuantizedLinear {
scales: scales,
biases: biases,
groupSize: groupSize,
bits: bits
bits: bits,
mode: mode
)
}
}
5 changes: 3 additions & 2 deletions Libraries/MLXLMCommon/AttentionUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public func attentionWithCacheUpdate(
mask: mask
)
}
if let quantizedKVCache = cache as? QuantizedKVCache {
if let quantizedKVCache = cache as? QuantizedKVCacheProtocol {
let (quantizedKeys, quantizedValues) = quantizedKVCache.updateQuantized(
keys: keys, values: values)
return quantizedScaledDotProductAttention(
Expand All @@ -62,7 +62,8 @@ public func attentionWithCacheUpdate(
scale: scale,
mask: mask,
groupSize: quantizedKVCache.groupSize,
bits: quantizedKVCache.bits
bits: quantizedKVCache.bits,
mode: quantizedKVCache.mode
)
} else {
let (cachedKeys, cachedValues) = cache.update(keys: keys, values: values)
Expand Down
23 changes: 10 additions & 13 deletions Libraries/MLXLMCommon/BaseConfiguration.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.

import Foundation
import MLX

/// Base ``LanguageModel`` configuration -- provides `modelType`
/// and `quantization` (used in loading the model).
Expand All @@ -18,20 +19,15 @@ public struct BaseConfiguration: Codable, Sendable {

public let groupSize: Int
public let bits: Int
public var quantMethod: String? = nil
public var linearClass: String? = nil
public var quantizationMode: String? = nil
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were defined so that they could be skipped (below). We can just skip them directly.

public var mode: String? = nil
private var _mode: QuantizationMode? = nil
public var mode: QuantizationMode { _mode ?? .affine }

public var asTuple: (Int, Int) { (groupSize, bits) }
public var asTuple: (Int, Int, QuantizationMode) { (groupSize, bits, mode) }

enum CodingKeys: String, CodingKey {
case groupSize = "group_size"
case bits = "bits"
case quantMethod = "quant_method"
case linearClass = "linear_class"
case quantizationMode = "quantization_mode"
case mode = "mode"
case _mode = "mode"
}
}

Expand Down Expand Up @@ -115,10 +111,11 @@ public struct BaseConfiguration: Codable, Sendable {
switch key.stringValue {
case Quantization.CodingKeys.groupSize.rawValue: continue
case Quantization.CodingKeys.bits.rawValue: continue
case Quantization.CodingKeys.quantMethod.rawValue: continue
case Quantization.CodingKeys.linearClass.rawValue: continue
case Quantization.CodingKeys.quantizationMode.rawValue: continue
case Quantization.CodingKeys.mode.rawValue: continue
case Quantization.CodingKeys._mode.rawValue: continue

// additional keys that are not layer instructions, see
// mlx-community/bitnet-b1.58-2B-4T-4bit
case "quant_method", "linear_class", "quantization_mode": continue
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip directly.


default:
if let f = try? container.decode(Bool.self, forKey: key) {
Expand Down
Loading