In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

class CPDecomposition:
    def __init__(self, tensor, mask, rank):
        """
        Initialize CP decomposition with masking.

        Args:
        - tensor: torch.Tensor, the tensor to decompose.
        - mask: torch.Tensor, binary mask of the same shape as tensor (1 for observed, 0 for missing).
        - rank: int, the target rank for decomposition.
        """
        assert tensor.shape == mask.shape, "Tensor and mask must have the same shape."
        self.tensor = tensor
        self.mask = mask
        self.rank = rank
        self.dims = tensor.shape
        self.factors = [torch.randn(dim, rank, requires_grad=True) for dim in self.dims]  # Random initialization

    def reconstruct(self):
        """
        Reconstruct tensor from factor matrices.
        """
        R = self.rank
        recon = torch.zeros_like(self.tensor)
        for r in range(R):
            component = torch.ger(self.factors[0][:, r], self.factors[1][:, r])  # Outer product for mode 1 and 2
            for mode in range(2, len(self.dims)):
                component = component.unsqueeze(-1) * self.factors[mode][:, r]  # Outer product across modes
            recon += component
        return recon

    def loss(self, reconstruction):
        """
        Compute Frobenius norm of the reconstruction error on observed entries only.
        """
        error = self.mask * (self.tensor - reconstruction)
        return torch.norm(error) ** 2

    def optimize(self, lr=0.001, max_iter=1000, tol=1e-6, reg_lambda=0.01):
        """
        Perform CP decomposition using gradient-based optimization with masking.

        Args:
        - lr: float, learning rate.
        - max_iter: int, maximum number of iterations.
        - tol: float, tolerance for convergence.
        - reg_lambda: float, regularization coefficient.

        Returns:
        - factors: List of torch.Tensor, factor matrices for each mode.
        """
        optimizer = optim.Adam(self.factors, lr=lr)  # Use Adam optimizer
        prev_loss = float('inf')

        for iteration in range(max_iter):
            optimizer.zero_grad()
            reconstruction = self.reconstruct()
            loss = self.loss(reconstruction)

            # Add L2 regularization
            for factor in self.factors:
                loss += reg_lambda * torch.norm(factor) ** 2

            loss.backward()
            optimizer.step()

            # Monitor loss
            print(f"Iteration {iteration + 1}, Loss: {loss.item():.6f}")
            if abs(prev_loss - loss.item()) < tol:
                print("Converged.")
                break
            prev_loss = loss.item()

        return [factor.detach() for factor in self.factors]


# Generate a sample tensor (3D example)
torch.manual_seed(42)
I, J, K = 4, 5, 6  # Dimensions of the tensor
R_true = 3  # True rank of the tensor

# Generate true factor matrices
A_true = torch.randn(I, R_true)
B_true = torch.randn(J, R_true)
C_true = torch.randn(K, R_true)

# Construct the tensor using these factors
tensor = torch.zeros(I, J, K)
for r in range(R_true):
    tensor += torch.ger(A_true[:, r], B_true[:, r]).unsqueeze(-1) * C_true[:, r]

# Add noise
tensor += 0.1 * torch.randn(I, J, K)

# Create a random mask (e.g., 80% observed)
mask = (torch.rand_like(tensor) < 0.8).float()

# Perform CP decomposition with masking
cp = CPDecomposition(tensor, mask, rank=R_true)
factors = cp.optimize(lr=0.001, max_iter=10000, tol=1e-6, reg_lambda=0.01)

# Evaluate reconstruction
reconstructed_tensor = cp.reconstruct()
masked_error = torch.norm(mask * (tensor - reconstructed_tensor)) / torch.norm(mask * tensor)
print(f"Reconstruction Error on Observed Entries: {masked_error.item():.6f}")

