In [2]:
import math
import torch
import scipy.special as sc
from scipy.stats import chi2
import random
import numpy as np
import matplotlib.pyplot as plt
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@torch.no_grad()
def _invSqrt(A):
        vals, vecs = torch.linalg.eigh(A)
        return vecs @ torch.diag(1.0 / torch.sqrt(vals)) @ vecs.T

@torch.no_grad()
def rTensorNorm(n, M, Sigma1, Sigma2, Sigma3):
    p1 = Sigma1.shape[0]
    p2 = Sigma2.shape[0]
    p3 = Sigma3.shape[0]
    e1_vals, e1_vecs = torch.linalg.eigh(Sigma1)
    sqrtSigma1 = e1_vecs @ torch.diag(torch.sqrt(e1_vals)) @ e1_vecs.T
    e2_vals, e2_vecs = torch.linalg.eigh(Sigma2)
    sqrtSigma2 = e2_vecs @ torch.diag(torch.sqrt(e2_vals)) @ e2_vecs.T
    e3_vals, e3_vecs = torch.linalg.eigh(Sigma3)
    sqrtSigma3 = e3_vecs @ torch.diag(torch.sqrt(e3_vals)) @ e3_vecs.T
    Z = torch.randn(n, p1, p2, p3, device=device)
    Z = Z.view(n * p1 * p2, p3) @ sqrtSigma3
    Z = Z.view(n, p1, p2, p3)
    Z = Z.permute(0, 1, 3, 2).contiguous()
    Z = Z.view(n * p1 * p3, p2) @ sqrtSigma2
    Z = Z.view(n, p1, p3, p2)
    Z = Z.permute(0, 3, 2, 1).contiguous()
    Z = Z.view(n * p2 * p3, p1) @ sqrtSigma1
    Z = Z.view(n, p2, p3, p1)
    Z = Z.permute(0, 3, 1, 2).contiguous()
    M_expanded = M.unsqueeze(0).expand(n, p1, p2, p3)
    Z = Z + M_expanded
    return Z

@torch.no_grad()
def computeTensorMD(C, invSqrt1, invSqrt2, invSqrt3, returnContributions=False):
    n, p1, p2, p3 = C.shape
    C = C.view(n * p1 * p2, p3) @ invSqrt3
    C = C.view(n, p1, p2, p3)
    C = C.permute(0, 1, 3, 2).contiguous()
    C = C.view(n * p1 * p3, p2) @ invSqrt2
    C = C.view(n, p1, p3, p2)
    C = C.permute(0, 3, 2, 1).contiguous()
    C = C.view(n * p2 * p3, p1) @ invSqrt1
    C = C.view(n, p2, p3, p1)
    C = C.permute(0, 3, 1, 2).contiguous()
    D = C * C
    D = D.view(n, -1)
    TMDsq = D.sum(dim=1)
    if returnContributions:
        D = D.view(n, p1, p2, p3)
        return TMDsq, D
    else:
        return TMDsq

@torch.no_grad()
def updateOneCov(C, invSqrt2, invSqrt3):
    n, p1, p2, p3 = C.shape
    C = C.view(n * p1 * p2, p3) @ invSqrt3
    C = C.view(n, p1, p2, p3)
    C = C.permute(0, 1, 3, 2).contiguous()
    C = C.view(n * p1 * p3, p2) @ invSqrt2
    C = C.view(n, p1, p3, p2)
    C = C.permute(1, 0, 3, 2).contiguous()
    C = C.view(p1, n * p2 * p3)
    Sigma1 = (C @ C.T) / (n * p2 * p3)
    return Sigma1





