In [45]:
:dep ndarray = { version = "0.14", features = ["serde"] }
:dep ndarray-linalg = { version = "0.13.1", features = ["intel-mkl-static"]}
:dep plotters = { git = "https://github.com/38/plotters", default_features = false, features = ["evcxr", "all_series"] }
:dep thiserror = "1.0"
:dep serde = { version = "1.0", features = ["derive"] }
:dep serde_json = "1.0"
:dep dimensioned = "0.7"
:dep num-complex = "0.4"

use ndarray::prelude::*;
use ndarray_linalg::*;

pub use ndarray_linalg::c32;
pub type VecC32 = ndarray::Array1<c32>;
pub type VecF32 = ndarray::Array1<f32>;
pub type MatrixC32 = ndarray::Array2<c32>;
pub type MatrixF32 = ndarray::Array2<f32>;


/////////////////////////////////////////////////////////////////////
////////////////////Entanglement Calculations////////////////////////
/////////////////////////////////////////////////////////////////////

pub fn create_dens_matrix(coefs: VecC32) -> MatrixC32 {

  let coefs_conj = coefs.map(|coefs| coefs.conj());

  let a = into_col(coefs);
  let b = into_row(coefs_conj);
  let dens_matrix = a.dot(&b);
  dens_matrix
}

pub fn find_purity(rho_sqrd: MatrixC32)-> f32 {
  let purity = rho_sqrd.trace().unwrap();
  purity.re
}


pub fn find_fidelity(rho: MatrixC32, sigma: MatrixC32) -> f32 {

  let sqrt_rho = find_sqr_root_of_matrix(rho);
  let product = sqrt_rho.dot(&sigma).dot(&sqrt_rho);
  let sqrt_product = find_sqr_root_of_matrix(product);
  (sqrt_product.trace().unwrap()).re
}

pub fn find_concurrence(rho: MatrixC32) -> f32 {
  
  let pauli_y = array![ [  c32::new(0. , 0.)  ,  c32::new(0. , 0.) , c32::new(0. , 0.) ,  c32::new(-1. , 0.)  ] , 
                                              [  c32::new(0. , 0.)  ,  c32::new(0. , 0.) , c32::new(1. , 0.) ,  c32::new(0. , 0.)   ] ,
                                              [  c32::new(0. , 0.)  ,  c32::new(1. , 0.) , c32::new(0. , 0.) ,  c32::new(0. , 0.)   ] ,
                                              [  c32::new(-1. , 0.) ,  c32::new(0. , 0.) , c32::new(0. , 0.) ,  c32::new(0. , 0.)   ] ];

  let rho_star = rho.mapv(|rho| rho.conj());
  let sqrt_rho = find_sqr_root_of_matrix(rho.clone());
  let rho_tilde = pauli_y.dot(&rho_star).dot(&pauli_y);

  let product = sqrt_rho.dot(&rho_tilde).dot(&sqrt_rho);
  let sqrt_product = find_sqr_root_of_matrix(product);

  let (eigvals, _eigvecs) = sqrt_product.eigh(UPLO::Lower).unwrap();
  let mut eigvals = eigvals.to_vec();
  eigvals.sort_by(|a, b| a.partial_cmp(b).unwrap());
  0_f32.max(eigvals[3] - eigvals[2] - eigvals[1] - eigvals[0])
}

pub fn find_negativity(rho: MatrixC32) -> f32 {

  let trace_norm = find_trace_norm(rho);
  (trace_norm - 1.)/2.
}

pub fn find_trace_norm(rho: MatrixC32) -> f32 {

  let rho_partial_transpose = find_partial_transpose(rho);
  let rho_partial_transpose_star   = rho_partial_transpose.mapv(|rho_partial_transpose| rho_partial_transpose.conj());
  let rho_partial_transpose_dagger = rho_partial_transpose_star.t();
  
  let inner_product = (rho_partial_transpose_dagger).dot(&rho_partial_transpose);
  let partial_transpose_norm = find_sqr_root_of_matrix(inner_product); 
  let trace_norm =  partial_transpose_norm.trace().unwrap();
  trace_norm.re
}

pub fn find_log_negativity(rho: MatrixC32) -> f32 {
  let neg = find_negativity(rho);
  (2.*neg + 1.).log2()
}

/////////////////////////////////////////////////////////////////////
/////////////////////////Matrix Operations///////////////////////////
/////////////////////////////////////////////////////////////////////

pub fn find_dim(matrix: MatrixC32)-> i32 {
  let shape = matrix.dim();
  shape.1 as i32
}

pub fn find_matrix_sqrd(matrix: MatrixC32) -> MatrixC32 {
  let matrix_sqrd = matrix.dot(&matrix); 
  matrix_sqrd
}

pub fn find_sqr_root_of_matrix(matrix: MatrixC32) -> MatrixC32 {
  
  let (matrix_d, matrix_s) = rescale_neg_eigvals(matrix);
  let matrix_s_inv = matrix_s.inv().unwrap();

  let sqrt_matrix_d = matrix_d.mapv(|matrix_d| (matrix_d).sqrt());

  let sqrt_product = matrix_s.dot(&sqrt_matrix_d).dot(&matrix_s_inv);
  sqrt_product
}