Iteration 1, Loss: 235.789047
Iteration 2, Loss: 235.076782
Iteration 3, Loss: 234.368423
Iteration 4, Loss: 233.664047
Iteration 5, Loss: 232.963669
Iteration 6, Loss: 232.267349
Iteration 7, Loss: 231.575089
Iteration 8, Loss: 230.886917
Iteration 9, Loss: 230.202896
Iteration 10, Loss: 229.523102
Iteration 11, Loss: 228.847519
Iteration 12, Loss: 228.176193
Iteration 13, Loss: 227.509155
Iteration 14, Loss: 226.846436
Iteration 15, Loss: 226.188004
Iteration 16, Loss: 225.533905
Iteration 17, Loss: 224.884155
Iteration 18, Loss: 224.238754
Iteration 19, Loss: 223.597733
Iteration 20, Loss: 222.961044
Iteration 21, Loss: 222.328766
Iteration 22, Loss: 221.700775
Iteration 23, Loss: 221.077194
Iteration 24, Loss: 220.457993
Iteration 25, Loss: 219.843048
Iteration 26, Loss: 219.232498
Iteration 27, Loss: 218.626205
Iteration 28, Loss: 218.024185
Iteration 29, Loss: 217.426453
Iteration 30, Loss: 216.832993
Iteration 31, Loss: 216.243683
Iteration 32, Loss: 215.658646
Iteration 33, Los

In [58]:
import torch
import torch.optim as optim

class ConstrainedCPDecomposition_v1:
    def __init__(self, tensor, mask, rank, constraint=None):
        """
        Initialize Constrained CP decomposition.

        Args:
        - tensor: torch.Tensor, the tensor to decompose.
        - mask: torch.Tensor, binary mask of the same shape as tensor (1 for observed, 0 for missing).
        - rank: int, the target rank for decomposition.
        - constraint: torch.Tensor, tensor representing constraint satisfaction (same shape as tensor).
        """
        if constraint is None:
            constraint = torch.ones_like(tensor)
        assert tensor.shape == mask.shape == constraint.shape, "Tensor, mask, and constraint must have the same shape."
        self.tensor = tensor
        self.mask = mask
        self.constraint = constraint
        self.rank = rank
        self.dims = tensor.shape
        self.factors = [torch.randn(dim, rank, requires_grad=True) for dim in self.dims]  # Random initialization

    def reconstruct(self):
        """
        Reconstruct tensor from factor matrices.
        """
        R = self.rank
        recon = torch.zeros_like(self.tensor)
        for r in range(R):
            component = torch.ger(self.factors[0][:, r], self.factors[1][:, r])  # Outer product for mode 1 and 2
            for mode in range(2, len(self.dims)):
                component = component.unsqueeze(-1) * self.factors[mode][:, r]  # Outer product across modes
            recon += component
        return recon

    def optimize(self, lr=0.001, max_iter=1000, tol=1e-6, reg_lambda=0.01, constraint_lambda=1):
        """
        Perform Constrained CP decomposition using gradient-based optimization.

        Args:
        - lr: float, learning rate.
        - max_iter: int, maximum number of iterations.
        - tol: float, tolerance for convergence.
        - reg_lambda: float, regularization coefficient for L2 regularization.
        - constraint_lambda: float, penalty coefficient for constraint violations.

        Returns:
        - factors: List of torch.Tensor, factor matrices for each mode.
        """
        optimizer = optim.Adam(self.factors, lr=lr)  # Use Adam optimizer
        prev_loss = float('inf')

        for iteration in range(max_iter):
            optimizer.zero_grad()

            # Reconstruct the tensor
            reconstruction = self.reconstruct()

            # Compute the loss function
            def loss_fn():
                # First term: Reconstruction error on observed entries weighted by constraints
                error_term = self.constraint * self.mask * (self.tensor - reconstruction)
                mse_loss = torch.norm(error_term) ** 2

                # Second term: Constraint violation penalty
                violation_term = (1 - self.constraint) * torch.clamp(reconstruction - self.tensor, min=0)
                constraint_loss = constraint_lambda * torch.sum(violation_term)

                # Regularization term
                l2_loss = reg_lambda * sum(torch.norm(factor) ** 2 for factor in self.factors)

                # Total loss
                total_loss = mse_loss + constraint_loss + l2_loss
                return total_loss, mse_loss, constraint_loss, l2_loss

            # Compute the loss
            loss, mse_loss, constraint_loss, l2_loss = loss_fn()
            loss.backward()
            optimizer.step()

            # Monitor loss
            print(f"Iteration {iteration + 1}, Loss: {loss.item():.6f}")
            print(f"mse: {mse_loss.item():.6f}, constraint: {constraint_loss.item():.6f}, l2: {l2_loss.item():.6f}")
            if abs(prev_loss - loss.item()) < tol:
                print("Converged.")
                break
            prev_loss = loss.item()

        return [factor.detach() for factor in self.factors]