@torch.no_grad()
def flipFlopMLE(C1, C2, C3,
                Sigma1init, Sigma2init, Sigma3init,
                invSqrt2init, invSqrt3init,
                maxIter,
                tol):

    old1 = Sigma1init
    old2 = Sigma2init
    old3 = Sigma3init
    invSqrt2 = invSqrt2init
    invSqrt3 = invSqrt3init
    for it in range(maxIter):
        Sigma1 = updateOneCov(C1, invSqrt2, invSqrt3)
        invSqrt1 = _invSqrt(Sigma1)
        Sigma2 = updateOneCov(C2, invSqrt3, invSqrt1)
        invSqrt2 = _invSqrt(Sigma2)
        Sigma3 = updateOneCov(C3, invSqrt1, invSqrt2)
        d11_1 = Sigma1[0, 0]
        d11_2 = Sigma2[0, 0]
        Sigma1 = Sigma1 / d11_1
        Sigma2 = Sigma2 / d11_2
        Sigma3 = Sigma3 * (d11_1 * d11_2)
        if it == maxIter - 1:
            break
        frobDiff = torch.sum((Sigma1 - old1) ** 2) \
                 + torch.sum((Sigma2 - old2) ** 2) \
                 + torch.sum((Sigma3 - old3) ** 2)
        if frobDiff < tol:
            break
        invSqrt3 = _invSqrt(Sigma3)
        old1 = Sigma1
        old2 = Sigma2
        old3 = Sigma3
    return {
        "Sigma1": Sigma1,
        "Sigma2": Sigma2,
        "Sigma3": Sigma3,
        "invSqrt1": invSqrt1,
        "invSqrt2": invSqrt2
    }

@torch.no_grad()
def cStep(X, C, Sigma1, Sigma2, Sigma3,
          invSqrt1, invSqrt2, invSqrt3,
          alpha,
          maxIterC,
          tolC,
          maxIterFF,
          tolFF):
    def _ld(S1, S2, S3):
        return (p2 * p3) * torch.logdet(S1) \
             + (p1 * p3) * torch.logdet(S2) \
             + (p1 * p2) * torch.logdet(S3)
    n, p1, p2, p3 = X.shape
    h = int(math.floor(alpha * n))
    ldOld = _ld(Sigma1, Sigma2, Sigma3)
    for it in range(maxIterC):
        TMDsAll = computeTensorMD(C, invSqrt1, invSqrt2, invSqrt3)
        sortedIdx = torch.argsort(TMDsAll)
        subsetIndices = sortedIdx[:h]
        Xsub = X[subsetIndices]
        subMean = Xsub.mean(dim=0)
        C1 = Xsub - subMean
        C2 = C1.permute(0, 2, 3, 1).contiguous()
        C3 = C1.permute(0, 3, 1, 2).contiguous()
        initFF = flipFlopMLE(
            C1, C2, C3,
            Sigma1, Sigma2, Sigma3,
            invSqrt2, invSqrt3,
            maxIterFF,
            tolFF
        )
        Sigma1 = initFF["Sigma1"]
        Sigma2 = initFF["Sigma2"]
        Sigma3 = initFF["Sigma3"]
        invSqrt1 = initFF["invSqrt1"]
        invSqrt2 = initFF["invSqrt2"]
        ldNew = _ld(Sigma1, Sigma2, Sigma3)
        if it == maxIterC - 1 or abs(ldNew - ldOld) < tolC:
            break
        ldOld = ldNew
        C = X - subMean
        vals, vecs = torch.linalg.eigh(Sigma3)
        invSqrt3 = vecs @ torch.diag(1.0 / torch.sqrt(vals)) @ vecs.T
    return {
        "Sigma1": Sigma1,
        "Sigma2": Sigma2,
        "Sigma3": Sigma3,
        "invSqrt1": invSqrt1,
        "invSqrt2": invSqrt2,
        "subsetIndices": subsetIndices,
        "TMDsAll": TMDsAll,
        "ld": ldNew
    }

