Skip to content

Commit

Permalink
Fix state ordering in Python (#1122)
Browse files Browse the repository at this point in the history
This change uses the vector returned from the internals directly rather
than converting into a hashmap so that the state ordering can be
preserved for display.

Fixes #1119

Before:

![image](https://github.com/microsoft/qsharp/assets/10567287/4ff3d4d1-021b-4b27-b797-266312cc13cc)

After:

![image](https://github.com/microsoft/qsharp/assets/10567287/eff67e4e-a756-45e3-b246-16829f9befcf)
  • Loading branch information
swernli committed Feb 6, 2024
1 parent 09be372 commit 1fccd44
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
3 changes: 1 addition & 2 deletions pip/src/displayable_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ mod tests;
use num_bigint::BigUint;
use num_complex::{Complex64, ComplexFloat};
use qsc::{fmt_basis_state_label, fmt_complex, format_state_id, get_phase};
use rustc_hash::FxHashMap;
use std::fmt::Write;

#[derive(Clone)]
pub struct DisplayableState(pub FxHashMap<BigUint, Complex64>, pub usize);
pub struct DisplayableState(pub Vec<(BigUint, Complex64)>, pub usize);

impl DisplayableState {
pub fn to_plain(&self) -> String {
Expand Down
29 changes: 19 additions & 10 deletions pip/src/displayable_output/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,39 @@

use num_bigint::BigUint;
use num_complex::Complex;
use rustc_hash::FxHashMap;

use crate::displayable_output::DisplayableState;

#[test]
fn display_neg_zero() {
let s = DisplayableState(
vec![(BigUint::default(), Complex::new(-0.0, -0.0))]
.into_iter()
.collect::<FxHashMap<_, _>>(),
1,
);
let s = DisplayableState(vec![(BigUint::default(), Complex::new(-0.0, -0.0))], 1);
// -0 should be displayed as 0.0000 without a minus sign
assert_eq!("STATE:\n|0⟩: 0.0000+0.0000𝑖", s.to_plain());
}

#[test]
fn display_rounds_to_neg_zero() {
let s = DisplayableState(
vec![(BigUint::default(), Complex::new(-0.00001, -0.00001))]
.into_iter()
.collect::<FxHashMap<_, _>>(),
vec![(BigUint::default(), Complex::new(-0.00001, -0.00001))],
1,
);
// -0.00001 should be displayed as 0.0000 without a minus sign
assert_eq!("STATE:\n|0⟩: 0.0000+0.0000𝑖", s.to_plain());
}

#[test]
fn display_preserves_order() {
let s = DisplayableState(
vec![
(BigUint::from(0_u64), Complex::new(0.0, 0.0)),
(BigUint::from(1_u64), Complex::new(0.0, 1.0)),
(BigUint::from(2_u64), Complex::new(1.0, 0.0)),
(BigUint::from(3_u64), Complex::new(1.0, 1.0)),
],
2,
);
assert_eq!(
"STATE:\n|00⟩: 0.0000+0.0000𝑖\n|01⟩: 0.0000+1.0000𝑖\n|10⟩: 1.0000+0.0000𝑖\n|11⟩: 1.0000+1.0000𝑖",
s.to_plain()
);
}
15 changes: 8 additions & 7 deletions pip/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ use qsc::{
PackageType, SourceMap,
};
use resource_estimator::{self as re, estimate_expr};
use rustc_hash::FxHashMap;
use std::fmt::Write;

#[pymodule]
Expand Down Expand Up @@ -168,10 +167,7 @@ impl Interpreter {
/// pairs of real and imaginary amplitudes.
fn dump_machine(&mut self) -> StateDump {
let (state, qubit_count) = self.interpreter.get_quantum_state();
StateDump(DisplayableState(
state.into_iter().collect::<FxHashMap<_, _>>(),
qubit_count,
))
StateDump(DisplayableState(state, qubit_count))
}

fn run(
Expand Down Expand Up @@ -336,7 +332,13 @@ impl StateDump {
// Pass by value is needed for compatiblity with the pyo3 API.
#[allow(clippy::needless_pass_by_value)]
fn __getitem__(&self, key: BigUint) -> Option<(f64, f64)> {
self.0 .0.get(&key).map(|state| (state.re, state.im))
self.0 .0.iter().find_map(|state| {
if state.0 == key {
Some((state.1.re, state.1.im))
} else {
None
}
})
}

fn __len__(&self) -> usize {
Expand Down Expand Up @@ -459,7 +461,6 @@ impl Receiver for OptionalCallbackReceiver<'_> {
qubit_count: usize,
) -> core::result::Result<(), Error> {
if let Some(callback) = &self.callback {
let state = state.into_iter().collect::<FxHashMap<_, _>>();
let out = DisplayableOutput::State(DisplayableState(state, qubit_count));
callback
.call1(
Expand Down

0 comments on commit 1fccd44

Please sign in to comment.