import torch

# Generate a sample 2D tensor
torch.manual_seed(42)
I, J = 4, 5  # Dimensions of the 2D tensor 
R_true = 2  # True rank of the tensor

# Generate true factor matrices
A_true = torch.randn(I, R_true)
B_true = torch.randn(J, R_true)

# Construct the tensor using these factors
tensor = torch.zeros(I, J)
for r in range(R_true):
    tensor += torch.outer(A_true[:, r], B_true[:, r])

# Add noise
tensor += 0.1 * torch.randn(I, J)

# Create a random mask (80% observed)
mask = (torch.rand_like(tensor) < 0.8).float()

# Create a random constraint tensor
constraint = (torch.rand_like(tensor) > 0.2).float()


tensor[constraint == 0] = 0


# Perform constrained CP decomposition
cp = ConstrainedCPDecomposition_v1(tensor, mask, rank=R_true, constraint=constraint)
factors = cp.optimize(lr=0.001, max_iter=10000, tol=1e-6, reg_lambda=0.01, constraint_lambda=1.0)

# Evaluate reconstruction
reconstructed_tensor = cp.reconstruct()
masked_error = torch.norm(mask * (tensor - reconstructed_tensor)) / torch.norm(mask * tensor)
print(f"Reconstruction Error on Observed Entries: {masked_error.item():.6f}")

# Print original and reconstructed tensors
print("\nOriginal tensor:")
print(tensor)
print("\nReconstructed tensor:")
print(reconstructed_tensor)


Iteration 1, Loss: 16.438507
mse: 15.821542, constraint: 0.423571, l2: 0.193394
Iteration 2, Loss: 16.368248
mse: 15.753337, constraint: 0.421692, l2: 0.193221
Iteration 3, Loss: 16.298206
mse: 15.685341, constraint: 0.419816, l2: 0.193048
Iteration 4, Loss: 16.228380
mse: 15.617560, constraint: 0.417945, l2: 0.192876
Iteration 5, Loss: 16.158783
mse: 15.550002, constraint: 0.416078, l2: 0.192704
Iteration 6, Loss: 16.089413
mse: 15.482665, constraint: 0.414215, l2: 0.192532
Iteration 7, Loss: 16.020271
mse: 15.415553, constraint: 0.412357, l2: 0.192361
Iteration 8, Loss: 15.951361
mse: 15.348667, constraint: 0.410503, l2: 0.192191
Iteration 9, Loss: 15.882686
mse: 15.282012, constraint: 0.408653, l2: 0.192021
Iteration 10, Loss: 15.814251
mse: 15.215591, constraint: 0.406808, l2: 0.191851
Iteration 11, Loss: 15.746059
mse: 15.149408, constraint: 0.404968, l2: 0.191683
Iteration 12, Loss: 15.678108
mse: 15.083462, constraint: 0.403132, l2: 0.191514
Iteration 13, Loss: 15.610405
mse: 15

In [61]:
arr = constraint * (tensor - reconstructed_tensor)
arr

tensor([[ 4.8142e-02,  3.1889e-02,  0.0000e+00, -3.4242e-02, -2.8156e-02],
        [-7.6940e-02, -2.7126e-04,  0.0000e+00,  3.5103e-03,  3.3835e-02],
        [-3.3695e-02,  0.0000e+00, -1.1161e+00,  0.0000e+00,  0.0000e+00],
        [ 1.5573e-03,  0.0000e+00,  0.0000e+00, -1.3336e-02,  7.9098e-03]],
       grad_fn=<MulBackward0>)

In [62]:
arr2 = (1 - constraint) * (reconstructed_tensor - tensor)
arr2