@torch.no_grad()
def tmcd(X,
         alpha,
         nSubsets,
         nBest,
         maxIterCshort,
         maxIterFFshort,
         maxIterCfull,
         maxIterFFfull,
         tolC,
         tolFF,
         beta):

    n, p1, p2, p3 = X.shape
    s = int(math.ceil(p1/(p2*p3) + p2/(p1*p3) + p3/(p1*p2))) + 2
    allSubsets = []
    for _ in range(nSubsets):
        allSubsets.append(torch.randperm(n)[:s].to(device))
    shortResults = []
    for i in range(nSubsets):
        idx = allSubsets[i]
        xSub = X[idx]
        subMean = xSub.mean(dim=0)
        C1 = xSub - subMean
        C2 = C1.permute(0, 2, 3, 1).contiguous()
        C3 = C1.permute(0, 3, 1, 2).contiguous()
        initSig1 = torch.eye(p1, device=device)
        initSig2 = torch.eye(p2, device=device)
        initSig3 = torch.eye(p3, device=device)
        initInvSqrt2 = torch.eye(p2, device=device)
        initInvSqrt3 = torch.eye(p3, device=device)
        shortMLE = flipFlopMLE(
            C1, C2, C3,
            initSig1, initSig2, initSig3,
            initInvSqrt2, initInvSqrt3,
            maxIterFFshort,
            tolFF
        )
        C = X - subMean
        curInvSqrt3 = _invSqrt(shortMLE["Sigma3"])
        shortRes = cStep(
            X, C,
            shortMLE["Sigma1"],
            shortMLE["Sigma2"],
            shortMLE["Sigma3"],
            shortMLE["invSqrt1"],
            shortMLE["invSqrt2"],
            curInvSqrt3,
            alpha,
            maxIterCshort,
            tolC,
            maxIterFFshort,
            tolFF
        )
        shortResults.append(shortRes)
    allLd = torch.tensor([res["ld"] for res in shortResults], device=device)
    rankLd = torch.argsort(allLd)
    topIdx = rankLd[:min(nBest, nSubsets)]
    fullResults = []
    for j in range(len(topIdx)):
        chosen = shortResults[topIdx[j]]
        xSub = X[chosen["subsetIndices"]]
        subMean = xSub.mean(dim=0)
        C = X - subMean
        invSqrt3 = _invSqrt(chosen["Sigma3"])
        fullRes = cStep(
            X, C,
            chosen["Sigma1"],
            chosen["Sigma2"],
            chosen["Sigma3"],
            chosen["invSqrt1"],
            chosen["invSqrt2"],
            invSqrt3,
            alpha,
            maxIterCfull,
            tolC,
            maxIterFFfull,
            tolFF
        )
        fullResults.append(fullRes)
    allLdFull = torch.tensor([r["ld"] for r in fullResults], device=device)
    bestFullIdx = torch.argmin(allLdFull)
    bestRaw = fullResults[bestFullIdx]
    dfMain = p1 * p2 * p3
    dfPlus = dfMain + 2
    chiAlpha = chi2.ppf(alpha, dfMain)
    cdfVal = chi2.cdf(chiAlpha, dfPlus)
    gammaAlpha = alpha / cdfVal
    S1 = bestRaw["Sigma1"]
    S2 = bestRaw["Sigma2"]
    S3 = bestRaw["Sigma3"] * gammaAlpha
    invSqrt2 = bestRaw["invSqrt2"]
    vals3, vecs3 = torch.linalg.eigh(S3)
    invSqrt3 = vecs3 @ torch.diag(1.0 / torch.sqrt(vals3)) @ vecs3.T
    TMDsqAll = bestRaw["TMDsAll"]
    cutoff = chi2.ppf(beta, dfMain)
    goodSet = torch.where((TMDsqAll / gammaAlpha) < cutoff)[0]
    finalGood = torch.unique(torch.cat([bestRaw["subsetIndices"], goodSet]))
    outliers = torch.tensor(list(set(range(n)) - set(finalGood.tolist())), device=device)
    Xgood = X[finalGood]
    M = Xgood.mean(dim=0)
    alphaHat = float(len(finalGood)) / n
    C1 = Xgood - M
    C2 = C1.permute(0, 2, 3, 1).contiguous()
    C3 = C1.permute(0, 3, 1, 2).contiguous()
    ffFinal = flipFlopMLE(
        C1, C2, C3,
        S1, S2, S3,
        invSqrt2, invSqrt3,
        maxIterFFfull,
        tolFF
    )
    S1 = ffFinal["Sigma1"]
    S2 = ffFinal["Sigma2"]
    S3 = ffFinal["Sigma3"]
    chiAlphaHat = chi2.ppf(alphaHat, dfMain)
    cdfValHat = chi2.cdf(chiAlphaHat, dfPlus)
    gammaAlphaHat = alphaHat / cdfValHat
    S3 = S3 * gammaAlphaHat
    return {
        "M": M,
        "Sigma1": S1,
        "Sigma2": S2,
        "Sigma3": S3,
        "outliers": outliers,
        "finalGood": finalGood
    }



