# Phase retrieval
This notebook explains how to retrieve the transmission matrix (TM) of the OPU. 

We start with the **imports**: If you want to exclusively use the OPU then you need just these two imports:

In [1]:
from phase_retrieval_encode import PhaseRetriever
from lightonml.projections.torch import OPUMap

The `PhaseRetriever` object will do all the work to retrieve the TM.

If you want to play a bit with synthetic matrices you need also this:

In [2]:
import torch
import numpy as np
import scipy.stats

from complex import ComplexTensor

We start by setting the device we want to use: `opu` uses the real OPU, while `cpu` uses a synthetic matrix known a priori.

In [3]:
rp_device = "opu"

Next we go to the matrix to recover. The OPU performs this operation:  
$$ Y = |X A|^2$$

With the random matrix on the **right**. 
The transposed TM has size `(n_components, col)`. We define them below:

In [4]:
n_components = 1000
col = 1100

Then we set some parameters for the accuracy of the retriever. The higher these numbers, the better, but they will increase the computational resources needed for the algorithm to run. These below were the default in the paper, so we will go with these:

In [5]:
anchors = 20
circ_N = int(0.5 * 1.5 * col)
n_signals = 2 * circ_N

Then we proceed to the definition of the random matrix:
- if `rp_device` == `opu`, we define the OPUMap object as usual
- otherwise we generate a dummy random matrix

In [6]:
if rp_device == "opu":
    A = OPUMap(n_components=n_components)
else:
    A = ComplexTensor(real=torch.randn(n_components, col), imag=torch.randn(n_components, col))

OPU output is detached from the computational graph.


As a last step, we initialize the retriever object: we pass all the arguments we defined before, along with the `batch_size`, which manages how many rows of the TM will be recovered at the same time. For optimal results, please make sure that this is divisible by the number of rows of the TM.  
Usually $batch size = $ $100$ ~ $500$ tends to yield the maximum speedup

In [7]:
batch_size = 500

retriever = PhaseRetriever(n_components, col, circ_N, n_signals=n_signals, n_anchors=anchors, batch_size=batch_size)

Then you call `fit` to recover the TM

In [8]:
rec_A = retriever.fit(A)

100%|██████████| 150/150 [00:00<00:00, 28988.88it/s]


torch.Size([100, 3402])
OPUMap was not fit to data. Performing fit on the first batch with default parameters...


100%|██████████| 1/1 [00:00<00:00, 27.78it/s]
100%|██████████| 150/150 [00:00<00:00, 27244.00it/s]

Recovering rows...
torch.Size([100, 3402])



100%|██████████| 1/1 [00:00<00:00, 33.29it/s]

Recovering rows...





The recovered TM is a `ComplexTensor`, which is **NOT** a native dtype in pytorch, since complex tensor are not fully supported yet at the time of writing.  
You can just go back to numpy by calling `.numpy()` on the TM:

In [9]:
rec_A.numpy()

array([[ 0.03065051+0.2785847j , -0.06160724-0.31389287j,
         0.12339653-0.2755082j , ...,  0.21879664-0.52970487j,
         0.19090652-0.06901937j, -0.1025134 +0.09424819j],
       [ 0.3519583 +0.0348193j , -0.3281318 +0.18054375j,
        -0.08244599+0.28247464j, ..., -0.17507008+0.25962567j,
        -0.1248703 +0.002277j  , -0.07638101-0.40856355j],
       [ 0.4054241 -0.18987256j, -0.18094121-0.18648839j,
         0.04063346-0.19954702j, ..., -0.23087256+0.11478726j,
        -0.03398113+0.02732213j,  0.1612791 -0.37866104j],
       ...,
       [ 0.0246226 -0.114633j  , -0.44970965+0.40815488j,
        -0.08802416+0.29696688j, ...,  0.0646004 +0.0685879j ,
         0.19173981-0.24357633j, -0.19181368+0.1938564j ],
       [ 0.20641218+0.19747394j, -0.32064572-0.03209302j,
        -0.13389167+0.04634472j, ...,  0.26521596+0.12376404j,
         0.30993614-0.01816016j, -0.14432073+0.04973661j],
       [ 0.48918098+0.4864396j ,  0.00594345+0.06432842j,
         0.00494148-0.06564788

We can also run some tests to see how close the two matrices are by performing a random projection **with modulus square** with the original and recovered matrix:

In [1]:
def test_matrix(retriver, A, reconstructed_A, dummy):
    """
    Computes statistics of the random projection computed with the original and reconstructed matrix
    
    retriver: retirever object
    A: ComplexTensor or OPUMap
    reconstructed_A: Complex Tensor, the reconstructed projection matrix
    dummy: torch tensor, a random binary input 
    
    """
    if type(A) == ComplexTensor:
        original_input = (A @ dummy).abs() ** 2
    else:
        original_input = A(dummy.bool().T).T.float()

    rec_input = (reconstructed_A @ dummy).abs() ** 2

    min_range_orig, max_range_orig = torch.min(original_input).item(), torch.max(original_input).item()
    mean_orig, std_orig = torch.mean(original_input).item(), torch.std(original_input).item()

    min_range_rec, max_range_rec = torch.min(rec_input).item(), torch.max(rec_input).item()
    mean_rec, std_rec = torch.mean(rec_input).item(), torch.std(rec_input).item()

    MSE = np.linalg.norm(rec_input.numpy() - original_input.numpy()) / original_input.shape[0]
    
    if np.isnan(reconstructed_A.numpy()).any():
        cross_correlation, p_value = 0, 0
    else:
        cross_correlation, p_value = scipy.stats.pearsonr(rec_input.numpy().squeeze(), original_input.numpy().squeeze())

    print("MSE = {}\nCross correlation = {}".format(MSE, cross_correlation))
    print("\toriginal input: min = {}\tmax = {}\tmean = {}\tstd = {}"
          .format(min_range_orig, max_range_orig, mean_orig, std_orig))
    print("\tRange of values (rec): min = {}\tmax = {}\tmean = {}\tstd = {}"
          .format(min_range_rec, max_range_rec, mean_rec, std_rec))

    total_time = sum(retriver.time_logger.values())
    time_no_RP = total_time - retriver.time_logger["RP"]

    print("\nRetrieval time (log) = {0:3.4f} s\nTime no RP = {1:3.4f} s\n".format(total_time, time_no_RP))

    return 

In [11]:
dummy = torch.randint(low=0, high=2, size=(col, 1)).float()

test_matrix(retriever, A, rec_A, dummy)

MSE = 0.5006128692626953
Cross correlation = 0.9777085192479201
	original input: min = 0.0	max = 134.0	mean = 21.579999923706055	std = 23.407751083374023
	Range of values (rec): min = 0.17202019691467285	max = 131.1925811767578	mean = 21.436668395996094	std = 23.952713012695312

Retrieval time (log) = 3.9967 s
Time no RP = 0.0899 s

