**Obtain CALDERA decomposition of a tensor**

In [None]:
import torch

from dataclasses import field, dataclass
from collections import namedtuple
from copy import deepcopy
from tqdm import tqdm

In [None]:
from src.caldera.utils.dataclasses import CalderaParams
from src.caldera.utils.quantization import QuantizerFactory
from src.caldera.decomposition.alg import caldera

In [None]:
quant_factory_Q = QuantizerFactory(method="uniform", block_size=64)
quant_factory__LR = QuantizerFactory(method="uniform", block_size=64)

quant_params = CalderaParams(
    compute_quantized_component=True,  
    compute_low_rank_factors=True,      
    Q_bits=4,                           
    L_bits=4,                          
    R_bits=4,
    rank=16,
    iters=20,
    lplr_iters=5,
    activation_aware_LR=True,
    update_order=["Q", "LR"],
    quant_factory_Q=quant_factory_Q,
    quant_factory_LR=quant_factory__LR,
    rand_svd=False,
    sigma_reg=1e-8                             
)

In [None]:
torch.manual_seed(42)

W = torch.randn(1024, 1024)
X = torch.eye(1024, 128)
H = torch.matmul(X, X.T)

In [None]:
caldera_decom = caldera(
    quant_params=quant_params,
    W = W,
    H = H,
    device="cpu",
    use_tqdm=True,
    scale_W=True
)

In [None]:
print(caldera_decom)

In [None]:
print(caldera_decom.L.shape)
print(caldera_decom.R.shape)

In [None]:
len(caldera_decom.errors['Q'])