In [None]:
%config InlineBackend.figure_formats = ['svg']
import quimb.tensor as qtn
import numpy as np
import xyzpy as xyz
import autoray as ar
import cotengra as ctg
from quimb.tensor.tensor_core import *

In [None]:
tn = qtn.TN2D_rand(6, 6, 4, seed=666)

In [None]:
# %%timeit
tn.apply_to_arrays(ar.lazy.array)
Z = tn.contract_compressed(
    max_bond=4, cutoff=0.0,
    optimize='greedy',
    # optimize=ctg.HyperCompressedOptimizer()
)
# Z = tn.contract(...)
# Z.compute()

In [None]:
# %%timeit
# tn.apply_to_arrays(ar.lazy.array)
# Z = tn.contract_compressed(
#     max_bond=4, cutoff=0.0,
#     optimize='greedy',
# )
# list(Z.ascend())

In [None]:
# %%timeit
# list(Z.ascend())

In [None]:
Z.plot_history_size_footprint()

In [None]:
Z.plot_circuit(color_by="function", layout='compact', colors={"svd_truncated_numba": "red", "qr_stabilized_numba": "orange"})

In [None]:
Z.plot_circuit(color_by="function", layout='balanced', colors={"svd_truncated_numba": "red", "qr_stabilized_numba": "orange"})

In [None]:
Z.plot_circuit(color_by="function", layout='wide', colors={"svd_truncated_numba": "red", "qr_stabilized_numba": "orange"})

In [None]:
Z.plot_graph(colors={"svd_truncated_numba": "red", "qr_stabilized_numba": "orange"})

In [None]:
# Z.compute()

In [None]:
def frequencies(self):
    freq = {}
    for node in Z:
        freq[node.fn_name] = freq.setdefault(node.fn_name, 0) + 1
    return freq

def show(self):
    line = 0
    seen = {}
    queue = [(self, ())]
    while queue:
        t, columns = queue.pop()

        prefix = f'{line:>4} '
        if columns:
            # work out various lines we need to draw based on whether the
            # sequence of parents are themselves the last child of their parent
            prefix += ''.join('│ ' if not p else '  ' for p in columns[:-1])
            prefix += ('└─' if columns[-1] else '├─')

        if t.fn_name not in (None, 'None'):
            item = f"{t.fn_name}{list(t.shape)}"
        else:
            item = f"<-{list(t.shape)}"

        if t in seen:
            # ignore loops, but point to when it was computed
            print(f"{prefix} ... ({item} from line {seen[t]})")
            continue
        print(f"{prefix}{item}")
        seen[t] = line
        line += 1
        deps = t.deps
        islasts = [True] + [False] * (len(deps) - 1)
        for islast, d in zip(islasts, deps):
            queue.append((d, columns + (islast,)))

In [None]:
# [
#     node.args for node in Z
#     if node.fn_name == "tensordot"
# ]

In [None]:
frequencies(Z)

In [None]:
show(Z)

In [None]:
{node.fn_name for node in Z}

In [None]:
Z.compute()

In [None]:
tn.contract_()

In [None]:
import xyzpy as xyz

In [None]:
@xyz.label(['err'], harvester=True)
def run(m, n, r, q=1, seed=None):
    rng = np.random.default_rng(seed)
    X = rng.normal(size=(m, n))
    G = np.random.randn(X.shape[1], r)
    Q = np.linalg.qr(X @ G)[0]  # m x r
    for _ in range(q):
        Q = np.linalg.qr(X.T @ Q)[0]
        Q = np.linalg.qr(X @ Q)[0]
    B = Q.T @ X  # (r, m) (m, n)
    Xc = Q @ B  # (m, r) (r, n)
    return np.linalg.norm(X - Xc)


run.harvest_combos(
    cases=[
        {"m": 10, "n": 100},
    ],
    combos={
        "seed": range(100),
        "q": range(4),
        "r": range(1, 11),
    },
)

In [None]:
run.full_ds.xyz.infiniplot(
    x="r",
    y="err",
    color="q",
    aggregate="seed",
    height=10,
)[0]

In [None]:
X = np.random.randn(3, 3, 3, 3)
XX = np.tensordot(X, X, axes=([0, 1, 2], [0, 1, 2]))
XX
s2, W = do("linalg.eigh", XX)

side = "right"

# if keep is not None:
#     # outer dimension smaller -> exactly low-rank
#     s2 = s2[-keep:]
#     W = W[:, -keep:]

# might have negative eigenvalues due to numerical error from squaring
s2 = do("clip", s2, s2[-1] * 1e-12, None)
s = do("sqrt", s2)

if side == "right":
    factor = decomp.ldmul(s, dag(W))
else:  # 'left'
    factor = decomp.rdmul(W, s)

factor

In [None]:
Y = np.einsum(
    'bi,bj,bk,ijkl->bl',
    np.random.randn(3, 3),
    np.random.randn(3, 3),
    np.random.randn(3, 3),
    X
)
Q = np.linalg.qr(Y)[0]
U, s, VH = np.linalg.svd(Q, full_matrices=False)
VH

In [None]:
Q

In [None]:
L = 8
D = 3
chi = 9
# tn = qtn.TN3D_rand(L, L, L, D)
tn = qtn.TN3D_classical_ising_partition_function(L, L, L, 0.3)

In [None]:
%%time
tn.contract_ctmrg(max_bond=chi, cutoff=0.0).contract(..., optimize='auto-hq')

In [None]:
%%time
tn.contract_boundary(max_bond=chi, cutoff=0.0)