In [1]:
import sqlite3
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from plotly_gif import GIF, three_d_scatter_rotate

In [2]:
def fetch_data(db_path, simulation_id):
    conn = sqlite3.connect(db_path)
    query = f"SELECT mu_s, mu_a, g, reflectance FROM mclut WHERE simulation_id = {simulation_id}"
    df = pd.read_sql_query(query, conn)
    conn.close()
    return df

def plot_wireframe(df):
    unique_g_values = df['g'].unique()
    for g in unique_g_values:
        gif = GIF()
        subset = df[df['g'] == g]
        
        # Pivot data for wireframe plotting
        pivot_table = subset.pivot(index='mu_s', columns='mu_a', values='reflectance')
        mu_s_values = pivot_table.index.values
        mu_a_values = pivot_table.columns.values
        reflectance_values = pivot_table.values
        
        mu_s_grid, mu_a_grid = np.meshgrid(mu_s_values, mu_a_values, indexing='ij')
        
        # Generate wireframe lines
        lines = []
        for i in range(len(mu_s_values)):
            lines.append(go.Scatter3d(x=[mu_s_values[i]]*len(mu_a_values), 
                                      y=mu_a_values, 
                                      z=reflectance_values[i, :], 
                                      mode='lines',
                                      line=dict(color='black', width=3)))
        
        for j in range(len(mu_a_values)):
            lines.append(go.Scatter3d(x=mu_s_values, 
                                      y=[mu_a_values[j]]*len(mu_s_values), 
                                      z=reflectance_values[:, j], 
                                      mode='lines',
                                      line=dict(color='black', width=3)))
        
        fig = go.Figure(data=lines)
        fig.update_layout(title=f"Reflectance Wireframe vs mu_s & mu_a (g={g})",
                          scene=dict(xaxis_title='mu_s', yaxis_title='mu_a', zaxis_title='Reflectance'))
        three_d_scatter_rotate(gif, fig)
        fig.show()
        gif.create_gif(gif_path='mclut_wireframe_g-{g}.gif'.format(g=g))


In [None]:
df = fetch_data('databases/hsdfm_data.db', 14)
plot_wireframe(df)