In [4]:



def genPositiveDefMat(dim, rangeVar=(0.1, 1.0)):
    eta = 1.0
    def _rcor_onion(d):
        if d == 1:
            return torch.ones(1, 1, dtype=torch.double)
        if d == 2:
            rho = 2 * torch.distributions.Beta(eta, eta).sample() - 1
            return torch.tensor([[1.0, rho], [rho, 1.0]], dtype=torch.double)
        beta = eta + (d - 2) / 2
        r12 = 2 * torch.distributions.Beta(beta, beta).sample() - 1
        R = torch.tensor([[1.0, r12], [r12, 1.0]], dtype=torch.double)
        for m in range(2, d):
            beta -= 0.5
            y = torch.distributions.Beta(m / 2, beta).sample()
            z = torch.randn(m, dtype=torch.double)
            z = z / torch.linalg.norm(z)
            w = torch.sqrt(y) * z
            q = torch.mv(torch.linalg.cholesky(R), w)
            R = torch.block_diag(R, torch.ones(1, dtype=torch.double))
            R[:-1, -1] = q
            R[-1, :-1] = q
        return R
    low, high = rangeVar
    variances = (high - low) * torch.rand(dim, dtype=torch.double) + low
    D = torch.diag(torch.sqrt(variances))
    Sigma = D @ _rcor_onion(dim) @ D
    return Sigma.to(torch.get_default_dtype())

def KL(Sigma1_1, Sigma2_1, Sigma3_1,
       Sigma1_2, Sigma2_2, Sigma3_2):

    p1 = Sigma1_1.shape[0]
    p2 = Sigma2_1.shape[0]
    p3 = Sigma3_1.shape[0]


    A3 = torch.linalg.solve(Sigma3_2, Sigma3_1)
    A2 = torch.linalg.solve(Sigma2_2, Sigma2_1)
    A1 = torch.linalg.solve(Sigma1_2, Sigma1_1)

    tr3 = torch.trace(A3)
    tr2 = torch.trace(A2)
    tr1 = torch.trace(A1)

    det3 = torch.det(A3)
    det2 = torch.det(A2)
    det1 = torch.det(A1)

    val = 0.5 * (
        (tr3 * tr2 * tr1)
        - (p1 * p2) * math.log(det3)
        - (p1 * p3) * math.log(det2)
        - (p2 * p3) * math.log(det1)
        - (p1 * p2 * p3)
    )
    return val.item()



