# Tensor Network Contraction Optimization and Sampling with `tnco` 

This notebook demonstrates how to use the `tnco` library to optimize the contraction order of tensor networks and perform efficient circuit sampling. We cover quantum circuits (Cirq and Qiskit), arbitrary tensor networks, and advanced sampling techniques.

In [1]:
import cirq
import qiskit
import pickle
import numpy as np
import quimb.tensor as qt
import more_itertools as mit
from random import Random
from qiskit.circuit.random import random_circuit as qiskit_random_circuit
from tnco.app import Optimizer, Tensor, TensorNetwork
from tnco.app.circuit import Sampler
import tnco.utils.tn as tn_utils

## 1. Initializing the Optimizer

The `Optimizer` class is the main entry point for finding efficient contraction paths. It can be configured with different optimization methods and constraints.

- **Method**: The default optimization method is simulated annealing (`method='sa'`).
- **Memory Constraints (`max_width`)**: 
    - If `max_width` is not provided, the optimizer seeks the lowest total cost (FLOPs) without memory limits.
    - If `max_width` is provided, the optimizer introduces **index slicing** to ensure that every intermediate tensor fits within the specified memory limit (expressed as $2^{max\_width}$ elements).

In [2]:
# Optimizer for unconstrained memory
opt = Optimizer()

# Optimizer with a maximum intermediate tensor width of 2 (i.e., max 2^2 = 4 elements)
opt_fw = Optimizer(max_width=2)

## 2. Optimizing `cirq.Circuit` Objects

The `Optimizer` natively supports `cirq.Circuit` objects. You can specify initial and final states for the qubits to define the contraction problem.

In [3]:
# Generate a random Cirq circuit
circuit = cirq.testing.random_circuit(qubits=8, n_moments=16, op_density=1)
qubits = sorted(circuit.all_qubits())

The `optimize` method returns:
1. A `TensorNetwork` object representing the circuit.
2. A list of `ContractionResults` containing the optimized paths.

**State Specification:**
- By default, initial and final states are assumed to be $|0\rangle$.
- You can provide a dictionary to specify states per qubit.
- Use `None` to leave qubits open (uncontracted).

In [4]:
# Optimization parameters
opt_params = {
    "betas": (0, 1e5),
    "initial_state": '+',  # All qubits start in the |+> state
    "final_state":
        None,  # All qubits are left open (resulting in a final state vector)
    "n_steps": 1_000,
    "n_runs": 4
}

# Optimize for infinite memory
tn, res = opt.optimize(circuit, **opt_params)

# Optimize with finite width constraints
tn_fw, res_fw = opt_fw.optimize(circuit, **opt_params)

### Analyzing Contraction Results

The `ContractionResults` object provides detailed performance metrics:
- `cost`: Contraction cost in floating-point operations (FLOPs).
- `path`: The optimal contraction sequence.
- `runtime_s`: Time spent on the optimization process.
- `slices`: Indices selected for slicing (applicable to finite width).

In [5]:
# Sort results by cost (lowest first)
res = sorted(res, key=lambda x: x.cost)
res_fw = sorted(res_fw, key=lambda x: x.cost)


# Helper to calculate average runtime
def avg_runtime(r):
    return sum(x.runtime_s for x in r) / len(r)


print(f"# Infinite Memory | Avg Runtime: {avg_runtime(res):.2g}s")
print(f"# Infinite Memory | log10(FLOP): {res[0].cost.log10():.2g}")
print("# -")
print(f"# Finite Width    | Avg Runtime: {avg_runtime(res_fw):.2g}s")
print(f"# Finite Width    | log10(FLOP): {res_fw[0].cost.log10():.2g}")
print(f"# Finite Width    | Sliced indices: {len(res_fw[0].slices)}")

# Extract the best paths
path = res[0].path
path_fw = res_fw[0].path

# Infinite Memory | Avg Runtime: 0.0067s
# Infinite Memory | log10(FLOP): 4.3
# -
# Finite Width    | Avg Runtime: 0.0089s
# Finite Width    | log10(FLOP): 16
# Finite Width    | Sliced indices: 44


### Executing the Contraction

You can use the optimized path with external libraries like `quimb` to perform the actual tensor contraction. 

