In [None]:
import torch
import itertools

import numpy as np
import qiskit.quantum_info as qi

from torch import nn
from codes import swapper, physical_imposition_operator, get_marginals, get_state_of_rank

### Chose the case

Available cases:
1. N3k2
2. N4k3
3. N5k3
4. N6k4

with `N` the number of qubits of the global state and `k` the number of qubits in the reduced system

In [None]:
d = 2
num_of_qubits = 6
num_of_qubits_in_marginals = 4

checkpoint_path = f"./checkpoints/cpN{num_of_qubits}k{num_of_qubits_in_marginals}"

if num_of_qubits == 4 or num_of_qubits == 6:
    SCALE = 4
else:
    SCALE = 2

### Load de model from checkpoint

In [None]:
class ConvDenoiser(nn.Module):

    def __init__(self):
        super(ConvDenoiser, self).__init__()

        self.encoder = nn.Sequential(
                nn.Conv2d(2,SCALE*60,3,padding=1),
                nn.Tanh(),
                nn.MaxPool2d(2,2),

                nn.Conv2d(SCALE*60,SCALE*120,3,padding=1),
                nn.Tanh(),
                nn.MaxPool2d(2,2),

                nn.Conv2d(SCALE*120,SCALE*60,3,padding=1),
                nn.Tanh(),
                nn.MaxPool2d(2,2)
            )

        self.decoder = nn.Sequential(
                nn.ConvTranspose2d(SCALE*60,SCALE*60,3,padding=1,stride=2),
                nn.Tanh(),

                nn.ConvTranspose2d(SCALE*60,SCALE*120,5,padding=1,stride=2),
                nn.Tanh(),

                nn.ConvTranspose2d(SCALE*120,SCALE*60,6,stride=2),
                nn.Tanh(),
                nn.Conv2d(SCALE*60,2,3)
            )

    def forward(self,x):
        encoded = self.encoder(x)
        output = self.decoder(encoded)

        return output
    
model = ConvDenoiser()
model_params = torch.load(checkpoint_path, map_location=torch.device('cpu'),weights_only=True)
model.load_state_dict(model_params['model_state_dict'])
model.eval()

### Impose marginals in a random density matrix

In [None]:
# To define the data is necessary to define a swapper matrix and label the differents marginal
dn = d**num_of_qubits
swapper_d = swapper(d)

rank = dn
rho_noisless = get_state_of_rank( rank, d, num_of_qubits)


# "k" is an integer such that  1 <= k <= num_of_qubits_in_marginals
k = num_of_qubits_in_marginals 
labels_marginals = list(itertools.combinations( range( num_of_qubits) , r = k))  
target_marginals = get_marginals( rho_noisless, d, num_of_qubits, labels_marginals )

initial_seed = qi.random_density_matrix(dn).data
            
rho_noisy = physical_imposition_operator(d,num_of_qubits,
            initial_seed ,
            target_marginals,
            swapper_d
            )

print(f"Eigenvals of the noisy matrix =  \n\n   {np.linalg.eigvalsh(rho_noisy)}")
print(40*"---")

### Use the model to get a quantum state 

In [None]:
# Prepare the initial seed to implemente the physical imposition operator and, with that, get a matrix that contains the Quantum Marginals

tensor_rho_noisy = torch.Tensor( np.stack((rho_noisy.real, rho_noisy.imag)) )
output = model( tensor_rho_noisy )

predicted_state = output[0] + 1j*output[1]
predicted_state = (predicted_state + torch.conj(predicted_state.T))/2

predicted_state_np = predicted_state.detach().numpy()
predicted_state_np = predicted_state_np/np.trace(predicted_state_np)

predicted_marginals = get_marginals(predicted_state_np, d, num_of_qubits, labels_marginals)

print(f"\nEigenvals renormalized predicted state =  \n   {np.linalg.eigvalsh(predicted_state_np)}")

### Compare the marginals from the predicted sate with the "target_marginals"

In [None]:
predicted_marginals   = get_marginals(predicted_state_np, 2, num_of_qubits, labels_marginals)

for label in target_marginals.keys():
    print(qi.state_fidelity(target_marginals[label], predicted_marginals[label], validate=True))