def run_tmcd_check(n,
                   p1, p2, p3,
                   outlier_fraction,
                   method_outliers,
                   infection_fraction=None,
                   infection_range=None,
                   outlier_shift=None,
                   alpha=0.6,
                   nSubsets=100,
                   nBest=10,
                   maxIterCshort=2,
                   maxIterFFshort=2,
                   maxIterCfull=100,
                   maxIterFFfull=100,
                   tolC=1e-4,
                   tolFF=1e-3,
                   beta=0.99,
                   seed=None):

    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_runs = 1

    Sigma1_true = genPositiveDefMat(p1, rangeVar=(0.1, 1.0)).to(device)


    Sigma2_true = 0.5 * torch.ones(p2, p2, device=device)
    Sigma2_true.fill_diagonal_(0.7)


    Sigma3_true = torch.zeros(p3, p3, device=device)
    for i in range(p3):
        for j in range(p3):
            if i == j:
                Sigma3_true[i, j] = 0.7
            else:
                Sigma3_true[i, j] = 0.5 ** abs(i - j)


    Mu_true = torch.randn(p1, p2, p3, device=device) + 1.0

    X_clean = rTensorNorm(n, Mu_true, Sigma1_true, Sigma2_true, Sigma3_true)

    X_outliers = X_clean.clone()
    n_outliers = int(math.floor(n * outlier_fraction))
    outlier_indices = []

    if n_outliers > 0:
        outlier_indices = random.sample(range(n), n_outliers)

        if method_outliers == "method1":

            if outlier_shift is None:
                raise ValueError("For method1, you must provide outlier_shift.")
            Mu_shifted = Mu_true + outlier_shift
            X_outlier_sample = rTensorNorm(n_outliers, Mu_shifted,
                                           Sigma1_true, Sigma2_true, Sigma3_true)
            for i, idx in enumerate(outlier_indices):
                X_outliers[idx] = X_outlier_sample[i]

        elif method_outliers == "method2":

            if (infection_fraction is None) or (infection_range is None):
                raise ValueError("For method2, must provide infection_fraction and infection_range.")
            total_cells = p1 * p2 * p3
            cells_to_infect = int(math.floor(infection_fraction * total_cells))

            for idx in outlier_indices:

                infected_cells_lin = random.sample(range(total_cells), cells_to_infect)
                slice_ijk = X_outliers[idx]
                for lin in infected_cells_lin:
                    i = lin // (p2 * p3)
                    rem = lin % (p2 * p3)
                    j = rem // p3
                    k = rem % p3

                    val = random.uniform(infection_range[0], infection_range[1])
                    slice_ijk[i, j, k] += val

        else:
            raise ValueError("Unknown method_outliers. Must be 'method1' or 'method2'.")



    elapsed = []

    final_tmcd_result = None


    rng_state = torch.get_rng_state().clone()
    np_rng_state = np.random.get_state()
    py_rng_state = random.getstate()

    for i in range(n_runs):

        torch.set_rng_state(rng_state.clone())
        np.random.set_state(np_rng_state)
        random.setstate(py_rng_state)

        start_time = time.time()
        tmcd_result = tmcd(
            X_outliers,
            alpha=alpha,
            nSubsets=nSubsets,
            nBest=nBest,
            maxIterCshort=maxIterCshort,
            maxIterFFshort=maxIterFFshort,
            maxIterCfull=maxIterCfull,
            maxIterFFfull=maxIterFFfull,
            tolC=tolC,
            tolFF=tolFF,
            beta=beta
        )
        end_time = time.time()
        elapsed.append(end_time - start_time)

        final_tmcd_result = tmcd_result

    elapsed_sorted = sorted(elapsed)
    print("\nElapsed times (seconds) from fastest to slowest:")
    print(" - ".join(f"{e:.3f}" for e in elapsed_sorted))


    s1_true = Sigma1_true[0, 0].item()
    s2_true = Sigma2_true[0, 0].item()
    Sigma1_true_scaled = Sigma1_true / s1_true
    Sigma2_true_scaled = Sigma2_true / s2_true
    Sigma3_true_scaled = Sigma3_true * (s1_true * s2_true)

    S1_est = final_tmcd_result["Sigma1"]
    S2_est = final_tmcd_result["Sigma2"]
    S3_est = final_tmcd_result["Sigma3"]
    Mu_est = final_tmcd_result["M"]

    Mu_diff = Mu_est - Mu_true
    Mu_diff_norm = Mu_diff.norm().item()

    Sigma1_diff = S1_est - Sigma1_true_scaled
    Sigma2_diff = S2_est - Sigma2_true_scaled
    Sigma3_diff = S3_est - Sigma3_true_scaled

    Sigma1_diff_norm = Sigma1_diff.norm().item()
    Sigma2_diff_norm = Sigma2_diff.norm().item()
    Sigma3_diff_norm = Sigma3_diff.norm().item()

    print("\nDifference norms:")
    print(f"  Mu diff norm        : {Mu_diff_norm:.4f}")
    print(f"  Sigma1 diff norm    : {Sigma1_diff_norm:.4f}")
    print(f"  Sigma2 diff norm    : {Sigma2_diff_norm:.4f}")
    print(f"  Sigma3 diff norm    : {Sigma3_diff_norm:.4f}")


    kl_value = KL(Sigma1_true_scaled, Sigma2_true_scaled, Sigma3_true_scaled,
                  S1_est, S2_est, S3_est)
    print(f"\nKL divergence         : {kl_value:.4f}")


    flagged_outliers = final_tmcd_result["outliers"].cpu().numpy()
    actual_outliers  = np.array(outlier_indices, dtype=int)

    n_flagged = flagged_outliers.size
    flagged_set = set(flagged_outliers.tolist())
    actual_set  = set(actual_outliers.tolist())
    n_correct   = len(flagged_set.intersection(actual_set))

    if n_flagged == 0:
        precision = 0.0
    else:
        precision = n_correct / n_flagged

    if actual_outliers.size == 0:
        recall = 1.0
    else:
        recall = n_correct / actual_outliers.size

    if (precision + recall) == 0:
        F_score = 0.0
    else:
        F_score = 2.0 * precision * recall / (precision + recall)

    print("\nOutlier detection summary:")
    print(f"  Actual number of outliers : {actual_outliers.size}")
    print(f"  Flagged as outliers       : {n_flagged}")
    print(f"  Correctly flagged         : {n_correct}")
    print(f"  Precision                 : {precision:.3f}")
    print(f"  Recall                    : {recall:.3f}")
    print(f"  F-score                   : {F_score:.3f}")


    M_full = X_outliers.mean(dim=0)
    C1_full = X_outliers - M_full
    C2_full = C1_full.permute(0, 2, 3, 1).contiguous()
    C3_full = C1_full.permute(0, 3, 1, 2).contiguous()

    mle_full = flipFlopMLE(
        C1_full, C2_full, C3_full,
        torch.eye(p1, device=device),
        torch.eye(p2, device=device),
        torch.eye(p3, device=device),
        torch.eye(p2, device=device),
        torch.eye(p3, device=device),
        maxIter=maxIterFFfull,
        tol=tolFF
    )
    S1_full = mle_full["Sigma1"]
    S2_full = mle_full["Sigma2"]
    S3_full = mle_full["Sigma3"]

    kl_full = KL(Sigma1_true_scaled, Sigma2_true_scaled, Sigma3_true_scaled,
                 S1_full, S2_full, S3_full)
    print(f"\nKL divergence (True vs. full-sample MLE): {kl_full:.4f}\n")

    return {
        "Sigma1_true_scaled": Sigma1_true_scaled,
        "Sigma2_true_scaled": Sigma2_true_scaled,
        "Sigma3_true_scaled": Sigma3_true_scaled,
        "Sigma1_est": S1_est,
        "Sigma2_est": S2_est,
        "Sigma3_est": S3_est,
        "Mu_true": Mu_true,
        "Mu_est": Mu_est,
        "outliers_flagged": flagged_outliers,
        "outliers_actual": actual_outliers,
        "Mu_diff_norm": Mu_diff_norm,
        "Sigma1_diff_norm": Sigma1_diff_norm,
        "Sigma2_diff_norm": Sigma2_diff_norm,
        "Sigma3_diff_norm": Sigma3_diff_norm,
        "KL_value": kl_value,
        "precision": precision,
        "recall": recall,
        "F_score": F_score
    }



