In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import numqi
from utils import *

phi = np.linspace(0, np.pi, 20)
theta = np.linspace(0, 2 * np.pi, 40)
phi, theta = np.meshgrid(phi, theta)

a = np.sin(phi) * np.cos(theta)
b = np.sin(phi) * np.sin(theta)
c = np.cos(phi)

def gme_numerical(a, b, c):
    model = numqi.matrix_space.DetectCanonicalPolyadicRankModel([2,2,2], rank=1)
    kwargs = dict(theta0='uniform', tol=1e-14, num_repeat=3, print_every_round=0, early_stop_threshold=1e-14)
    state = W_type_state(a, b, c)
    model.set_target(state)
    value = numqi.optimize.minimize(model, **kwargs).fun
    return value

W_type_state_gme_analytical = np.vectorize(W_type_state_gme)
f_values_analytical = W_type_state_gme_analytical(a, b, c)

W_type_state_gme_numerical = np.vectorize(gme_numerical)
f_values_numerical = W_type_state_gme_numerical(a, b, c)

error = np.abs(f_values_analytical - f_values_numerical)



In [None]:
print(np.max(error))

In [None]:
import plotly.express as px
fig = make_subplots(rows=1, cols=2,
                    subplot_titles=("Analytical result", "Absolute error"),
                    specs=[[{'type': 'surface'}, {'type': 'surface'}]])

fig.add_trace(go.Surface(
    x=a, 
    y=b, 
    z=c, 
    surfacecolor=f_values_analytical, 
    colorscale='balance',
    colorbar=dict(x=0.45, xpad=0),  
    showscale=True,  
    name='analytical_result'), 
    row=1, 
    col=1
)

fig.add_trace(go.Surface(
    x=a, 
    y=b, 
    z=c, 
    surfacecolor=error, 
    colorscale='Viridis',
    colorbar=dict(x=1, xpad=0),
    showscale=True,  
    name='absolute_error'), 
    row=1, 
    col=2
)

fig.update_scenes(
    xaxis=dict(title_text="a", tickfont=dict(size=10)),
    yaxis=dict(title_text="b", tickfont=dict(size=10)),
    zaxis=dict(title_text="c", tickfont=dict(size=10), range=[-0.95, 1]),
    camera_eye=dict(x=1.5, y=1.5, z=1.5)
)

fig.update_layout(
    autosize=False,
    width=900,
    height=500,
    margin=dict(
        l=50,  # left margin
        r=50,  # right margin
        b=100,  # bottom margin
        t=100,  # top margin
        pad=10  # padding
    )
)

fig.show()