Skip to content

Commit b89259d

Browse files
authored
Make Module.update() throw instead of crash for incompatible parameters (#266)
* replace fatal errors with throwing errors
1 parent 8a25603 commit b89259d

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

Source/MLXNN/Module.swift

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,9 @@ open class Module {
516516
break
517517

518518
default:
519-
fatalError(
520-
"Unable to set \(path.joined(separator: ".")) on \(modulePath.joined(separator: ".")): \(item) not compatible with \(value.mapValues { $0.shape.description })"
521-
)
519+
throw UpdateError.incompatibleItems(
520+
path: path, modules: modulePath, item: item.description,
521+
value: String(describing: (value.mapValues { $0.shape.description })))
522522
}
523523
}
524524

@@ -630,9 +630,7 @@ open class Module {
630630

631631
switch (item, value) {
632632
case (.value(.parameters), .value):
633-
fatalError(
634-
"Unable to set \(path.joined(separator: ".")) on \(modulePath.joined(separator: ".")): parameters (MLXArray) cannot be updated with a Module"
635-
)
633+
throw UpdateError.settingArrayWithModule(path: path, modules: modulePath)
636634

637635
case (.array(let items), .array(let values)):
638636
// Could be:
@@ -689,9 +687,7 @@ open class Module {
689687
}
690688
}
691689
default:
692-
fatalError(
693-
"Unexpected structure for \(key) on \(self): not @ModuleInfo var modules = [...]"
694-
)
690+
throw UpdateError.unexpectedStructure(key: key, item: self.description)
695691
}
696692

697693
case (.dictionary, .dictionary(let values)):
@@ -723,9 +719,9 @@ open class Module {
723719
try module.update(modules: NestedDictionary(values: values), verify: verify)
724720

725721
default:
726-
fatalError(
727-
"Unable to set \(path.joined(separator: ".")) on \(modulePath.joined(separator: ".")): \(item) not compatible with \(value)"
728-
)
722+
throw UpdateError.incompatibleItems(
723+
path: path, modules: modulePath, item: item.description,
724+
value: value.description)
729725
}
730726
}
731727

@@ -1573,6 +1569,9 @@ enum UpdateError: Error {
15731569
case unableToSet(String)
15741570
case unableToCast(String)
15751571
case unhandledKeys(path: [String], modules: [String], keys: [String])
1572+
case settingArrayWithModule(path: [String], modules: [String])
1573+
case incompatibleItems(path: [String], modules: [String], item: String, value: String)
1574+
case unexpectedStructure(key: String, item: String)
15761575
}
15771576

15781577
extension UpdateError: LocalizedError {
@@ -1599,6 +1598,15 @@ extension UpdateError: LocalizedError {
15991598
case .unhandledKeys(let path, let modules, let keys):
16001599
return
16011600
"Unhandled keys \(keys) in \(path.joined(separator: ".")) in \(modules.joined(separator: "."))"
1601+
case .settingArrayWithModule(let path, let modules):
1602+
return
1603+
"Unable to set \(path.joined(separator: ".")) on \(modules.joined(separator: ".")): parameters (MLXArray) cannot be updated with a Module"
1604+
case .incompatibleItems(let path, let modules, let item, let value):
1605+
return
1606+
"Unable to set \(path.joined(separator: ".")) on \(modules.joined(separator: ".")): \(item) not compatible with \(value)"
1607+
case .unexpectedStructure(let key, let item):
1608+
return "Unexpected structure for \(key) on \(item): not @ModuleInfo var modules = [...]"
1609+
16021610
}
16031611
}
16041612
}

0 commit comments

Comments
 (0)