def compareSigmas(res, digits=3, col_gap=6):


    def format_cell(val, w):
        return f"{val:.{digits}f}".rjust(w)

    def printSigma(true_mat, est_mat, name):
        print(f"\n{name} (scaled): Left = True, Right = Estimated")
        p = true_mat.shape[0]
        all_vals = torch.cat((true_mat.flatten(), est_mat.flatten()))
        test_strs = [f"{v:.{digits}f}" for v in all_vals]
        cell_width = max(len(s) for s in test_strs)
        gap = " " * col_gap

        for i in range(p):
            row_strs = []
            for j in range(p):
                left_str = format_cell(true_mat[i, j].item(), cell_width)
                right_str = format_cell(est_mat[i, j].item(), cell_width)
                row_strs.append(f"{left_str} | {right_str}")
            print(gap.join(row_strs))

    printSigma(res["Sigma1_true_scaled"], res["Sigma1_est"], "Sigma1")
    printSigma(res["Sigma2_true_scaled"], res["Sigma2_est"], "Sigma2")
    printSigma(res["Sigma3_true_scaled"], res["Sigma3_est"], "Sigma3")






In [None]:
print("PyTorch version :", torch.__version__)
print("CUDA version    :", torch.version.cuda)
print("CUDA available  :", torch.cuda.is_available())

