Skip to content

Commit

Permalink
adding softargmax functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
b-vanstraaten committed Oct 23, 2023
1 parent c2ea888 commit bc08dcb
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "qarray_rust_core"
version = "1.1.1"
version = "1.2.0"
edition = "2021"

[lib]
Expand Down
34 changes: 17 additions & 17 deletions src/closed_dots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use osqp::{CscMatrix, Problem, Settings};
use rayon::prelude::*;

use crate::charge_configurations::closed_charge_configurations;
use crate::helper_functions::{hard_argmin, soft_argmin};

#[allow(non_snake_case)]
pub fn ground_state_closed_1d<'a>(
v_g: ArrayView<'a, f64, Ix2>,
n_charge: u64,
Expand All @@ -14,6 +16,7 @@ pub fn ground_state_closed_1d<'a>(
c_dd_inv: ArrayView<'a, f64, Ix2>,
threshold: f64,
polish: bool,
T: f64,
) -> Array<f64, Ix2> {
let n = v_g.shape()[0];
let m = c_gd.shape()[0];
Expand All @@ -23,13 +26,15 @@ pub fn ground_state_closed_1d<'a>(

rows.par_iter_mut().enumerate().for_each(|(j, result_row)| {
let v_g_row = v_g.slice(s![j, ..]);
let n_charge =
ground_state_closed_0d(v_g_row, n_charge, c_gd, c_dd, c_dd_inv, threshold, polish);
let n_charge = ground_state_closed_0d(
v_g_row, n_charge, c_gd, c_dd, c_dd_inv, threshold, polish, T,
);
result_row.assign(&n_charge);
});
results_array
}

#[allow(non_snake_case)]
pub fn ground_state_closed_0d<'a>(
v_g: ArrayView<f64, Ix1>,
n_charge: u64,
Expand All @@ -38,6 +43,7 @@ pub fn ground_state_closed_0d<'a>(
c_dd_inv: ArrayView<'a, f64, Ix2>,
threshold: f64,
polish: bool,
T: f64,
) -> Array<f64, Ix1> {
let analytical_solution = analytical_solution(c_gd, c_dd, v_g, n_charge);
if analytical_solution
Expand All @@ -53,6 +59,7 @@ pub fn ground_state_closed_0d<'a>(
v_g,
n_charge,
threshold,
T,
);
} else {
// the analytical solution is not a valid charge configuration
Expand All @@ -64,7 +71,7 @@ pub fn ground_state_closed_0d<'a>(
let n_continuous =
Array1::<f64>::from(result.x().expect("failed to solve problem").to_owned());

return compute_argmin_closed(n_continuous, c_dd_inv, c_gd, v_g, n_charge, threshold);
return compute_argmin_closed(n_continuous, c_dd_inv, c_gd, v_g, n_charge, threshold, T);
}
}

Expand Down Expand Up @@ -126,28 +133,21 @@ fn init_osqp_problem_closed<'a>(
Problem::new(P, q, A, l, u, &settings).expect("failed to setup problem")
}

#[allow(non_snake_case)]
fn compute_argmin_closed(
n_continuous: Array1<f64>,
c_dd_inv: ArrayView<f64, Ix2>,
c_gd: ArrayView<f64, Ix2>,
v_g: ArrayView<f64, Ix1>,
n_charge: u64,
threshold: f64,
T: f64,
) -> Array1<f64> {
let n_list = closed_charge_configurations(n_continuous, n_charge, threshold);
let vg_dash = c_gd.dot(&v_g);

// type conversion from i64 to f64
let n_list = n_list.mapv(|x| x as f64);

let n_min = n_list
.outer_iter()
.map(|x| x.to_owned() - &c_gd.dot(&v_g))
.map(|x| x.dot(&c_dd_inv.dot(&x)))
.enumerate()
.min_by(|(_, x), (_, y)| x.partial_cmp(y).expect("failed to compare floats"))
.map(|(idx, _)| n_list.index_axis(Axis(0), idx))
.expect("failed to compute argmin")
.to_owned();

return n_min;
match T > 0.0 {
false => hard_argmin(n_list, c_dd_inv, vg_dash),
true => soft_argmin(n_list, c_dd_inv, vg_dash, T),
}
}
44 changes: 44 additions & 0 deletions src/helper_functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use ndarray::{Array1, Array2, ArrayView, Axis, Ix2};