**Note on Indices:** 
Output indices are named `(qubit_name, tag)`, where `tag` is `'i'` for the initial state and `'f'` for the final state.

In [6]:
# 1. Calculate the exact state vector using Cirq's simulator for verification
initial_state_vec = np.ones(2**len(qubits)) / np.sqrt(2**len(qubits))
exact_final_state = cirq.Simulator().simulate(
    circuit, initial_state=initial_state_vec,
    qubit_order=qubits).state_vector()

# 2. Perform tensor network contraction using the optimized path
qt_tn = qt.TensorNetwork(map(qt.Tensor, tn.arrays, tn.ts_inds))
final_state = qt_tn.contract(optimize=path, output_inds=tn.output_inds)

# 3. Post-process the result to match the Cirq state vector format
# Re-index: (Qubit, 'f') -> Qubit
final_state.reindex({ind: ind[0] for ind in final_state.inds}, inplace=True)

# Transpose indices to match the qubit order
final_state.transpose(*qubits, inplace=True)

# 4. Validate results
np.testing.assert_allclose(final_state.data.ravel(),
                           exact_final_state,
                           atol=1e-5)
print("Contraction verified successfully!")

Contraction verified successfully!


## 3. Quantum Circuit Sampling

`tnco` implements the algorithm by Bravyi-Gosset-Liu for sampling bitstrings without explicitly computing full marginals.

Reference: *"How to Simulate Quantum Measurement without Computing Marginals"*, Phys. Rev. Lett. 128, 220503 (2022).

In [7]:
# Initialize the sampler
sampler = Sampler()

# Sample random bitstrings from the circuit
bitstrings, qubit_order = sampler.sample(circuit,
                                         n_samples=100,
                                         betas=(0, 1e3),
                                         n_steps=100,
                                         normalize=False,
                                         n_runs=2)
print("Sampled Bitstrings:")
for bitstring, n_hits in mit.take(5, bitstrings.items()):
    print(bitstring, "({}/100 hits)".format(n_hits))
if len(bitstrings) > 5:
    print("...")

Sampled Bitstrings:
00011010 (12/100 hits)
11000000 (10/100 hits)
00111010 (9/100 hits)
11001100 (7/100 hits)
00010010 (7/100 hits)
...


### Reusing Optimized Intermediate States

Optimizing partial tensor networks for sampling can be computationally expensive. `tnco` allows you to generate and save an **optimized intermediate state**, which can be reused for multiple sampling runs without re-optimization.

In [8]:
# 1. Generate the optimized intermediate state
state = sampler.sample(circuit,
                       return_intermediate_state_only=True,
                       betas=(0, 1e3),
                       n_steps=100,
                       n_runs=2)

# 2. The state is pickle-compatible, making it easy to store or distribute
pickled_state = pickle.dumps(state)

# 3. Reuse the state for sampling
bitstrings, qubit_order = sampler.sample(pickle.loads(pickled_state),
                                         n_samples=100)
print("Reused state sampling completed.")

Reused state sampling completed.


## 4. Optimizing `qiskit.QuantumCircuit` Objects

The optimizer provides a consistent interface for Qiskit circuits as well.

In [9]:
# Generate a random Qiskit circuit
qiskit_circuit = qiskit.QuantumCircuit(8)
qiskit_circuit = qiskit_circuit.compose(qiskit_random_circuit(8, 16))

# Optimize the Qiskit circuit
tn_qiskit, res_qiskit = opt.optimize(qiskit_circuit, **opt_params)
best_flops = sorted(res_qiskit, key=lambda x: x.cost)[0].cost
print(f"Qiskit Optimization log10(FLOP): {best_flops.log10():.2g}")

Qiskit Optimization log10(FLOP): 4.9


In [10]:
# Sample random bitstrings from the circuit
bitstrings, qubit_order = sampler.sample(circuit,
                                         n_samples=100,
                                         betas=(0, 1e3),
                                         n_steps=100,
                                         normalize=False,
                                         n_runs=2)
print("Sampled Bitstrings:")
for bitstring, n_hits in mit.take(5, bitstrings.items()):
    print(bitstring, "({}/100 hits)".format(n_hits))
if len(bitstrings) > 5:
    print("...")

