In [1]:
import pandas as pd
from matplotlib.ticker import ScalarFormatter
from glch_utils import compute_hulls, tree_nodes
from glch_threed_utils import plot_3d_lch
from glch_experiments_functions import glch_rate_vs_dist

def glch3d_rdc(
    csv_path,
    complexity_axis="params",
    constrained=True,
    start="left",
    lambdas=["5e-3", "1e-2", "2e-2"],
    fldr="glch_results",
    debug_folder="debug",
    debug=True,
    select_function="angle_rule",
    axes_aliases=None
):

    rs = []
    tree_strs = []
    for lmbda in lambdas:
        r,tree_str = glch_rate_vs_dist(
            csv_path,
            [complexity_axis,"loss"],
            algo="glch",
            constrained=constrained,
            weights=None,
            start=start,
            lambdas= [lmbda],
            axes_ranges=None,
            axes_aliases=axes_aliases,
            fldr=fldr,
            debug_folder=debug_folder,
            debug=debug,
            select_function=select_function
        )
        tree_strs.append(tree_str)
        rs.append(r)

    data = pd.read_csv(csv_path)
    data = data.set_index("labels")

    axes = ["bpp_loss","mse_loss",complexity_axis]

    exp_id = f'glch3D_{select_function}_{"constrained" if constrained else "unconstrained"}_{"_vs_".join(axes)}_start_{start}'

#     save_threed_hull_data(data,rs,axes,complexity_axis,exp_id,fldr=fldr)

#     save_threed_history(data,tree_strs,exp_id,fldr=fldr)

    estimated_hulls = []
    for r in rs:
        estimated_hull_points = data.loc[([str(r)] + tree_nodes(r,[], "lch")),:]
#         _,estimated_hull_points = compute_hulls(data,[r],complexity_axis,"loss")
        estimated_hulls.append(estimated_hull_points)
    
    combined_estimated_hull = pd.concat(estimated_hulls,axis=0)

#     with open(f'{fldr}/{exp_id}_threed_hull.txt', 'w') as f:
#         print("\nestimated_hull_points:\n",file=f)
#         print(combined_estimated_hull[axes],file=f)

    cloud = data.loc[:,axes].values.tolist()

    combined_estimated_hull_cloud = combined_estimated_hull.loc[:,axes].values.tolist()

    fig = plot_3d_lch([cloud,combined_estimated_hull_cloud],["b","g"],['o','s'],[0.05,1],
                ax_labels=["$R$","$D$","$C$"],
                      title=None,normalizations=[1,0.001,1000000],figsize=(10,10),fontsize=14)
    
    return fig

  self[key] = other[key]


In [2]:

results_folder = "glch_results"
debug_folder = "glch_debug"

select_function = "angle_rule"
constrained = True


# glch3d_rdc(
#     "/home/lucas/Documents/perceptronac/scripts/tradeoffs/bpp-mse-psnr-loss-flops-params_bmshj2018-factorized_10000-epochs_D-3-4_L-2e-2-1e-2-5e-3_N-32-64-96-128-160-192-224_M-32-64-96-128-160-192-224-256-288-320.csv",
#     complexity_axis="flops",
#     constrained=constrained,
#     start="left",
#     lambdas=["5e-3", "1e-2", "2e-2"],
#     fldr=results_folder,
#     debug_folder=debug_folder,
#     debug=False,
#     select_function=select_function,
#     axes_aliases=["Complexity (FLOPs)","$L=R+\lambda D$"]
# )

fig = glch3d_rdc(
    "/home/lucas/Documents/perceptronac/scripts/tradeoffs/bpp-mse-psnr-loss-flops-params_bmshj2018-factorized_10000-epochs_D-3-4_L-2e-2-1e-2-5e-3_N-32-64-96-128-160-192-224_M-32-64-96-128-160-192-224-256-288-320.csv",
    complexity_axis="params",
    constrained=constrained,
    start="left",
    lambdas=["5e-3", "1e-2", "2e-2"],
    fldr=results_folder,
    debug_folder=debug_folder,
    debug=False,
    select_function=select_function,
    axes_aliases=["Complexity (number of parameters)","$L=R+\lambda D$"]
)

In [3]:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation

def rotate(angle):
    fig.axes[0].view_init(azim=angle)

fig.axes[0].set_yticks([5,10,15])
fig.axes[0].set_yticklabels([0.005,0.010,0.015])

fig.axes[0].set_zticks([0,2,4,6,8])
fig.axes[0].set_zticklabels(["0","$2 \cdot 10^6$","$4 \cdot 10^6$","$6 \cdot 10^6$","$8 \cdot 10^6$"])
    
print("Making animation")
rot_animation = animation.FuncAnimation(fig, rotate, frames=np.arange(135, 302, 2), interval=100)
rot_animation.save('rotation.gif', dpi=80, writer='imagemagick')

Making animation