PyTorch version : 2.6.0+cu124
CUDA version    : 12.4
CUDA available  : True


In [10]:
seed = 101
params = dict(
                n=350,
                p1=200,
                p2=250,
                p3=50,
                outlier_shift=0.5,
                outlier_fraction=0.25,
                method_outliers="method1",
                alpha=0.6,
                nSubsets=500,
                nBest=10,
                maxIterCshort=2,
                maxIterFFshort=2,
                maxIterCfull=100,
                maxIterFFfull=100,
                tolC=1e-4,
                tolFF=1e-3,
                beta=0.99,
                seed=seed
)


print("\n--- Running TMCD Check ---")
res = run_tmcd_check(**params)



--- Running TMCD Check ---

Elapsed times (seconds) from fastest to slowest:
444.616

Difference norms:
  Mu diff norm        : 57.3175
  Sigma1 diff norm    : 0.1113
  Sigma2 diff norm    : 0.1280
  Sigma3 diff norm    : 0.0195

KL divergence         : 138.3750

Outlier detection summary:
  Actual number of outliers : 87
  Flagged as outliers       : 140
  Correctly flagged         : 87
  Precision                 : 0.621
  Recall                    : 1.000
  F-score                   : 0.767

KL divergence (True vs. full-sample MLE): -inf



In [9]:
import gc

for var in list(globals().values()):
    if hasattr(var, 'iterator') and hasattr(var, 'num_workers'):
        var._iterator._shutdown_workers()
for k in list(globals().keys()):
    obj = globals()[k]
    if torch.is_tensor(obj) or isinstance(obj, torch.nn.Module):
        del globals()[k]


gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
print("allocated :", torch.cuda.memory_allocated() / 1024**2, "MB")
print("reserved  :", torch.cuda.memory_reserved()  / 1024**2, "MB")


allocated : 8.125 MB
reserved  : 20.0 MB


In [8]:
del tensors_and_models
gc.collect()
torch.cuda.empty_cache()


NameError: name 'tensors_and_models' is not defined

In [3]:
import sys, subprocess, importlib.util, urllib.request
!pip install rdata tqdm

import rdata, torch, numpy as np
from pathlib import Path

