In [1]:
import numpy as np
import time
import plotly.express as px 
import plotly.graph_objs as go

In [2]:
COLORS = px.colors.qualitative.Plotly
# blue, red, green, purple, cyan, pink, ...

LINE_STYLES = ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"]

SYMBOLS = [
    "circle",
    "square",
    "star",
    "x",
    "triangle-up",
    "pentagon",
    "cross",
]

In [67]:
n = 100

iter_list = list(range(n + 1))

# Constant eta then (almost) 1
eta_cst = 0.95

eta_list = [eta_cst]
lmbda_list = [0.]

step_size_list = [eta_cst]
beta_list = [0.]

for i in range(n):
    if step_size_list[-1] <= 0.5 or beta_list[-1] >= 0.5:
        print(i)
        eta_cst = 1.
    
    # Update eta to its current constant value
    eta_k = eta_list[-1]
    eta_k_plus_1 = eta_cst
    eta_list.append(eta_k_plus_1)
    
    # Update lmbda
    lmbda_k = lmbda_list[-1]
    lmbda_k_plus_1 = eta_k * (1. + lmbda_k - eta_k) / eta_k_plus_1
    lmbda_list.append(lmbda_k_plus_1)

    # Update step size and beta
    step_size_k = eta_k / (lmbda_k_plus_1 + 1.)
    step_size_list.append(step_size_k)
    beta_k = lmbda_k / (lmbda_k_plus_1 + 1.)
    beta_list.append(beta_k)

18
19


In [69]:
fig = go.Figure(layout_yaxis_range=[0, 1.01])

fig.add_trace(
    go.Scatter(
        x=iter_list,
        y=eta_list,
        name=r"$\Huge\eta_k \phantom{aaaaa},$",
        legendgroup="1",
        mode="lines",
        line=dict(color=COLORS[0], dash=LINE_STYLES[0], width=5),
        marker=dict(symbol=SYMBOLS[0], size=10),
    )
)

fig.add_trace(
    go.Scatter(
        x=iter_list,
        y=lmbda_list,
        name=r"$\Huge\zeta_k$",
        mode="lines",
        legendgroup="2",
        line=dict(color=COLORS[1], dash=LINE_STYLES[1], width=5),
        marker=dict(symbol=SYMBOLS[1], size=10),
    )
)

fig.update_layout(
    # width=1000,
    # height=500,
    margin={"l": 20, "r": 20, "t": 20, "b": 20},
    template="plotly_white", 
    font=dict(size=20,),
    xaxis_title=r"iteration",
    legend=dict(
        orientation="h",
        xanchor="center",
        yanchor="top",
        y=1.12, 
        x=.5,)
)

fig.write_image("experiments/results/example_momentum_eta_zeta_param_seq.pdf")
fig.show()

In [70]:
fig = go.Figure(layout_yaxis_range=[0, 1.01])

fig.add_trace(
    go.Scatter(
        x=iter_list,
        y=step_size_list,
        name=r"$\Huge\gamma_k \phantom{aaaaa},$",
        legendgroup="1",
        mode="lines",
        line=dict(color=COLORS[3], dash=LINE_STYLES[3], width=5),
        marker=dict(symbol=SYMBOLS[3], size=10),
    )
)

fig.add_trace(
    go.Scatter(
        x=iter_list,
        y=beta_list,
        name=r"$\Huge\beta_k$",
        mode="lines",
        legendgroup="2",
        line=dict(color=COLORS[4], dash=LINE_STYLES[4], width=5),
        marker=dict(symbol=SYMBOLS[4], size=10),
    )
)

fig.update_layout(
    margin={"l": 20, "r": 20, "t": 20, "b": 20},
    template="plotly_white", 
    font=dict(size=20,),
    xaxis_title=r"iteration",
    legend=dict(
        orientation="h",
        xanchor="center",
        yanchor="top",
        y=1.12, 
        x=.5,)
)

fig.write_image("experiments/results/example_momentum_gamma_beta_param_seq.pdf")
fig.show()