pub fn hard_argmin(
n_list: Array2<f64>,
c_dd_inv: ArrayView<f64, Ix2>,
vg_dash: Array1<f64>,
) -> Array1<f64> {
let argmin_index = n_list
.outer_iter()
.map(|x| x.to_owned() - &vg_dash)
.map(|x| x.dot(&c_dd_inv.dot(&x)))
.enumerate()
.min_by(|(_, x), (_, y)| x.partial_cmp(y).expect("failed to compare floats"))
.map(|(idx, _)| idx);

match argmin_index {
Some(idx) => n_list.index_axis(Axis(0), idx).to_owned(),
None => panic!("failed to compute argmin"),
}
}

#[allow(non_snake_case)]
pub fn soft_argmin(
n_list: Array2<f64>,
c_dd_inv: ArrayView<f64, Ix2>,
vg_dash: Array1<f64>,
T: f64,
) -> Array1<f64> {
let F: Array1<f64> = n_list
.outer_iter()
.map(|x| x.to_owned() - &vg_dash)
.map(|x| x.dot(&c_dd_inv.dot(&x)))
.map(|x| -x / T)
.collect();

let max_value = F.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
let mut F = F - max_value;
F.mapv_inplace(f64::exp);

let sum_weights = F.sum_axis(Axis(0));
let weighted_n_list = &n_list * &F.insert_axis(Axis(1));
let n_min = weighted_n_list.sum_axis(Axis(0)) / &sum_weights;
n_min.to_owned()
}
10 changes: 8 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pyo3::prelude::{pymodule, PyModule, PyResult, Python};

mod charge_configurations;
mod closed_dots;
mod helper_functions;
mod open_dots;

#[pymodule]
Expand Down Expand Up @@ -36,23 +37,27 @@ fn qarray_rust_core(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
}

#[pyfn(m)]
#[allow(non_snake_case)]
fn ground_state_open<'py>(
py: Python<'py>,
v_g: PyReadonlyArray2<f64>,
c_gd: PyReadonlyArray2<f64>,
c_dd_inv: PyReadonlyArray2<f64>,
threshold: f64,
polish: bool,
T: f64,
) -> &'py PyArray2<f64> {
let v_g = v_g.as_array();
let c_gd = c_gd.as_array();
let c_dd_inv = c_dd_inv.as_array();

let results_array = open_dots::ground_state_open_1d(v_g, c_gd, c_dd_inv, threshold, polish);
let results_array =
open_dots::ground_state_open_1d(v_g, c_gd, c_dd_inv, threshold, polish, T);
results_array.into_pyarray(py)
}

#[pyfn(m)]
#[allow(non_snake_case)]
fn ground_state_closed<'py>(
py: Python<'py>,
v_g: PyReadonlyArray2<f64>,
Expand All @@ -62,14 +67,15 @@ fn qarray_rust_core(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
c_dd_inv: PyReadonlyArray2<f64>,
threshold: f64,
polish: bool,
T: f64,
) -> &'py PyArray2<f64> {
let v_g = v_g.as_array();
let c_gd = c_gd.as_array();
let c_dd = c_dd.as_array();
let c_dd_inv = c_dd_inv.as_array();

let results_array = closed_dots::ground_state_closed_1d(
v_g, n_charge, c_gd, c_dd, c_dd_inv, threshold, polish,
v_g, n_charge, c_gd, c_dd, c_dd_inv, threshold, polish, T,
);
results_array.into_pyarray(py)
}
Expand Down
35 changes: 16 additions & 19 deletions src/open_dots.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
// Computes the ground state for open dots.