# ---------- load RData ----------
url  = "https://wis.kuleuven.be/stat/robust/Programs/DO/do-video-data-rdata"
fbin = Path("do-video-data.rdata")
if not fbin.exists():
    urllib.request.urlretrieve(url, fbin)

video_py = rdata.conversion.convert(rdata.parser.parse_file(fbin))["Video"]
print("Video shape:", video_py.shape)      # (633, 128, 160, 3)

Collecting rdata
  Downloading rdata-0.11.2-py3-none-any.whl.metadata (11 kB)
Downloading rdata-0.11.2-py3-none-any.whl (46 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/46.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.5/46.5 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdata
Successfully installed rdata-0.11.2
Video shape: (633, 128, 160, 3)


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_full = torch.tensor(video_py, dtype=torch.float32, device=device).contiguous()
n, p1, p2, p3 = X_full.shape
assert p3 == 3, "RGB expected"

seed = 102
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

start = time.perf_counter()
tmcd_res = tmcd(
    X_full,
    alpha          = 0.75,
    nSubsets       = 10,
    nBest          = 10,
    maxIterCshort  = 2,
    maxIterFFshort = 2,
    maxIterCfull   = 25,
    maxIterFFfull  = 100,
    tolC           = 1e-4,
    tolFF          = 1e-3,
    beta           = 0.9999
)
if torch.cuda.is_available():
    torch.cuda.synchronize()
elapsed = time.perf_counter() - start

print(f"TMCD done – good frames: {tmcd_res['finalGood'].numel()}")
print(f"Elapsed time: {elapsed:.2f}s")
print("Outlier indices:", sorted(tmcd_res["outliers"].cpu().tolist()))

# number of frames *not* flagged as outliers from index 483 onward
start      = 483
n_frames   = X_full.shape[0]               # 633
outliers   = set(tmcd_res["outliers"].cpu().tolist())

missing = [i for i in range(start, n_frames) if i not in outliers]

print(f"From frame {start} to {n_frames-1}: "
      f"{len(missing)} frames are NOT in the outlier list.")
print("Indices:", missing)



TMCD done – good frames: 474
Elapsed time: 4.95s
Outlier indices: [409, 410, 413, 421, 443, 447, 456, 458, 459, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632]
From frame 483 to 632: 0 frames are NOT in the outlier list.
Indices: []


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_full = torch.tensor(video_py, dtype=torch.float32, device=device).contiguous()
n, p1, p2, p3 = X_full.shape
assert p3 == 3, "RGB expected"

seed = 102
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

start = time.perf_counter()
tmcd_res = tmcd(
    X_full,
    alpha          = 0.75,
    nSubsets       = 500,
    nBest          = 20,
    maxIterCshort  = 2,
    maxIterFFshort = 2,
    maxIterCfull   = 100,
    maxIterFFfull  = 100,
    tolC           = 1e-4,
    tolFF          = 1e-3,
    beta           = 0.9999
)
if torch.cuda.is_available():
    torch.cuda.synchronize()
elapsed = time.perf_counter() - start

print(f"TMCD done – good frames: {tmcd_res['finalGood'].numel()}")
print(f"Elapsed time: {elapsed:.2f}s")
print("Outlier indices:", sorted(tmcd_res["outliers"].cpu().tolist()))

# number of frames *not* flagged as outliers from index 483 onward
start      = 483
n_frames   = X_full.shape[0]               # 633
outliers   = set(tmcd_res["outliers"].cpu().tolist())

missing = [i for i in range(start, n_frames) if i not in outliers]

print(f"From frame {start} to {n_frames-1}: "
      f"{len(missing)} frames are NOT in the outlier list.")
print("Indices:", missing)



TMCD done – good frames: 474
Elapsed time: 67.78s
Outlier indices: [409, 410, 413, 421, 443, 447, 450, 456, 458, 459, 483, 484, 485, 486, 487, 488, 489, 490, 491, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632]
From frame 483 to 632: 1 frames are NOT in the outlier list.
Indices: [492]
