In [1]:
import numpy as np
import pandas as pd
import plotly.express as px

In [2]:
def tt_solve_time_complexity(s: int, r: int, I:int, d: int, num_iterations: int) -> float:
    t1 = s**3 * r**2 * I**2  # prepare local system
    t2 = s**3 * r * I**3  # in-between contractions
    t3 = s**6 * I**6  # direct solve of local system
    t4 = s**3 * I**3  # SVD truncation
    return num_iterations * d * (t1 + t2 + t3 + t4)
    

In [3]:
def cg_time_complexity(z: int, cond_num: int) -> float:
    return z * np.sqrt(cond_num)

In [82]:
# Setting 1: we can keep s, number of iterations, d are log(n)
def tt_solve_time_complexity_log_assumption(r: int, I:int, n:int) -> float:
    return tt_solve_time_complexity(s=np.log(n), r=r, I=I, d=np.log(n), num_iterations=np.log(n))

def tt_solve_time_complexity_full_log_assumption(n:int) -> float:
    return tt_solve_time_complexity_log_assumption(r=np.log(n), I=np.log(n), n=n)

# assume that TTM-rank and mode sizes are log(n)^2 big
def tt_solve_time_complexity_full_log2_assumption(n:int) -> float:
    return tt_solve_time_complexity_log_assumption(r=np.log(n)**2, I=np.log(n)**2, n=n)

In [104]:
xs = np.linspace(10, int(1e+10), int(1e+4))
xs.shape

(10000,)

In [105]:
tt1_ys = [tt_solve_time_complexity_full_log_assumption(x) for x in xs]
tt2_ys = [tt_solve_time_complexity_full_log2_assumption(x) for x in xs]  
cg_cond1_ys = [cg_time_complexity(x, cond_num=x) for x in xs]  # cond num is O(n)
cg_cond2_ys = [cg_time_complexity(x, cond_num=x**5) for x in xs]  # cond num is O(n^5)
cg_cond3_ys = [cg_time_complexity(x, cond_num=x**10) for x in xs]  # cond num is O(n^10)
# cg_cond4_ys = [cg_time_complexity(x, cond_num=x**20) for x in xs]  # cond num is O(n^20)

In [122]:
df = pd.DataFrame(np.transpose(np.stack([xs, tt1_ys, tt2_ys, cg_cond1_ys, cg_cond2_ys, cg_cond3_ys])), columns=["x", "$\\text{TT-solve, all } O(\log(n))$", "$\\text{TT-solve}, I=r=O(\log(n)^2)$", "$\\text{CG}, z = \kappa = O(n)$", "$\\text{CG}, z = O(n), \kappa = O(n^5)$", "$\\text{CG}, z = O(n), \kappa = O(n^{10})$"])

In [123]:
melted_df = df.melt(id_vars=["x"])
melted_df.head()

Unnamed: 0,x,variable,value
0,10.0,"$\text{TT-solve, all } O(\log(n))$",122195.0
1,1000110.0,"$\text{TT-solve, all } O(\log(n))$",9229695000000000.0
2,2000210.0,"$\text{TT-solve, all } O(\log(n))$",1.831583e+16
3,3000310.0,"$\text{TT-solve, all } O(\log(n))$",2.694085e+16
4,4000410.0,"$\text{TT-solve, all } O(\log(n))$",3.520256e+16


In [124]:
line_styles = {
    "$\\text{TT-solve, all } O(\log(n))$": 'dot',
    "$\\text{TT-solve}, I=r=O(\log(n)^2)$": 'dashdot',
    "$\\text{CG}, z = \kappa = O(n)$": 'solid',
    "$\\text{CG}, z = O(n), \kappa = O(n^5)$": 'dash',
    "$\\text{CG}, z = O(n), \kappa = O(n^{10})$": 'longdash',
}


In [126]:
fig = px.line(melted_df, x="x", y="value", color="variable", line_dash='variable', 
              line_dash_map=line_styles,
              log_y=True, log_x=True,
                  labels={
                      "x": "Matrix size (n)",
                      "value": "Estimated runtime (FLOPs)",
                      "variable": "Solver, assumptions",
                      "CG (n^5)": "CG (n^6)",
                  })
fig.update_layout(
        title={
            'text': "Runtime comparison of TT-solve and CG",
            'x': 0.5,
            'xanchor': 'center',
            'yanchor': 'top'
        },
        plot_bgcolor='white',  # Plot area background color
        paper_bgcolor='white',  # Entire figure background color
        font=dict(color='black'),  # Font color
    )
fig.show()
fig.write_image("plots/tt_solve_vs_cg_runtime_estimates.pdf")