In [1]:
import logging

import torch
import torch.optim as optim


class TensorFactorization:
    def __init__(
        self, 
        tensor, 
        rank, 
        method="cp", 
        mask=None, 
        constraint=None,  
        is_maximize_c=True,
        device=None,
        prev_state=None,   # Added for continual learning
        verbose=False
    ):
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

        # Move tensors to the device
        tensor = tensor.to(self.device)
        if mask is None:
            mask = torch.ones_like(tensor, device=self.device)
        else:
            mask = mask.to(self.device)
        if constraint is None:
            constraint = torch.ones_like(tensor, device=self.device)
        else:
            constraint = constraint.to(self.device)

        assert tensor.shape == mask.shape == constraint.shape, \
            "Tensor, mask, and constraint must have the same shape."

        self.tensor = tensor
        self.mask = mask
        self.constraint = constraint
        self.is_maximize_c = is_maximize_c

        self.method = method.lower()
        self.total_params = 0  # Initialize total_params

        if self.method == "cp":
            self.rank = rank
            self.dims = tensor.shape
            # Initialize or create factor parameters
            self.factors = [torch.randn(dim, rank, requires_grad=True, device=self.device) 
                            for dim in self.dims]
            self.total_params = sum(factor.numel() for factor in self.factors)

        elif self.method == "tucker":
            self.rank = rank if isinstance(rank, tuple) else (rank,) * len(tensor.shape)
            self.core = torch.randn(*self.rank, requires_grad=True, device=self.device)
            self.factors = [torch.randn(dim, r, requires_grad=True, device=self.device) 
                            for dim, r in zip(tensor.shape, self.rank)]
            self.total_params = self.core.numel() + sum(factor.numel() for factor in self.factors)

        elif self.method == "train":
            # Automatically expand rank to [1, rank, ..., rank, 1] if rank is int
            if isinstance(rank, int):
                rank = [1] + [rank] * (len(tensor.shape) - 1) + [1]

            self.ranks = rank
            assert self.ranks[0] == self.ranks[-1] == 1, "Tensor Train ranks must start and end with 1."
            assert len(self.ranks) == len(tensor.shape) + 1, \
                "Ranks length must be equal to tensor dimensions + 1."
            
            self.factors = [
                torch.randn(self.ranks[i], tensor.shape[i], self.ranks[i + 1], 
                            requires_grad=True, device=self.device)
                for i in range(len(tensor.shape))
            ]
            self.total_params = sum(factor.numel() for factor in self.factors)


        elif self.method == "ring":
            self.rank = rank
            self.factors = [
                torch.randn(rank, tensor.shape[i], rank, requires_grad=True, device=self.device)
                for i in range(len(tensor.shape))
            ]
            self.total_params = sum(factor.numel() for factor in self.factors)
        else:
            raise ValueError(f"Unsupported method: {method}. Choose from 'cp', 'tucker', 'train', or 'ring'.")

        # Attempt to load previous state if provided
        if prev_state is not None:
            self._load_state(prev_state)
            # print("Loaded from prev state!")

        # For logging
        self.loss = None
        self.mse_loss = None
        self.constraint_loss = None
        self.l2_loss = None
        self.iter_end = None

        self.loss_history = {
            "epoch": [],
            "total": [],
            "mse": [],
            "constraint": [],
            "l2": [],
        }

        # Verbosity
        self.verbose = verbose
        if self.verbose:
            logging.info(f"Initialized {method} decomposition with rank {rank} on device {self.device}.")
            logging.info(f"Total parameters: {self.total_params}")

    def _load_state(self, prev_state):
        """
        Simple demonstration of loading factor parameters from prev_state.
        Modify according to how you save the states.
        prev_state could be a list of Tensors or any structure you define.
        """
        if self.method == "cp":
            # Expecting prev_state to be a list of factor Tensors, same shape as self.factors
            if len(prev_state) == len(self.factors):
                for factor, saved_factor in zip(self.factors, prev_state):
                    factor.data.copy_(saved_factor.data)
        elif self.method == "tucker":
            # Suppose prev_state = (core, [factor1, factor2, ...])
            core, prev_factors = prev_state
            self.core.data.copy_(core.data)
            for f1, f2 in zip(self.factors, prev_factors):
                f1.data.copy_(f2.data)
        elif self.method == "train":
            # Suppose prev_state is a list of TT-cores
            for factor, saved_factor in zip(self.factors, prev_state):
                factor.data.copy_(saved_factor.data)
        elif self.method == "ring":
            # Suppose prev_state is a list of ring factors
            for factor, saved_factor in zip(self.factors, prev_state):
                factor.data.copy_(saved_factor.data)

    def get_state(self):
        """
        Return the current factor parameters (and core if tucker, etc.).
        This can be used for continual optimization in TFSampler.
        """
        if self.method == "cp":
            return [factor.clone().detach() for factor in self.factors]
        elif self.method == "tucker":
            return (
                self.core.clone().detach(),
                [factor.clone().detach() for factor in self.factors]
            )
        elif self.method == "train":
            return [factor.clone().detach() for factor in self.factors]
        elif self.method == "ring":
            return [factor.clone().detach() for factor in self.factors]

    def reconstruct(self):
        """
        Reconstruct the tensor based on the decomposition method.
        """
        if self.method == "cp":
            R = self.rank
            recon = torch.zeros_like(self.tensor, device=self.device)
            for r in range(R):
                # Outer product across all modes
                component = self.factors[0][:, r]
                for mode in range(1, len(self.dims)):
                    component = torch.ger(component, self.factors[mode][:, r]).flatten()
                # Reshape it back to self.dims
                recon += component.view(*self.dims)
            return recon

        elif self.method == "tucker":
            # Start with core
            recon = self.core
            # Repeatedly tensordot with factor matrices
            for i, factor in enumerate(self.factors):
                recon = torch.tensordot(recon, factor, dims=[[0], [1]])
            return recon

        elif self.method == "train":
            # TT decomposition reconstruction with einsum
            recon = self.factors[0]
            for factor in self.factors[1:]:
                recon = torch.einsum("...i,ijk->...jk", recon, factor)
            return recon.squeeze()

        elif self.method == "ring":
            # Very rough ring decomposition reconstruction
            n_modes = len(self.factors)
            result = self.factors[0]
            for i in range(1, n_modes - 1):
                result = torch.einsum('ijk,klm->ijlm', result, self.factors[i])
                s1, s2, s3, s4 = result.shape
                result = result.reshape(s1, s2 * s3, s4)
            result = torch.einsum('ijk,klm->jl', result, self.factors[-1])
            result = result.reshape(self.tensor.shape)
            return result

    def optimize(
        self, 
        lr=0.01, 
        max_iter=None, 
        tol=1e-6, 
        mse_tol=1e-1, 
        const_tol=1e-1, 
        reg_lambda=0.0, 
        constraint_lambda=1
    ):
        """
        Perform optimization for the specified decomposition method.

        Args:
          - lr: float, learning rate
          - max_iter: int or None, maximum number of iterations (if None, stop based on tol)
          - tol: float, tolerance for total loss change
          - mse_tol: float, tolerance for MSE loss
          - const_tol: float, tolerance for constraint loss
          - reg_lambda: float, L2 regularization coefficient
          - constraint_lambda: float, penalty coefficient for constraint violations

        Returns:
          - factors: (Optional) Possibly return the updated factors for reuse
        """
        params = []
        if self.method == "tucker":
            params = [self.core] + self.factors
        else:
            params = self.factors

        optimizer = optim.Adam(params, lr=lr)
        # optimizer = optim.Adam(params, lr=lr, weight_decay=0.01)
        # optimizer = optim.SGD(params, lr=lr)
        # optimizer = optim.SGD(params, lr=lr, momentum=0.01)
        prev_loss = float('inf')
        iteration = 0

        min_iter = 10

        while True:
            optimizer.zero_grad()
            reconstruction = self.reconstruct()

            def loss_fn():
                # Count of observed entries
                n_se = torch.sum(self.mask)
                # Count of constraint-violating entries
                n_c = torch.sum(1 - self.constraint)
                n_c = n_c if n_c > 0 else 1
                
                error_term = self.constraint * self.mask * (self.tensor - reconstruction)
                mse_loss = torch.norm(error_term) ** 2 / n_se if n_se > 0 else 0

                if self.is_maximize_c:
                    sign = 1
                    thr = torch.min(self.tensor)
                else:
                    sign = -1
                    thr = torch.max(self.tensor)

                violation_term = torch.clamp(
                    (1 - self.constraint) * sign * (reconstruction - thr),
                    min=0
                )
                constraint_loss = constraint_lambda * torch.sum(violation_term) / n_c

                # L2 regularization
                l2_loss = torch.tensor(0., device=self.device, dtype=mse_loss.dtype)
                for p in params:
                    l2_loss += torch.norm(p) ** 2 / p.numel()
                l2_loss *= reg_lambda

                total_loss = mse_loss + constraint_loss + l2_loss
                return total_loss, mse_loss, constraint_loss, l2_loss

            loss, mse_loss, c_loss, l2_loss = loss_fn()
            loss.backward()
            optimizer.step()

            # Logging
            self.loss = loss
            self.mse_loss = mse_loss
            self.constraint_loss = c_loss
            self.l2_loss = l2_loss

            self.loss_history["epoch"].append(iteration+1)
            self.loss_history["total"].append(loss.item())
            self.loss_history["mse"].append(mse_loss.item())
            self.loss_history["constraint"].append(c_loss.item())
            self.loss_history["l2"].append(l2_loss.item())

            if self.verbose:
                logging.info(f"Iter: {iteration}, Loss: {loss.item()}")
                logging.info(f"MSE: {mse_loss.item()}, CONST: {c_loss.item()}, L2: {l2_loss.item()}")

            # Check for MSE and constraint convergence
            if mse_loss < mse_tol and c_loss < const_tol and iteration > min_iter:
                if self.verbose:
                    logging.info("Converged based on MSE and constraint tolerance.")
                break

            # Check for total loss difference
            if abs(prev_loss - loss.item()) < tol and iteration > min_iter:
                if self.verbose:
                    logging.info("Converged based on total loss tolerance.")
                break

            if max_iter is not None and iteration >= max_iter - 1 and iteration > min_iter:
                if self.verbose:
                    logging.info("Reached max iteration limit.")
                break

            prev_loss = loss.item()
            iteration += 1

        self.iter_end = iteration

        return [p.detach() for p in params]