pub fn rescale_neg_eigvals(rho: MatrixC32) -> (MatrixC32, MatrixC32) {
  
  let (mut eigvals, vecs) = rho.eigh(UPLO::Lower).unwrap();
  let eig_len = eigvals.len() as i32;

  let mut j = 0;
  for _ctr in 0..eig_len {
    if eigvals[j] < 0.0 {
      eigvals[j] = 0.0;
  
      j += 1;
    }
  }

  let eigvals_c32 = eigvals.map(|f| c32::new(*f, 0.0));

  let matrix_d = MatrixC32::from_diag(&eigvals_c32);
  let matrix_s = vecs;

  (matrix_d, matrix_s)
}

pub fn find_schmidt_number(jsi: MatrixF32) -> f32 {
  let jsa = jsi.mapv(|jsi| jsi.sqrt());
  let (_u, s, _v_transpose) = jsa.svd(true , true).unwrap();
  let sum_eigvals_sqrd = s.mapv(|s| s*s).sum();
  let norm_const = 1./sum_eigvals_sqrd;
  let renormed_s = s.mapv(|s| s*(norm_const.sqrt()));
  let sum_eig_sqrd = renormed_s.mapv(|renormed_s| renormed_s.powf(4.)).sum();
  let k = 1./sum_eig_sqrd;
  k
}

pub fn find_partial_transpose(matrix: MatrixC32) -> MatrixC32 {

  let dim = find_dim(matrix.clone()) as usize;
  let mut partial_transpose_matrix = MatrixC32::zeros((dim , dim).f());

  let upper_left_block  = matrix.slice(s! [0..(dim / 2)   , 0..(dim / 2)  ] );
  let upper_right_block = matrix.slice(s! [0..(dim / 2)   , (dim / 2)..dim] );
  let lower_left_block  = matrix.slice(s! [(dim / 2)..dim , 0..(dim / 2)  ] );
  let lower_right_block = matrix.slice(s! [(dim / 2)..dim , (dim / 2)..dim] );

  let upper_right_block_transpose = upper_right_block.t();
  let lower_left_block_transpose = lower_left_block.t();

//TODO: Stack/concatenate
//Find for loops and see how to optimize
  let mut i = 0;
  for _index_1 in 0..dim/2 {
    let mut j = 0;
    for _index_2 in 0..dim/2 {
      partial_transpose_matrix[[i         , j        ]] = upper_left_block[ [i , j] ];
      partial_transpose_matrix[[i         , j + dim/2]] = upper_right_block_transpose[ [i , j] ];
      partial_transpose_matrix[[i + dim/2 , j        ]] = lower_left_block_transpose[ [i , j] ];
      partial_transpose_matrix[[i + dim/2 , j + dim/2]] = lower_right_block[ [i , j] ];
      j += 1;
      }
    i += 1;
    }

  partial_transpose_matrix

}

////////////////////////////////////////////////////////

extern crate plotters;
use plotters::prelude::*;
use core::f32::consts::PI;

let mut points: Vec<(f32, f32)> = Vec::new();
// let mut points_x: Vec<f32> = Vec::new();
// let mut points_y: Vec<f32> = Vec::new();

for i in 0..360{
    let mut theta: f32 = (i as f32)*PI/180.;
    let mut psi_part_entangled: VecC32 = array![c32::new(theta.cos(), 0.0), c32::new(0.0, 0.0), c32::new(0.0, 0.0), c32::new(theta.sin(), 0.0)];
    let mut rho_part_entangled: MatrixC32 = create_dens_matrix(psi_part_entangled);
    let mut concurrence: f32 = find_concurrence(rho_part_entangled);
    points.push( (theta, concurrence) );
//     points_x.push(theta);
//     points_y.push(concurrence);
}

let figure = evcxr_figure((600, 400), |root| {
    root.fill(&WHITE);
    let mut chart = ChartBuilder::on(&root)
        .caption("Concurrence of angle theta", ("Arial", 20).into_font())
        .margin(5)
        .x_label_area_size(30)
        .y_label_area_size(30)
//         .x_desc("theta")
//         .y_desc("Concurrence")
        .build_cartesian_2d(0_f32..2.*PI, 0_f32..1_f32)?;
    
    chart.configure_mesh().draw()?;

//     chart.draw_series(LineSeries::new(
//         (-50..=50).map(|x| x as f32 / 50.0).map(|x| (x, x * x)),
//         &RGBColor(255, 0, 0),
//     ))?;
    
//     chart.draw_series(LineSeries::new(
//         (0..=720).points_x.iter().map(|x| x ).points_y.iter().map(|y| y), &RED,))?;
    
    chart.draw_series(points.iter().map(|(x,y)| Circle::new((*x,*y), 3, BLUE.filled())));

//     chart.configure_series_labels()
//         .background_style(&WHITE.mix(0.8))
//         .border_style(&BLACK)
//         .draw()?;
    Ok(())
});
figure