# Metric functions

## Overview
Metrics quantify quantum states and circuits. They are used to evaluate the quality of a circuit for generating a target graph state. 
In this context, we can think of metrics as cost functions.

We will consider the following:
1. General features of metrics/cost functions
2. `Infidelity`, `TraceDistance`, and `CircuitDepth`
3. Joint metrics (trading off between metric classes)

## Metric objects

Current metrics include `Infidelity`, `TraceDistance`, and `CircuitDepth`. These are all subclasses of `MetricBase`, which is an abstract class.
Each metric object logs evaluated values every `log_step` number of steps. 
Metric functions are evaluated on a given state and/or circuit as `metric.evaluate(state, circuit)`, returning a scalar value.
Metrics may only depend on the produced state(e.g. `Infidelity`,  `TraceDistance`), or only depend on the circuit (e.g. `CircuitDepth`). However, we require both arguments such that the solver can be agnostic to the metric type (i.e. so that it can pass both arguments without having to check the type of metric which it is running).

### `Infidelity`
Evaluated as the $1 - F(\rho, \sigma)$ where $F(\rho, \sigma)$ is the fidelity between the target state, $\sigma$, and the produced state, $\rho$.

### `TraceDistance`
Evaluated as the trace distance between the target state and the produced state.

### `CircuitDepth` 
Evaluated as the depth of the circuit, denoted as the number of layers of gates in the circuit.

In [2]:
""" Evaluating metrics """
from graphiq.benchmarks.circuits import (
    ghz4_state_circuit,
    linear_cluster_4qubit_circuit,
)
import graphiq.metrics as met

# consider a 4-qubit GHZ target state
ghz4_circuit, ghz4_target = ghz4_state_circuit()

# initialize metrics
infidelity = met.Infidelity(ghz4_target)
trace_dist = met.TraceDistance(ghz4_target)
circ_depth = met.CircuitDepth()

# Let's look at optimal results
print(f"Cost functions results on perfect state/circuit:")
print(f"Infidelity: {infidelity.evaluate(ghz4_target, ghz4_circuit)}")
print(f"Trace distance: {trace_dist.evaluate(ghz4_target, ghz4_circuit)}")
print(f"Circuit depth: {circ_depth.evaluate(ghz4_target, ghz4_circuit)}")

# look at the logged values
print(f"\nInfidelity log: {infidelity.log}")
print(f"\nTrace Distance log: {trace_dist.log}")
print(f"\nCircuit depth log: {circ_depth.log}")

Cost functions results on perfect state/circuit:
Infidelity: 0.0
Trace distance: 0.0
Circuit depth: 8

Infidelity log: [0.0]

Trace Distance log: [0.0]

Circuit depth log: [8]


### CircuitDepth metric: normalization

While the `Infidelity` and `TraceDistance` metrics have an obvious normalization, this is not the case for `CircuitDepth`. In the example above, we did not normalize circuit depth at all. However, we also allow a `depth_penalty` function to be defined.

In [3]:
""" CircuitDepth metric: normalization """

circ_depth_quadratic = met.CircuitDepth(depth_penalty=lambda x: (x / 16) ** 2)
print(
    f"Circuit depth penalty: {circ_depth_quadratic.evaluate(ghz4_target, ghz4_circuit)}"
)

Circuit depth penalty: 0.25


### Function implementation
`Infidelity` and `TraceDistance` are currently only implemented in the density matrix representation (and the state input must reflect this--currently the state input is a numpy array, but shortly a change will come in and it will be a `QuantumState` object. Nevertheless, the `QuantumState` object must have a density matrix representation). An upcoming change will add the option to run `Infidelity` in stabilizer formalism.
They are implemented from helper functions in `graphiq/backends/density_matrix/functions.py`.
`CircuitDepth` is implemented from a `depth` attribute in the circuit class.

## Joint metrics

It can be useful to consider multiple metrics at once in our cost function. 


In [8]:
""" Combo metric, default weighting """

combo_metric = met.Metrics([infidelity, trace_dist, circ_depth_quadratic])
print(
    f"Combined metric on correct state/circuit: {combo_metric.evaluate(ghz4_target, ghz4_circuit)}"
)

combo_metric = met.Metrics(
    [infidelity, trace_dist, circ_depth_quadratic], metric_weight=[0.4, 0.4, 0.2]
)
print(f"Weighted metric: {combo_metric.evaluate(ghz4_target, ghz4_circuit)}")

Combined metric on correct state/circuit: 0.25
Weighted metric: 0.05