tensor([[-0.0000, -0.0000, -0.0229,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0252, -0.0000, -0.0000],
        [ 0.0000, -1.2274,  0.0000, -0.0269, -1.2300],
        [-0.0000, -0.7897, -0.0888,  0.0000, -0.0000]], grad_fn=<MulBackward0>)

In [None]:
import torch
import torch.optim as optim

class ConstrainedCPDecomposition:
    def __init__(self, tensor, mask, rank, constraint=None):
        """
        Initialize Constrained CP decomposition.

        Args:
        - tensor: torch.Tensor, the tensor to decompose.
        - mask: torch.Tensor, binary mask of the same shape as tensor (1 for observed, 0 for missing).
        - rank: int, the target rank for decomposition.
        - constraint: torch.Tensor, tensor representing constraint satisfaction (same shape as tensor).
        """
        if constraint is None:
            constraint = torch.ones_like(tensor)
        assert tensor.shape == mask.shape == constraint.shape, "Tensor, mask, and constraint must have the same shape."
        self.tensor = tensor
        self.mask = mask
        self.constraint = constraint
        self.rank = rank
        self.dims = tensor.shape
        self.factors = [torch.randn(dim, rank, requires_grad=True) for dim in self.dims]  # Random initialization

    def reconstruct(self):
        """
        Reconstruct tensor from factor matrices.
        """
        R = self.rank
        recon = torch.zeros_like(self.tensor)
        for r in range(R):
            component = torch.ger(self.factors[0][:, r], self.factors[1][:, r])  # Outer product for mode 1 and 2
            for mode in range(2, len(self.dims)):
                component = component.unsqueeze(-1) * self.factors[mode][:, r]  # Outer product across modes
            recon += component
        return recon

    def optimize(self, lr=0.001, max_iter=1000, tol=1e-6, reg_lambda=0.01, constraint_lambda=1):
        """
        Perform Constrained CP decomposition using gradient-based optimization.

        Args:
        - lr: float, learning rate.
        - max_iter: int, maximum number of iterations.
        - tol: float, tolerance for convergence.
        - reg_lambda: float, regularization coefficient for L2 regularization.
        - constraint_lambda: float, penalty coefficient for constraint violations.

        Returns:
        - factors: List of torch.Tensor, factor matrices for each mode.
        """
        optimizer = optim.Adam(self.factors, lr=lr)  # Use Adam optimizer
        prev_loss = float('inf')

        for iteration in range(max_iter):
            optimizer.zero_grad()

            # Reconstruct the tensor
            reconstruction = self.reconstruct()

            # Compute the loss function
            def loss_fn():
                # First term: Reconstruction error on observed entries weighted by constraints
                error_term = self.constraint * self.mask * (self.tensor - reconstruction)
                mse_loss = torch.norm(error_term) ** 2

                # Second term: Constraint violation penalty
                violation_term = torch.clamp((1 - self.constraint) * reconstruction, min=0)
                constraint_loss = constraint_lambda * torch.sum(violation_term)  

                # Regularization term
                l2_loss = reg_lambda * sum(torch.norm(factor) ** 2 for factor in self.factors)

                # Total loss
                total_loss = mse_loss + constraint_loss + l2_loss
                return total_loss, mse_loss, constraint_loss, l2_loss

            # Compute the loss
            loss, mse_loss, constraint_loss, l2_loss = loss_fn()
            loss.backward()
            optimizer.step()

            # Monitor loss
            print(f"Iteration {iteration + 1}, Loss: {loss.item():.6f}")
            print(f"mse: {mse_loss.item():.6f}, constraint: {constraint_loss.item():.6f}, l2: {l2_loss.item():.6f}")
            if abs(prev_loss - loss.item()) < tol:
                print("Converged.")
                break
            prev_loss = loss.item()

        return [factor.detach() for factor in self.factors]


# Generate a sample tensor (3D example)
import torch

# Generate a sample 2D tensor
torch.manual_seed(42)
I, J = 4, 5  # Dimensions of the 2D tensor 
R_true = 2  # True rank of the tensor

# Generate true factor matrices
A_true = torch.randn(I, R_true)
B_true = torch.randn(J, R_true)

# Construct the tensor using these factors
tensor = torch.zeros(I, J)
for r in range(R_true):
    tensor += torch.outer(A_true[:, r], B_true[:, r])

