Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion metatomic-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ bench = false
[dependencies]
metatensor = { version = "0.3.0" }
once_cell = "1"
dlpk = "0.3"
dlpk = { version = "0.3", features = ["ndarray"]}
json = "0.12"
libloading = "0.8"
ndarray = "0.17"


[build-dependencies]
Expand Down
10 changes: 9 additions & 1 deletion metatomic-core/include/metatomic.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,19 @@ typedef enum mta_status_t {
* Status code indicating serialization/deserialization errors
*/
MTA_SERIALIZATION_ERROR = 3,
/**
* Status code indicating dlpack errors
*/
MTA_DLPACK_ERROR = 4,
/**
* Status code indicating metatensor errors
*/
MTA_METATENSOR_ERROR = 5,
/**
* Status code used by plugins when a model is not supported by the
* current plugin
*/
MTA_MODEL_NOT_SUPPORTED_ERROR = 4,
MTA_MODEL_NOT_SUPPORTED_ERROR = 6,
/**
* Status code used when there is an internal error
*/
Expand Down
9 changes: 8 additions & 1 deletion metatomic-core/src/c_api/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ pub enum mta_status_t {
MTA_IO_ERROR = 2,
/// Status code indicating serialization/deserialization errors
MTA_SERIALIZATION_ERROR = 3,
/// Status code indicating dlpack errors
MTA_DLPACK_ERROR = 4,
/// Status code indicating metatensor errors
MTA_METATENSOR_ERROR = 5,
/// Status code used by plugins when a model is not supported by the
/// current plugin
MTA_MODEL_NOT_SUPPORTED_ERROR = 4,
MTA_MODEL_NOT_SUPPORTED_ERROR = 6,
/// Status code used when there is an internal error
MTA_INTERNAL_ERROR = 255,
}
Expand Down Expand Up @@ -107,8 +111,11 @@ impl From<Error> for mta_status_t {
Error::InvalidParameter(_) => mta_status_t::MTA_INVALID_PARAMETER_ERROR,
Error::Io(_) => mta_status_t::MTA_IO_ERROR,
Error::Serialization(_) => mta_status_t::MTA_SERIALIZATION_ERROR,
Error::Dlpack(_) => mta_status_t::MTA_DLPACK_ERROR,
Error::Metatensor(_) => mta_status_t::MTA_METATENSOR_ERROR,
Error::CallbackError(_) => unreachable!("already handled above"),
Error::Internal(_) => mta_status_t::MTA_INTERNAL_ERROR,

}
}
}
Expand Down
20 changes: 20 additions & 0 deletions metatomic-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ pub enum Error {
InvalidParameter(String),
/// I/O error
Io(Arc<std::io::Error>),
/// Error related to dlpack tensors, such as invalid tensor shapes or types
Dlpack(Arc<dlpk::ndarray::DLPackNDarrayError>),
/// Error coming from metatensor
Metatensor(metatensor::Error),
/// Error coming from an external function used as a callback
CallbackError(mta_status_t),
/// Any other internal error, usually these are internal bugs.
Expand All @@ -58,6 +62,8 @@ impl std::fmt::Display for Error {
Error::Serialization(e) => write!(f, "serialization error: {}", e),
Error::InvalidParameter(e) => write!(f, "invalid parameter: {}", e),
Error::Io(e) => write!(f, "io error: {}", e),
Error::Dlpack(e) => write!(f, "dlpack error: {}", e),
Error::Metatensor(e) => write!(f, "metatensor error: {}", e),
Error::CallbackError(e) => write!(f, "callback error, status code: {:?}", e),
Error::Internal(e) => write!(f,
"internal metatomic error (this is likely a bug, please report it): {}", e
Expand All @@ -74,6 +80,8 @@ impl std::error::Error for Error {
| Error::Internal(_)
| Error::CallbackError(_) => None,
Error::Io(e) => Some(e),
Error::Dlpack(e) => Some(e),
Error::Metatensor(e) => Some(e),
}
}

Expand Down Expand Up @@ -102,3 +110,15 @@ impl From<std::io::Error> for Error {
Error::Io(Arc::new(error))
}
}

impl From<dlpk::ndarray::DLPackNDarrayError> for Error {
fn from(error: dlpk::ndarray::DLPackNDarrayError) -> Self {
Error::Dlpack(Arc::new(error))
}
}

impl From<metatensor::Error> for Error {
fn from(error: metatensor::Error) -> Self {
Error::Metatensor(error)
}
}
8 changes: 4 additions & 4 deletions metatomic-core/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ use crate::units::validate_unit;
#[derive(Debug, Clone)]
pub struct PairListOptions {
/// Cutoff radius for this pair list in the length unit of the model
cutoff: f64,
pub cutoff: f64,
/// Whether the list is a full list (contains both the pair `i -> j` and `j -> i`)
/// or a half list (contains only `i -> j`)
full_list: bool,
pub full_list: bool,
/// Whether the list guarantees that only atoms within the cutoff are
/// included (strict) or may also include pairs slightly beyond the cutoff
/// (non-strict)
strict: bool,
pub strict: bool,
/// List of strings describing who requested this pair list
requestors: Vec<String>,
pub requestors: Vec<String>,
}

impl std::cmp::PartialEq for PairListOptions {
Expand Down
23 changes: 18 additions & 5 deletions metatomic-core/src/quantities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ fn is_valid_identifier(s: &str) -> bool {
/// All components (namespace, name, variant) must be non-empty if they are
/// present, and must be valid identifiers (alphanumeric + underscore, not
/// starting with a digit).
fn validate_quantity_name(name: &str) -> Result<(), Error> {
pub(crate) fn validate_quantity_name(name: &str) -> Result<(), Error> {
if STANDARD_QUANTITIES.contains(&name) {
return Ok(());
}
Expand All @@ -67,7 +67,12 @@ fn validate_quantity_name(name: &str) -> Result<(), Error> {
}
}

for component in main_part.split("::") {
if STANDARD_QUANTITIES.contains(&main_part) {
return Ok(());
}

let components: Vec<_> = main_part.split("::").collect();
for component in &components {
if !is_valid_identifier(component) {
return Err(Error::InvalidParameter(format!(
"invalid quantity name component '{}' in '{}': must be a valid identifier (alphanumeric or underscore, not starting with a digit)",
Expand All @@ -76,6 +81,13 @@ fn validate_quantity_name(name: &str) -> Result<(), Error> {
}
}

if components.len() == 1 {
return Err(Error::InvalidParameter(format!(
"'{}' is not a standard quantity name; custom quantity names must use '<namespace>::<name>'",
name
)));
}

Ok(())
}

Expand Down Expand Up @@ -289,7 +301,7 @@ mod tests {
vec![Gradients::Positions, Gradients::Strain],
] {
let quantity = Quantity {
name: "test".into(),
name: "test_ns::test".into(),
unit: "unit".into(),
description: Some("Hello".to_string()),
gradients: grads.clone(),
Expand Down Expand Up @@ -367,9 +379,7 @@ mod tests {
"my_model::energy",
"org::my_model::custom_qty",
"ns1::ns2::ns3::energy",
"custom_name",
"some_ns::name_with_underscores",
"_underscore_start",
"_ns::_name",
];
for name in custom {
Expand All @@ -388,6 +398,9 @@ mod tests {
let error = validate_quantity_name("").expect_err("expected an error");
assert_eq!(error.to_string(), "invalid parameter: quantity name cannot be empty in ''");

let error = validate_quantity_name("not_a_standard_name").expect_err("expected an error");
assert_eq!(error.to_string(), "invalid parameter: 'not_a_standard_name' is not a standard quantity name; custom quantity names must use '<namespace>::<name>'");

let error = validate_quantity_name("/variant").expect_err("expected an error");
assert_eq!(error.to_string(), "invalid parameter: quantity name cannot be empty in '/variant'");

Expand Down
Loading
Loading