In [2]:
from time import time
import numpy as np


def benchmarking(
    shape,
    rank=3,
    method="train"
):
    accuracy = []
    time_taken = []

    for _ in range(10):
        # 元テンソル
        original_tensor = torch.randn(shape)   

        # テンソル分解
        rank = 3
        method = "train"

        decomp = TensorFactorization(
            tensor=original_tensor, 
            rank=rank, 
            method=method
        )

        s = time()
        decomp.optimize()
        e = time()
        # print(f"Time: {e-s}")

        reconstructed_tensor = decomp.reconstruct()
        mse = torch.mean((original_tensor - reconstructed_tensor) ** 2)
        # print(f"MSE: {mse}")

        accuracy.append(mse.item())
        time_taken.append(e-s)

    return accuracy, time_taken

In [3]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [2, 3, 4, 5, 6, 7, 8, 9, 10]
method = "train"

for rank in rank_list:
    accuracy, time_taken = benchmarking(
        shape,
        rank=rank,
        method=method
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print()

Rank: 2
Accuracy: 0.998685771226883
Time taken: 5.486046981811524

Rank: 3
Accuracy: 0.9982442557811737
Time taken: 6.0472664594650265

Rank: 4
Accuracy: 1.0008959352970124
Time taken: 7.091440510749817

Rank: 5
Accuracy: 1.000096708536148
Time taken: 5.578088855743408

Rank: 6
Accuracy: 1.0009174346923828
Time taken: 6.026907014846802

Rank: 7
Accuracy: 0.9991318643093109
Time taken: 7.623436284065247

Rank: 8
Accuracy: 0.9985550880432129
Time taken: 6.29259238243103

Rank: 9
Accuracy: 0.9991213262081147
Time taken: 5.931344485282898

Rank: 10
Accuracy: 1.000857025384903
Time taken: 6.459856557846069



In [4]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [2, 3, 4, 5, 6, 7, 8, 9, 10]
method = "ring"

for rank in rank_list:
    accuracy, time_taken = benchmarking(
        shape,
        rank=rank,
        method="ring"
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print()

Rank: 2
Accuracy: 1.0015193164348601
Time taken: 6.328765225410462

Rank: 3
Accuracy: 1.0028085827827453
Time taken: 7.8232728242874146

Rank: 4
Accuracy: 0.9991190433502197
Time taken: 6.5175905466079715

Rank: 5
Accuracy: 0.9983286440372467
Time taken: 6.043672394752503

Rank: 6
Accuracy: 1.0012374877929688
Time taken: 6.21276969909668

Rank: 7
Accuracy: 0.9987181007862092
Time taken: 7.223810911178589

Rank: 8
Accuracy: 0.9992061495780945
Time taken: 7.778573894500733

Rank: 9
Accuracy: 1.0012026131153107
Time taken: 7.613510584831237



KeyboardInterrupt: 

In [None]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [2, 3, 4, 5, 6, 7, 8, 9, 10]
method = "cp"

for rank in rank_list:
    accuracy, time_taken = benchmarking(
        shape,
        rank=rank,
        method="ring"
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print()

In [5]:
from time import time
import numpy as np


def benchmarking(
    shape,
    rank=3,
    method="train"
):
    accuracy = []
    time_taken = []
    step_taken = []

    for _ in range(10):
        # 元テンソル
        original_tensor = torch.randn(shape)   

        # テンソル分解
        rank = 3
        method = "train"

        decomp = TensorFactorization(
            tensor=original_tensor, 
            rank=rank, 
            method=method
        )

        s = time()
        decomp.optimize()
        e = time()
        # print(f"Time: {e-s}")

        reconstructed_tensor = decomp.reconstruct()
        mse = torch.mean((original_tensor - reconstructed_tensor) ** 2)
        # print(f"MSE: {mse}")

        accuracy.append(mse.item())
        time_taken.append(e-s)
        step_taken.append(decomp.loss_history["epoch"][-1])

    return accuracy, time_taken, step_taken

In [7]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [2, 3, 4, 5, 6, 7, 8, 9, 10]
method = "train"

for rank in rank_list:
    accuracy, time_taken, step_taken = benchmarking(
        shape,
        rank=rank,
        method=method
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print(f"Step taken: {np.mean(step_taken)}")
    print()

Rank: 2
Accuracy: 1.0020390748977661
Time taken: 6.7535542249679565
Step taken: 2343.8

Rank: 3
Accuracy: 1.0002936720848083
Time taken: 7.102569842338562
Step taken: 2514.6

Rank: 4
Accuracy: 1.0009375631809234
Time taken: 7.201712512969971
Step taken: 2302.7

Rank: 5
Accuracy: 1.000908410549164
Time taken: 7.112463784217835
Step taken: 2621.6

Rank: 6
Accuracy: 1.004307508468628
Time taken: 6.628024053573609
Step taken: 2510.3

Rank: 7
Accuracy: 1.0007879614830018
Time taken: 6.007804942131043
Step taken: 2191.3

Rank: 8
Accuracy: 1.0001854240894317
Time taken: 6.704331183433533
Step taken: 2403.2

Rank: 9
Accuracy: 1.0000901699066163
Time taken: 6.372687268257141
Step taken: 2423.8

Rank: 10
Accuracy: 0.9984685897827148
Time taken: 7.606977391242981
Step taken: 2714.4



In [8]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [2, 3, 4, 5, 6, 7, 8, 9, 10]
method = "ring"

for rank in rank_list:
    accuracy, time_taken, step_taken = benchmarking(
        shape,
        rank=rank,
        method=method
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print(f"Step taken: {np.mean(step_taken)}")
    print()

Rank: 2
Accuracy: 0.9987840533256531
Time taken: 7.539689731597901
Step taken: 2575.2

Rank: 3
Accuracy: 1.0015023589134215
Time taken: 8.104411673545837
Step taken: 2898.8

Rank: 4
Accuracy: 1.000743991136551
Time taken: 6.626764225959778
Step taken: 2535.0

Rank: 5
Accuracy: 0.9986528515815735
Time taken: 7.0723450660705565
Step taken: 2670.5

Rank: 6
Accuracy: 1.0023885428905488
Time taken: 6.656932139396668
Step taken: 2519.3

Rank: 7
Accuracy: 1.000109899044037
Time taken: 6.311275243759155
Step taken: 2376.7

Rank: 8
Accuracy: 0.9974024713039398
Time taken: 6.328418850898743
Step taken: 2371.4

Rank: 9
Accuracy: 0.9978284299373626
Time taken: 7.016313552856445
Step taken: 2661.0

Rank: 10
Accuracy: 1.000124216079712
Time taken: 6.521069741249084
Step taken: 2475.4



In [6]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [2, 3, 4, 5, 6, 7, 8, 9, 10]
method = "cp"

for rank in rank_list:
    accuracy, time_taken, step_taken = benchmarking(
        shape,
        rank=rank,
        method=method
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print(f"Step taken: {np.mean(step_taken)}")
    print()

Rank: 2
Accuracy: 1.0019038379192353
Time taken: 6.192322373390198
Step taken: 2169.6

Rank: 3
Accuracy: 1.00194793343544
Time taken: 7.4224138259887695
Step taken: 2724.9

Rank: 4
Accuracy: 1.0003471314907073
Time taken: 6.919440889358521
Step taken: 2505.5

Rank: 5
Accuracy: 1.0014607787132264
Time taken: 6.8373579502105715
Step taken: 2536.5

Rank: 6
Accuracy: 0.9976076781749725
Time taken: 8.152693128585815
Step taken: 2321.6

Rank: 7
Accuracy: 1.0016708433628083
Time taken: 7.26188292503357
Step taken: 2415.7

Rank: 8
Accuracy: 1.00036159157753
Time taken: 5.317220759391785
Step taken: 1971.3

Rank: 9
Accuracy: 1.0005896508693695
Time taken: 6.72185447216034
Step taken: 2394.5

Rank: 10
Accuracy: 1.0012098729610444
Time taken: 7.401428127288819
Step taken: 2530.2



In [9]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
method = "train"

for rank in rank_list:
    accuracy, time_taken, step_taken = benchmarking(
        shape,
        rank=rank,
        method=method
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print(f"Step taken: {np.mean(step_taken)}")
    print()

Rank: 10
Accuracy: 0.9998419880867004
Time taken: 6.392620706558228
Step taken: 2421.4

Rank: 20
Accuracy: 1.000387966632843
Time taken: 5.656297183036804
Step taken: 2117.5

Rank: 30
Accuracy: 0.9990622818470001
Time taken: 7.381319355964661
Step taken: 2773.5

Rank: 40
Accuracy: 0.9994167149066925
Time taken: 5.943284726142883
Step taken: 2196.2

Rank: 50
Accuracy: 1.000184828042984
Time taken: 6.265983366966248
Step taken: 2231.8

Rank: 60
Accuracy: 0.9982640624046326
Time taken: 6.768451595306397
Step taken: 2328.4

Rank: 70
Accuracy: 0.998406320810318
Time taken: 7.097509336471558
Step taken: 2437.0

Rank: 80
Accuracy: 1.0008914411067962
Time taken: 6.898731017112732
Step taken: 2423.4

Rank: 90
Accuracy: 1.000275868177414
Time taken: 7.612854647636413
Step taken: 2911.2

Rank: 100
Accuracy: 0.9982196629047394
Time taken: 6.03432605266571
Step taken: 2424.0



In [10]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
method = "ring"

for rank in rank_list:
    accuracy, time_taken, step_taken = benchmarking(
        shape,
        rank=rank,
        method=method
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print(f"Step taken: {np.mean(step_taken)}")
    print()

Rank: 10
Accuracy: 1.0001344203948974
Time taken: 6.37033360004425
Step taken: 2561.7

Rank: 20
Accuracy: 1.0027559459209443
Time taken: 5.56131808757782
Step taken: 2232.0

Rank: 30
Accuracy: 0.9990893423557281
Time taken: 6.2405870199203495
Step taken: 2537.8

Rank: 40
Accuracy: 0.9998206496238708
Time taken: 5.925774693489075
Step taken: 2366.2

Rank: 50
Accuracy: 0.9981076419353485
Time taken: 5.402565693855285
Step taken: 2170.9

Rank: 60
Accuracy: 1.001162850856781
Time taken: 5.616523742675781
Step taken: 2305.2

Rank: 70
Accuracy: 0.9991840958595276
Time taken: 5.926495456695557
Step taken: 2376.9

Rank: 80
Accuracy: 0.9990240037441254
Time taken: 5.999073195457458
Step taken: 2410.2

Rank: 90
Accuracy: 0.999521940946579
Time taken: 7.303771543502807
Step taken: 2934.8

Rank: 100
Accuracy: 0.9983904242515564
Time taken: 5.676851296424866
Step taken: 2265.2



In [11]:
# ------------------------------------------------------------------------------
shape = (7,) * 6

rank_list = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
method = "cp"

for rank in rank_list:
    accuracy, time_taken, step_taken = benchmarking(
        shape,
        rank=rank,
        method=method
    )

    print(f"Rank: {rank}")
    print(f"Accuracy: {np.mean(accuracy)}")
    print(f"Time taken: {np.mean(time_taken)}")
    print(f"Step taken: {np.mean(step_taken)}")
    print()

Rank: 10
Accuracy: 0.9997683227062225
Time taken: 6.618064308166504
Step taken: 2647.3

Rank: 20
Accuracy: 1.0010520696640015
Time taken: 7.25626494884491
Step taken: 2911.6

Rank: 30
Accuracy: 0.9988561868667603
Time taken: 6.383202576637268
Step taken: 2555.8

Rank: 40
Accuracy: 1.0018435418605804
Time taken: 5.53312463760376
Step taken: 2232.0

Rank: 50
Accuracy: 1.0001993834972382
Time taken: 5.640861392021179
Step taken: 2237.2

Rank: 60
Accuracy: 1.0012370705604554
Time taken: 5.654657244682312
Step taken: 2256.7

Rank: 70
Accuracy: 1.001906579732895
Time taken: 5.949770569801331
Step taken: 2364.3

Rank: 80
Accuracy: 1.0022528767585754
Time taken: 5.410847902297974
Step taken: 2154.1

Rank: 90
Accuracy: 1.000049215555191
Time taken: 5.961901760101318
Step taken: 2349.9

Rank: 100
Accuracy: 1.0018891990184784
Time taken: 5.722808074951172
Step taken: 2252.6