# Add noise
tensor += 0.1 * torch.randn(I, J)

# Create a random mask (80% observed)
mask = (torch.rand_like(tensor) < 0.8).float()

# Create a random constraint tensor
constraint = (torch.rand_like(tensor) > 0.2).float()


tensor[constraint == 0] = 0


# Perform constrained CP decomposition
cp = ConstrainedCPDecomposition(tensor, mask, rank=R_true, constraint=constraint)
factors = cp.optimize(lr=0.01, max_iter=1000, tol=1e-6, reg_lambda=0, constraint_lambda=1.0)

# Evaluate reconstruction
reconstructed_tensor = cp.reconstruct()
masked_error = torch.norm(mask * (tensor - reconstructed_tensor)) / torch.norm(mask * tensor)
print(f"Reconstruction Error on Observed Entries: {masked_error.item():.6f}")

# Print original and reconstructed tensors
print("\nOriginal tensor:")
print(tensor)
print("\nReconstructed tensor:")
print(reconstructed_tensor)

Iteration 1, Loss: 16.245113
mse: 15.821542, constraint: 0.423571, l2: 0.000000
Iteration 2, Loss: 15.553709
mse: 15.148753, constraint: 0.404956, l2: 0.000000
Iteration 3, Loss: 14.883753
mse: 14.497003, constraint: 0.386751, l2: 0.000000
Iteration 4, Loss: 14.235314
mse: 13.866352, constraint: 0.368962, l2: 0.000000
Iteration 5, Loss: 13.608415
mse: 13.256819, constraint: 0.351595, l2: 0.000000
Iteration 6, Loss: 13.003011
mse: 12.668355, constraint: 0.334656, l2: 0.000000
Iteration 7, Loss: 12.419025
mse: 12.100877, constraint: 0.318148, l2: 0.000000
Iteration 8, Loss: 11.856312
mse: 11.554236, constraint: 0.302076, l2: 0.000000
Iteration 9, Loss: 11.314688
mse: 11.028247, constraint: 0.286441, l2: 0.000000
Iteration 10, Loss: 10.793919
mse: 10.522671, constraint: 0.271248, l2: 0.000000
Iteration 11, Loss: 10.293721
mse: 10.037224, constraint: 0.256497, l2: 0.000000
Iteration 12, Loss: 9.813770
mse: 9.571581, constraint: 0.242190, l2: 0.000000
Iteration 13, Loss: 9.353706
mse: 9.125

In [64]:
arr = mask * constraint * (tensor - reconstructed_tensor)
arr

tensor([[ 0.0000e+00,  1.1502e-02,  0.0000e+00, -9.3460e-03, -1.1270e-02],
        [-1.2022e-02, -8.1643e-03,  0.0000e+00,  4.7992e-03,  9.5445e-03],
        [-1.1836e-03,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 6.6620e-04,  0.0000e+00,  0.0000e+00,  7.5340e-04,  6.3181e-05]],
       grad_fn=<MulBackward0>)

In [65]:
arr2 = mask * (1 - constraint) * (reconstructed_tensor - tensor)
arr2