Sampled Bitstrings:
11000000 (13/100 hits)
11001100 (10/100 hits)
11000100 (9/100 hits)
00011010 (9/100 hits)
11100000 (9/100 hits)
...


## 5. Optimizing Lists of Gates

Alternatively, you can provide a raw list of unitary matrices and their target qubits using tuples: `(matrix, qubits)`.

In [11]:
# Convert Cirq operations to a list of (unitary, qubits) tuples
gate_list = [(cirq.unitary(op), op.qubits) for op in circuit.all_operations()]

# Optimize the gate list
tn_gates, res_gates = opt.optimize(gate_list, **opt_params)
best_gate_res = sorted(res_gates, key=lambda x: x.cost)[0]
print(f"Gate list optimization log10(FLOP): {best_gate_res.cost.log10():.2g}")

Gate list optimization log10(FLOP): 4.3


## 6. Optimizing Arbitrary Tensor Networks

Beyond circuits, you can optimize arbitrary tensor networks by using the `tnco.app.TensorNetwork` class.

In [12]:
from tnco.testing.utils import generate_random_tensors

# Generate a random tensor network metadata
ts_inds, output_inds = generate_random_tensors(n_tensors=10,
                                               n_inds=20,
                                               n_cc=2,
                                               k=3,
                                               n_output_inds=3)

# Assign random dimensions (2 or 3) to indices
all_inds = frozenset(mit.flatten(ts_inds))
dims = {ind: Random().randint(2, 3) for ind in all_inds}

# Create the TensorNetwork object
tn_arbitrary = TensorNetwork([
    Tensor(inds, [dims[i]
                  for i in inds], tags={'id': j})
    for j, inds in enumerate(ts_inds)
],
                             output_inds=output_inds)

# Optimize with fusion (fuse=10 pre-contracts small tensor groups)
new_tn, res_arb = opt.optimize(tn_arbitrary,
                               fuse=10,
                               betas=(0, 1e5),
                               n_steps=1_000,
                               n_runs=4)
print(f"Arbitrary TN Optimization log10(FLOP): {res_arb[0].cost.log10():.2g}")

Arbitrary TN Optimization log10(FLOP): 6.3


  warn("Cannot decompose hyper-indices if not "


### Contracting the Arbitrary Network

To execute the contraction, we account for any pre-fused tensors using `tn_utils.contract`.

In [13]:
# Generate random data arrays for the tensors
arrays = [np.random.normal(size=t.dims) for t in tn_arbitrary.tensors]

# Account for 'fuse' by pre-contracting tensors accordingly
fused_inds, fused_output, fused_arrays = tn_utils.contract(
    new_tn.tags['fuse_path'],
    ts_inds=tn_arbitrary.ts_inds,
    output_inds=tn_arbitrary.output_inds,
    dims=tn_arbitrary.dims,
    arrays=arrays)

# Perform final contraction using the best path
best_arb_path = sorted(res_arb, key=lambda x: x.cost)[0].path
final_result = qt.TensorNetwork(map(qt.Tensor, fused_arrays,
                                    new_tn.ts_inds)).contract(
                                        optimize=best_arb_path,
                                        output_inds=new_tn.output_inds)
print("Arbitrary TN contraction completed.")

Arbitrary TN contraction completed.


## 7. Advanced: Using Index Maps

The optimizer can also accept an **Index Map**, a low-level format where each tuple represents an index, its dimension, and the IDs of connected tensors.

**Format:** `(dimension, tensor_id_1, tensor_id_2, ..., '*')` 
- The `*` token designates an output index.

In [14]:
# Construct an Index Map from the previous network
inds_map = {ind: [dims[ind]] for ind in dims}
for idx, tensor in enumerate(tn_arbitrary.tensors):
    for ind in tensor.inds:
        inds_map[ind].append(idx)

# Mark outputs
for ind in tn_arbitrary.output_inds:
    inds_map[ind].append('*')

# Optimize using raw map values
tn_map, res_map = opt.optimize(inds_map.values(),
                               betas=(0, 1e5),
                               n_steps=1_000,
                               output_index_token='*',
                               n_runs=4)
print(f"Index Map Optimization log10(FLOP): {res_map[0].cost.log10():.2g}")

Index Map Optimization log10(FLOP): 6.4
