In [2]:
%config InlineBackend.figure_formats = ['svg']
import quimb
import quimb.tensor as qtn
import numpy as np
import matplotlib.pyplot as plt

from functions import *

dimension of dataset: 14

In [3]:
L = 9
D = 8
sigma = 0.09

# create a random MPS as our initial target to optimize
psi = qtn.MPS_rand_state(L, bond_dim=D)
Ommd = Ommd(L, sigma)
dataset = get_bars_and_stripes(3)

In [6]:
MPS_dataset = []
for data in dataset:
    MPS_dataset.append(qtn.MPS_computational_state(data))


In [10]:
def loss_fn(psi,dataset,Ommd):
    loss = 0
    for data in dataset:
        #y = qtn.MPS_computational_state(data)
        loss += MMD(psi, data, Ommd, sigma, L, D)
    loss = loss / len(dataset)
    return loss


In [11]:
tnopt = qtn.TNOptimizer(
    # the tensor network we want to optimize
    psi,
    # the functions specfying the loss and normalization
    loss_fn=loss_fn,
    #norm_fn=norm_fn,
    # we specify constants so that the arguments can be converted
    # to the  desired autodiff backend automatically
    loss_constants={"dataset": MPS_dataset, "Ommd": Ommd},
    # the underlying algorithm to use for the optimization
    # 'l-bfgs-b' is the default and often good for fast initial progress
    optimizer="adam",
    # which gradient computation backend to use
    autodiff_backend="jax",
)
tnopt

<TNOptimizer(d=928, backend=jax)>

In [12]:
psi_opt = tnopt.optimize(1)

  0%|          | 0/1 [00:00<?, ?it/s]2025-02-27 11:26:15.691962: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %multiply.671 = f64[503,503,2,2]{2,3,1,0} multiply(f64[503,503,2,2]{2,3,1,0} %constant.692, f64[503,503,2,2]{2,3,1,0} %broadcast.341)

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-02-27 11:26:16.359547: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.667825091s
Constant folding an instruction is taking > 1s:

  %multiply.671 = f64[503,503,2,2]{2,3,1,0} multiply(f64[503,503,2,2]{2,3,1,0} %constant.692, f64[503,503,2,2]{2,3,1,0} %broadcas