tensor([[-0.0000, -0.0000, -0.0596,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
        [ 0.0000, -4.7956,  0.0000, -1.5213, -3.7144],
        [-0.0000, -0.1004, -0.3084, -0.0000, -0.0000]], grad_fn=<MulBackward0>)

In [66]:
torch.randn(4, 2)

tensor([[-0.0716, -0.0909],
        [-1.3297, -0.5426],
        [ 0.5471,  0.6431],
        [-0.7905, -0.9058]])

In [74]:
cond = ["ajsdl;fkjadls"]
a = 9 if not cond else 0
print(a)

0


In [76]:
cond = None
a = 9 if not cond else 0
print(a)

9


In [90]:
import torch
import torch.optim as optim


class TensorDecomposition:
    def __init__(self, tensor, rank, method="cp", mask=None, constraint=None):
        """
        Initialize a general tensor decomposition class supporting CP, Tucker, Tensor Train, and Ring decompositions.

        Args:
        - tensor: torch.Tensor, the tensor to decompose.
        - rank: int or tuple, the target rank for decomposition.
        - method: str, the decomposition method ("cp", "tucker", "train", "ring").
        - mask: torch.Tensor, binary mask of the same shape as tensor (1 for observed, 0 for missing).
        - constraint: torch.Tensor, tensor representing constraint satisfaction (same shape as tensor).
        """
        if mask is None:
            mask = torch.ones_like(tensor)
        if constraint is None:
            constraint = torch.ones_like(tensor)
        assert tensor.shape == mask.shape == constraint.shape, "Tensor, mask, and constraint must have the same shape."

        self.tensor = tensor
        self.mask = mask
        self.constraint = constraint
        self.method = method.lower()

        if self.method == "cp":
            self.rank = rank
            self.dims = tensor.shape
            self.factors = [torch.randn(dim, rank, requires_grad=True) for dim in self.dims]

        elif self.method == "tucker":
            self.rank = rank if isinstance(rank, tuple) else (rank,) * len(tensor.shape)
            self.core = torch.randn(*self.rank, requires_grad=True)
            self.factors = [torch.randn(dim, r, requires_grad=True) for dim, r in zip(tensor.shape, self.rank)]

        elif self.method == "train":
            self.ranks = rank if isinstance(rank, list) else [rank] * (len(tensor.shape) + 1)
            assert self.ranks[0] == self.ranks[-1] == 1, "Tensor Train ranks must start and end with 1."
            self.factors = [
                torch.randn(self.ranks[i], tensor.shape[i], self.ranks[i + 1], requires_grad=True)
                for i in range(len(tensor.shape))
            ]

        elif self.method == "ring":
            self.rank = rank
            self.factors = [
                torch.randn(rank, tensor.shape[i], rank, requires_grad=True)
                for i in range(len(tensor.shape))
            ]

        else:
            raise ValueError(f"Unsupported method: {method}. Choose from 'cp', 'tucker', 'train', or 'ring'.")

    def reconstruct(self):
        """
        Reconstruct the tensor based on the decomposition method.
        """
        if self.method == "cp":
            R = self.rank
            recon = torch.zeros_like(self.tensor)
            for r in range(R):
                component = torch.ger(self.factors[0][:, r], self.factors[1][:, r])
                for mode in range(2, len(self.dims)):
                    component = component.unsqueeze(-1) * self.factors[mode][:, r]
                recon += component
            return recon

        elif self.method == "tucker":
            recon = self.core
            for i, factor in enumerate(self.factors):
                recon = torch.tensordot(recon, factor, dims=[[0], [1]])
            return recon

        elif self.method == "train":
            recon = self.factors[0]
            for factor in self.factors[1:]:
                recon = torch.einsum("...i,ijk->...jk", recon, factor)
            return recon.squeeze()
        
        elif self.method == "ring":
            # Get number of modes
            n_modes = len(self.factors)
            
            # Start with first core
            result = self.factors[0]  # [r, d1, r]
            
            # Contract with middle cores
            for i in range(1, n_modes-1):
                # Contract [r,di,r] with [r,d(i+1),r] -> [r,di,d(i+1),r]
                result = torch.einsum('ijk,klm->ijlm', result, self.factors[i])
                # Reshape to [r,di*d(i+1),r] for next iteration
                s1, s2, s3, s4 = result.shape
                result = result.reshape(s1, s2*s3, s4)
                
            # Final contraction with last core to close the ring
            # Contract [r,d1*...*d(n-1),r] with [r,dn,r] -> [d1*...*d(n-1),dn]
            result = torch.einsum('ijk,klm->jl', result, self.factors[-1])
            
            # Reshape to match original tensor dimensions
            result = result.reshape(self.tensor.shape)
            
            return result

    def optimize(self, lr=0.01, max_iter=1000, tol=1e-6, reg_lambda=0.01, constraint_lambda=1):
        """
        Perform optimization for the specified decomposition method.

        Args:
        - lr: float, learning rate.
        - max_iter: int, maximum number of iterations.
        - tol: float, tolerance for convergence.
        - reg_lambda: float, regularization coefficient for L2 regularization.
        - constraint_lambda: float, penalty coefficient for constraint violations.

        Returns:
        - factors: Optimized factor matrices or tensors for the decomposition method.
        """
        params = self.factors if self.method != "tucker" else [self.core] + self.factors
        optimizer = optim.Adam(params, lr=lr)
        prev_loss = float('inf')

        for iteration in range(max_iter):
            optimizer.zero_grad()

            # Reconstruct the tensor
            reconstruction = self.reconstruct()

            # Compute the loss function
            def loss_fn():
                # Reconstruction error on observed entries weighted by constraints
                error_term = self.constraint * self.mask * (self.tensor - reconstruction)
                mse_loss = torch.norm(error_term) ** 2

                # Constraint violation penalty
                violation_term = torch.clamp((1 - self.constraint) * reconstruction, min=0)
                constraint_loss = constraint_lambda * torch.sum(violation_term)

                # Regularization term
                l2_loss = reg_lambda * sum(torch.norm(factor) ** 2 for factor in params)

                # Total loss
                total_loss = mse_loss + constraint_loss + l2_loss
                return total_loss, mse_loss, constraint_loss, l2_loss

            # Compute the loss
            loss, mse_loss, constraint_loss, l2_loss = loss_fn()
            loss.backward()
            optimizer.step()

            # Monitor loss
            print(f"Iteration {iteration + 1}, Loss: {loss.item():.6f}")
            print(f"mse: {mse_loss.item():.6f}, constraint: {constraint_loss.item():.6f}, l2: {l2_loss.item():.6f}")
            if abs(prev_loss - loss.item()) < tol:
                print("Converged.")
                break
            prev_loss = loss.item()

        return [factor.detach() for factor in params]

# Example Usage
torch.manual_seed(42)
I, J, K = 10, 8, 6
rank = 3
tensor = torch.randn(I, J, K)
mask = (torch.rand_like(tensor) > 0.2).float()
constraint = (torch.rand_like(tensor) > 0.5).float()

# Perform CP decomposition
cp_decomp = TensorDecomposition(tensor, rank=rank, method="cp", mask=mask, constraint=constraint)
factors_cp = cp_decomp.optimize()

# Perform Tucker decomposition
tucker_decomp = TensorDecomposition(tensor, rank=(3, 3, 3), method="tucker", mask=mask, constraint=constraint)
factors_tucker = tucker_decomp.optimize()

# Perform Tensor Train decomposition
tt_decomp = TensorDecomposition(tensor, rank=[1, 3, 3, 1], method="train", mask=mask, constraint=constraint)
factors_tt = tt_decomp.optimize()

# Perform Ring decomposition
ring_decomp = TensorDecomposition(tensor, rank=3, method="ring", mask=mask, constraint=constraint)
factors_ring = ring_decomp.optimize()


Iteration 1, Loss: 784.352600
mse: 654.484802, constraint: 129.197815, l2: 0.669991
Iteration 2, Loss: 756.508240
mse: 630.608704, constraint: 125.240692, l2: 0.658889
Iteration 3, Loss: 729.872803
mse: 607.841248, constraint: 121.383591, l2: 0.647942
Iteration 4, Loss: 704.416565
mse: 586.142639, constraint: 117.636833, l2: 0.637093
Iteration 5, Loss: 680.118225
mse: 565.505127, constraint: 113.986671, l2: 0.626405
Iteration 6, Loss: 656.938904
mse: 545.895874, constraint: 110.427109, l2: 0.615899
Iteration 7, Loss: 634.842773
mse: 527.276123, constraint: 106.961044, l2: 0.605590
Iteration 8, Loss: 613.791809
mse: 509.605499, constraint: 103.590851, l2: 0.595487
Iteration 9, Loss: 593.748962
mse: 492.845276, constraint: 100.318108, l2: 0.585593
Iteration 10, Loss: 574.685913
mse: 476.957703, constraint: 97.152275, l2: 0.575909
Iteration 11, Loss: 556.580811
mse: 461.902802, constraint: 94.111671, l2: 0.566371
Iteration 12, Loss: 539.372375
mse: 447.644714, constraint: 91.170677, l2: 0