use ndarray::{s, Array, Array1, Array2, ArrayView, Axis, Ix1, Ix2};
use ndarray::{s, Array, Array1, Array2, ArrayView, Ix1, Ix2};
use osqp::{CscMatrix, Problem, Settings};
use rayon::prelude::*;

use crate::charge_configurations::open_charge_configurations;
use crate::helper_functions::{hard_argmin, soft_argmin};

#[allow(non_snake_case)]
pub fn ground_state_open_1d<'a>(
v_g: ArrayView<'a, f64, Ix2>,
c_gd: ArrayView<'a, f64, Ix2>,
c_dd_inv: ArrayView<'a, f64, Ix2>,
threshold: f64,
polish: bool,
T: f64,
) -> Array<f64, Ix2> {
let n = v_g.shape()[0];
let m = c_gd.shape()[0];
Expand All @@ -21,23 +24,25 @@ pub fn ground_state_open_1d<'a>(

rows.par_iter_mut().enumerate().for_each(|(j, result_row)| {
let v_g_row = v_g.slice(s![j, ..]);
let n_charge = ground_state_open_0d(v_g_row, c_gd, c_dd_inv, threshold, polish);
let n_charge = ground_state_open_0d(v_g_row, c_gd, c_dd_inv, threshold, polish, T);
result_row.assign(&n_charge);
});
return results_array;
}

#[allow(non_snake_case)]
fn ground_state_open_0d<'a>(
v_g: ArrayView<f64, Ix1>,
c_gd: ArrayView<'a, f64, Ix2>,
c_dd_inv: ArrayView<'a, f64, Ix2>,
threshold: f64,
polish: bool,
T: f64,
) -> Array<f64, Ix1> {
let analytical_solution = analytical_solution(c_gd, v_g);

if analytical_solution.iter().all(|x| x >= &0.0) {
return compute_argmin_open(analytical_solution, c_dd_inv, c_gd, v_g, threshold);
return compute_argmin_open(analytical_solution, c_dd_inv, c_gd, v_g, threshold, T);
} else {
let mut problem = init_osqp_problem_open(v_g, c_gd, c_dd_inv, polish);
let result = problem.solve();
Expand All @@ -48,7 +53,7 @@ fn ground_state_open_0d<'a>(

// clip the continuous part to be positive, as we have turned off polishing in the solver
n_continuous.mapv_inplace(|x| x.max(0.0));
return compute_argmin_open(n_continuous, c_dd_inv, c_gd, v_g, threshold);
return compute_argmin_open(n_continuous, c_dd_inv, c_gd, v_g, threshold, T);
}
}

Expand Down Expand Up @@ -79,33 +84,25 @@ fn init_osqp_problem_open<'a>(
CscMatrix::from(identity.rows())
};

let settings = Settings::default()
.alpha(1.0)
.adaptive_rho(true)
.verbose(false)
.polish(polish);
let settings = Settings::default().alpha(1.0).verbose(false).polish(polish);
return Problem::new(P, q, A, l, u, &settings).expect("failed to setup problem");
}

#[allow(non_snake_case)]
fn compute_argmin_open(
n_continuous: Array1<f64>,
c_dd_inv: ArrayView<f64, Ix2>,
c_gd: ArrayView<f64, Ix2>,
v_g: ArrayView<f64, Ix1>,
threshold: f64,
T: f64,
) -> Array1<f64> {
let vg_dash = c_gd.dot(&v_g);

let n_list = open_charge_configurations(n_continuous, threshold);

let n_min = n_list
.outer_iter()
.map(|x| x.to_owned() - &vg_dash)
.map(|x| x.dot(&c_dd_inv.dot(&x)))
.enumerate()
.min_by(|(_, x), (_, y)| x.partial_cmp(y).expect("failed to compare floats"))
.map(|(idx, _)| n_list.index_axis(Axis(0), idx))
.expect("failed to compute argmin")
.to_owned();
return n_min;
match T > 0.0 {
false => hard_argmin(n_list, c_dd_inv, vg_dash),
true => soft_argmin(n_list, c_dd_inv, vg_dash, T),
}
}

0 comments on commit bc08dcb

Please sign in to comment.