In [1]:
import pandas as pd
import numpy as np
from lmfit import Model
import glob
import re
from bokeh.plotting import figure, output_file, save
from bokeh.models import Select, ColumnDataSource, Div, CustomJS
from bokeh.layouts import column, row

In [None]:
import pandas as pd
import numpy as np
from lmfit import Model
import glob
import re
from bokeh.plotting import figure, output_file, save
from bokeh.models import Select, ColumnDataSource, Div, CustomJS
from bokeh.layouts import column, row

BASE_URL = 'https://raw.githubusercontent.com/igor-sadalski/Scaling-up-measurement-noise-scaling-laws/main/'
CALTECH_URL = 'https://raw.githubusercontent.com/ggdna/scScaling/main/results/'

RENAME_DICT = {
    'celltype.l3': 'Cell type MI',
    'protein_counts': 'Protein MI',
    'clone': 'Clonal MI',
    'author_day': 'Temporal MI',
    'ng_idx': 'Spatial MI',
    'RandomProjection': 'Rand. Proj.'
}

# Load data
gaussian_df = pd.read_csv(f'{CALTECH_URL}Caltech101_Gaussian.csv')
gaussian_df['Scale'] = gaussian_df['Scale']**2
res_df = pd.read_csv(f'{CALTECH_URL}Caltech101_resolution.csv')
df = pd.read_csv(f'{BASE_URL}collect_mi_results.csv').replace(RENAME_DICT)
sc_param_df_noise = pd.read_csv('analysis/final_results/scaling_plots_u_bar_138.109_I_max_1.419.csv').replace(RENAME_DICT)
seq_df = pd.read_csv('seq/multisize_gisaid_results.csv')

# Load TissueMNIST data
csv_files = glob.glob('images/tissuemnist_models/result_*.csv')
dfs_tissue = []
for file in csv_files:
    match = re.search(r'result_(.+)\.csv', file)
    if match.group(1) == 'clean':
        downsampling_type, downsampling_level = 'clean', 0.0
    elif 'pix' in match.group(1):
        downsampling_type = 'pixel'
        downsampling_level = float(match.group(1).split('_')[1][:-1])
    elif 'gauss' in match.group(1):
        downsampling_type = 'gaussian'
        downsampling_level = float(match.group(1).split('_')[1][:-1])
    else:
        downsampling_type, downsampling_level = 'unknown', 0.0
    df_temp = pd.read_csv(file)
    df_temp['downsampling_level'] = downsampling_level
    df_temp['downsampling_type'] = downsampling_type
    dfs_tissue.append(df_temp)
combined_df = pd.concat(dfs_tissue, ignore_index=True)

# ============================================================================
# Fitting Function
# ============================================================================
def info_scaling_model(x, A, B):
    """Parameterized info scaling: 0.5*log2((x*B+1)/(1+A*x))"""
    return 0.5 * np.log2((x*B + 1)/(1 + A*x))

def fit_info_model(x_data, y_data):
    """Fit info_scaling_model and return parameters, R², and result object if successful"""
    model = Model(info_scaling_model)
    params = model.make_params(A=1e-2, B=1e-2)
    params['A'].min = params['B'].min = 0
    
    try:
        result = model.fit(y_data, params, x=x_data)
        a, b = result.params['A'], result.params['B']
        if a.stderr and b.stderr and a.stderr < a.value and b.stderr < b.value:
            u_bar = 1/a.value
            I_max = 0.5*np.log2(b.value/a.value)
            
            # Calculate R²
            ss_res = np.sum(result.residual**2)
            ss_tot = np.sum((y_data - np.mean(y_data))**2)
            r_squared = 1 - (ss_res / ss_tot)
            
            # Get confidence intervals (2-sigma)
            u_bar_err = a.stderr / (a.value**2)  # Error propagation for 1/A
            I_max_err = 0.5 / np.log(2) * np.sqrt((b.stderr/b.value)**2 + (a.stderr/a.value)**2)
            
            return u_bar, I_max, u_bar_err, I_max_err, r_squared, result
    except:
        pass
    return None, None, None, None, None, None


# ============================================================================
# 1) Generate Table of Fitted Parameters
# ============================================================================
hue_order = ['Rand. Proj.', 'PCA', 'SCVI', 'Geneformer']
hue_order_metrics = ['Protein MI', 'Clonal MI', 'Temporal MI', 'Spatial MI']

# Storage for all curves
all_curves = []

# --- Single-cell noise curves ---
for sig in hue_order_metrics:
    for size in df['size'].unique():
        for alg in hue_order:
            data = df[(df['signal']==sig) & (df['size']==size) & (df['algorithm']==alg)]
            if len(data) < 9:
                continue
            
            x_data = data['umis_per_cell'].values
            y_data = data['mi_value'].values
            u_bar, I_max, u_bar_err, I_max_err, r_squared, result = fit_info_model(x_data, y_data)
            
            if u_bar is not None and I_max is not None:
                # Generate confidence bands
                x_fit = np.logspace(np.log10(x_data.min()/5), np.log10(x_data.max()*5), 200)
                y_fit = result.eval(x=x_fit)
                y_err = result.eval_uncertainty(x=x_fit, sigma=2)
                
                curve_id = f"SC_{sig}_{alg}_size{int(size)}"
                all_curves.append({
                    'curve_id': curve_id,
                    'category': 'Single-cell',
                    'metric': sig,
                    'method': alg,
                    'size': f"{int(size)} cells",
                    'u_bar': u_bar,
                    'I_max': I_max,
                    'u_bar_err': u_bar_err,
                    'I_max_err': I_max_err,
                    'r_squared': r_squared,
                    'x_data': x_data.tolist(),
                    'y_data': y_data.tolist(),
                    'x_fit': x_fit.tolist(),
                    'y_fit': y_fit.tolist(),
                    'y_err': y_err.tolist()
                })

# --- Caltech101 Gaussian ---
for class_label in gaussian_df['Class label'].unique()[:-1]:
    data = gaussian_df[gaussian_df['Class label'] == class_label]
    x_data = 1/data['Scale'].values
    y_data = data['MI'].values
    
    u_bar, I_max, u_bar_err, I_max_err, r_squared, result = fit_info_model(x_data, y_data)
    if u_bar is not None and I_max is not None:
        x_fit = np.logspace(np.log10(x_data.min()/5), np.log10(x_data.max()*5), 200)
        y_fit = result.eval(x=x_fit)
        y_err = result.eval_uncertainty(x=x_fit, sigma=2)
        
        curve_id = f"Caltech101_Gaussian_class{class_label}"
        all_curves.append({
            'curve_id': curve_id,
            'category': 'Caltech101-Gaussian',
            'metric': f'Class {class_label}',
            'method': 'Mobilenet',
            'size': 'N/A',
            'u_bar': u_bar,
            'I_max': I_max,
            'u_bar_err': u_bar_err,
            'I_max_err': I_max_err,
            'r_squared': r_squared,
            'x_data': x_data.tolist(),
            'y_data': y_data.tolist(),
            'x_fit': x_fit.tolist(),
            'y_fit': y_fit.tolist(),
            'y_err': y_err.tolist()
        })

# --- Caltech101 Resolution ---
for class_label in res_df['Class label'].unique()[:-1]:
    data = res_df[res_df['Class label'] == class_label]
    x_data = 1/data['Factor'].values
    y_data = data['MI'].values
    
    u_bar, I_max, u_bar_err, I_max_err, r_squared, result = fit_info_model(x_data, y_data)
    if u_bar is not None and I_max is not None:
        x_fit = np.logspace(np.log10(x_data.min()/5), np.log10(x_data.max()*5), 200)
        y_fit = result.eval(x=x_fit)
        y_err = result.eval_uncertainty(x=x_fit, sigma=2)
        
        curve_id = f"Caltech101_Resolution_class{class_label}"
        all_curves.append({
            'curve_id': curve_id,
            'category': 'Caltech101-Pixelation',
            'metric': f'Class {class_label}',
            'method': 'Mobilenet',
            'size': 'N/A',
            'u_bar': u_bar,
            'I_max': I_max,
            'u_bar_err': u_bar_err,
            'I_max_err': I_max_err,
            'r_squared': r_squared,
            'x_data': x_data.tolist(),
            'y_data': y_data.tolist(),
            'x_fit': x_fit.tolist(),
            'y_fit': y_fit.tolist(),
            'y_err': y_err.tolist()
        })

# --- Sequences (ESM2 models) ---
model_sizes = sorted(seq_df['model_size'].unique())
for model_size in model_sizes:
    data = seq_df[seq_df['model_size'] == model_size]
    x_data = data['true/error'].values
    y_data = data['mutual_information'].values
    
    u_bar, I_max, u_bar_err, I_max_err, r_squared, result = fit_info_model(x_data, y_data)
    if u_bar is not None and I_max is not None:
        x_fit = np.logspace(np.log10(x_data.min()/5), np.log10(x_data.max()*5), 200)
        y_fit = result.eval(x=x_fit)
        y_err = result.eval_uncertainty(x=x_fit, sigma=2)
        
        curve_id = f"Sequences_ESM2_{model_size}"
        all_curves.append({
            'curve_id': curve_id,
            'category': 'Sequences',
            'metric': 'Collection month MI',
            'method': f'ESM2-{model_size}',
            'size': 'N/A',
            'u_bar': u_bar,
            'I_max': I_max,
            'u_bar_err': u_bar_err,
            'I_max_err': I_max_err,
            'r_squared': r_squared,
            'x_data': x_data.tolist(),
            'y_data': y_data.tolist(),
            'x_fit': x_fit.tolist(),
            'y_fit': y_fit.tolist(),
            'y_err': y_err.tolist()
        })

# --- TissueMNIST Pixel Downsampling ---
label_map = {
    'ova_mi_continuous_Class_0': 'Collecting Duct',
    'ova_mi_continuous_Class_1': 'Distal Convoluted Tubule',
    'ova_mi_continuous_Class_2': 'Glomerular endothelial',
    'ova_mi_continuous_Class_3': 'Interstitial endothelial',
    'ova_mi_continuous_Class_4': 'Leukocytes',
    'ova_mi_continuous_Class_5': 'Podocytes',
    'ova_mi_continuous_Class_6': 'Proximal Tubule',
    'ova_mi_continuous_Class_7': 'Thick Ascending Limb',
    'mi_score': '8-class MI',
}

pix = combined_df[combined_df['downsampling_type'] == 'pixel'].copy()
pix['inv_factor'] = 1 / pix['downsampling_level']
ova_columns = ['mi_score'] + [col for col in pix.columns if 'ova_mi_continuous' in col]

for col in ova_columns:
    mask = ~pix[col].isna() & ~pix['inv_factor'].isna()
    x_data = (pix[mask]['inv_factor'].values)**2
    y_data = pix[mask][col].values
    
    if len(x_data) < 3:
        continue
    
    u_bar, I_max, u_bar_err, I_max_err, r_squared, result = fit_info_model(x_data, y_data)
    if u_bar is not None and I_max is not None:
        x_fit = np.logspace(np.log10(x_data.min()/5), np.log10(x_data.max()*5), 200)
        y_fit = result.eval(x=x_fit)
        y_err = result.eval_uncertainty(x=x_fit, sigma=2)
        
        metric_name = label_map.get(col, col)
        curve_id = f"TissueMNIST_Pixel_{col}"
        all_curves.append({
            'curve_id': curve_id,
            'category': 'TissueMNIST-Pixelation',
            'metric': metric_name,
            'method': 'Mobilenet',
            'size': 'N/A',
            'u_bar': u_bar,
            'I_max': I_max,
            'u_bar_err': u_bar_err,
            'I_max_err': I_max_err,
            'r_squared': r_squared,
            'x_data': x_data.tolist(),
            'y_data': y_data.tolist(),
            'x_fit': x_fit.tolist(),
            'y_fit': y_fit.tolist(),
            'y_err': y_err.tolist()
        })

# --- TissueMNIST Gaussian Noise ---
gauss = combined_df[combined_df['downsampling_type'] == 'gaussian'].copy()

for col in ova_columns:
    mask = ~gauss[col].isna()
    if mask.sum() < 3:
        continue
    x_data = 1/gauss[mask]['downsampling_level'].values
    y_data = gauss[mask][col].values
    
    u_bar, I_max, u_bar_err, I_max_err, r_squared, result = fit_info_model(x_data, y_data)
    if u_bar is not None and I_max is not None:
        x_fit = np.logspace(np.log10(x_data.min()/5), np.log10(x_data.max()*5), 200)
        y_fit = result.eval(x=x_fit)
        y_err = result.eval_uncertainty(x=x_fit, sigma=2)
        
        metric_name = label_map.get(col, col)
        curve_id = f"TissueMNIST_Gaussian_{col}"
        all_curves.append({
            'curve_id': curve_id,
            'category': 'TissueMNIST-Gaussian',
            'metric': metric_name,
            'method': 'Mobilenet',
            'size': 'N/A',
            'u_bar': u_bar,
            'I_max': I_max,
            'u_bar_err': u_bar_err,
            'I_max_err': I_max_err,
            'r_squared': r_squared,
            'x_data': x_data.tolist(),
            'y_data': y_data.tolist(),
            'x_fit': x_fit.tolist(),
            'y_fit': y_fit.tolist(),
            'y_err': y_err.tolist()
        })

# Create DataFrame and save to CSV
params_df = pd.DataFrame([{
    'curve_id': c['curve_id'],
    'category': c['category'],
    'metric': c['metric'],
    'method': c['method'],
    'size': c['size'],
    'u_bar': c['u_bar'],
    'I_max': c['I_max'],
    'u_bar_err': c['u_bar_err'],
    'I_max_err': c['I_max_err'],
    'r_squared': c['r_squared']
} for c in all_curves])

params_df.to_csv('analysis/noise_fit_parameters.csv', index=False)
print(f"Saved {len(params_df)} curves to 'analysis/noise_fit_parameters.csv'")
print(params_df.head(20))

# ============================================================================
# 2) Bokeh Interactive Visualization with Responsive Dropdowns
# ============================================================================

# Get unique values for initial category
categories = sorted(list(set(c['category'] for c in all_curves)))

# Create initial data source with first curve
first_curve = all_curves[0]

# Get valid options for first category
first_cat_curves = [c for c in all_curves if c['category'] == first_curve['category']]
initial_metrics = sorted(list(set(c['metric'] for c in first_cat_curves)))

first_metric_curves = [c for c in first_cat_curves if c['metric'] == first_curve['metric']]
initial_methods = sorted(list(set(c['method'] for c in first_metric_curves)))

first_method_curves = [c for c in first_metric_curves if c['method'] == first_curve['method']]
initial_sizes = sorted(list(set(c['size'] for c in first_method_curves)))

x_data = np.array(first_curve['x_data'])
y_data = np.array(first_curve['y_data'])
sort_idx = np.argsort(x_data)
x_sorted = x_data[sort_idx]
y_sorted = y_data[sort_idx]

x_fit = np.array(first_curve['x_fit'])
y_fit = np.array(first_curve['y_fit'])
y_err = np.array(first_curve['y_err'])

source_scatter = ColumnDataSource(data=dict(x=x_sorted, y=y_sorted))
source_fit = ColumnDataSource(data=dict(x=x_fit, y=y_fit))
source_ci_upper = ColumnDataSource(data=dict(x=x_fit, y=y_fit+y_err))
source_ci_lower = ColumnDataSource(data=dict(x=x_fit, y=y_fit-y_err))

# Create figure
p = figure(title=f"{first_curve['category']} | {first_curve['metric']} | {first_curve['method']} | {first_curve['size']}", 
           x_axis_type="log", width=600, height=400,
           x_axis_label=r"$\eta$", y_axis_label="auxiliary MI (bits)")

p.xaxis.axis_label_text_font_size = "14pt"
p.yaxis.axis_label_text_font_size = "14pt"
p.xaxis.major_label_text_font_size = "12pt"
p.yaxis.major_label_text_font_size = "12pt"

# Create combined source for confidence band
source_ci = ColumnDataSource(data=dict(x=x_fit, y1=y_fit-y_err, y2=y_fit+y_err))

# Then in the plot:
p.varea(x='x', y1='y1', y2='y2', source=source_ci, 
        alpha=0.2, color='lightblue', legend_label="2 sigma CI")

p.scatter('x', 'y', source=source_scatter, size=10, color='purple', alpha=0.7, legend_label="data")
p.line('x', 'y', source=source_fit, line_width=2, color='lightblue', line_dash='dashed', legend_label="fit")
p.legend.location = "bottom_right"

# Create parameter display
u_bar = first_curve['u_bar']
I_max = first_curve['I_max']
u_bar_err = first_curve['u_bar_err']
I_max_err = first_curve['I_max_err']
r_squared = first_curve['r_squared']

param_div = Div(text=f"""
<div style="font-size: 16px; padding: 10px; background-color: #f0f0f0; border-radius: 5px;">
    <b>Fitted Parameters:</b><br>
    <span style="color: darkgreen;">ū = {u_bar:.4f} ± {u_bar_err:.4f}</span><br>
    <span style="color: darkblue;">I<sub>max</sub> = {I_max:.4f} ± {I_max_err:.4f} bits</span><br>
    <span style="color: purple;">R² = {r_squared:.4f}</span><br>
</div>
""", width=300)

# Create dropdown menus with initial valid options
select_category = Select(title="data domain:", value=first_curve['category'], options=categories, width=200)
select_metric = Select(title="metric:", value=first_curve['metric'], options=initial_metrics, width=200)
select_method = Select(title="method:", value=first_curve['method'], options=initial_methods, width=200)
select_size = Select(title="dataset size:", value=first_curve['size'], options=initial_sizes, width=200)

# Prepare all curve data as JSON for JavaScript callback
import json

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.integer, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float64)):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)

curve_data_json = json.dumps([{
    'category': c['category'],
    'metric': c['metric'],
    'method': c['method'],
    'size': c['size'],
    'x': c['x_data'],
    'y': c['y_data'],
    'x_fit': c['x_fit'],
    'y_fit': c['y_fit'],
    'y_err': c['y_err'],
    'u_bar': c['u_bar'],
    'I_max': c['I_max'],
    'u_bar_err': c['u_bar_err'],
    'I_max_err': c['I_max_err'],
    'r_squared': c['r_squared']
} for c in all_curves], cls=NumpyEncoder)

# JavaScript callback for dropdowns
callback = CustomJS(args=dict(
    source_scatter=source_scatter, 
    source_fit=source_fit,
    source_ci=source_ci,
    param_div=param_div,
    p=p,
    select_category=select_category,
    select_metric=select_metric,
    select_method=select_method,
    select_size=select_size
), code=f"""
    const all_curves = {curve_data_json};
    
    // Determine which dropdown triggered the callback
    const trigger = cb_obj;
    
    // Get current selections
    let sel_cat = select_category.value;
    let sel_met = select_metric.value;
    let sel_meth = select_method.value;
    let sel_size = select_size.value;
    
    // Update available options based on selections
    // Start from the triggered dropdown and cascade down
    
    // Filter by category
    const cat_curves = all_curves.filter(c => c.category === sel_cat);
    const avail_metrics = [...new Set(cat_curves.map(c => c.metric))].sort().reverse();
    
    // If metric is no longer valid, select first available
    if (!avail_metrics.includes(sel_met)) {{
        sel_met = avail_metrics[0];
        select_metric.value = sel_met;
    }}
    select_metric.options = avail_metrics;
    
    // Filter by category and metric
    const met_curves = cat_curves.filter(c => c.metric === sel_met);
    const avail_methods = [...new Set(met_curves.map(c => c.method))].sort().reverse();
    
    // If method is no longer valid, select first available
    if (!avail_methods.includes(sel_meth)) {{
        sel_meth = avail_methods[0];
        select_method.value = sel_meth;
    }}
    select_method.options = avail_methods;
    
    // Filter by category, metric, and method
    const meth_curves = met_curves.filter(c => c.method === sel_meth);
    const avail_sizes = [...new Set(meth_curves.map(c => c.size))].sort().reverse();
    
    // If size is no longer valid, select first available
    if (!avail_sizes.includes(sel_size)) {{
        sel_size = avail_sizes[0];
        select_size.value = sel_size;
    }}
    select_size.options = avail_sizes;
    
    // Find matching curve with final selections
    const curve = all_curves.find(c => 
        c.category === sel_cat && 
        c.metric === sel_met && 
        c.method === sel_meth && 
        c.size === sel_size
    );
    
    if (!curve) {{
        console.log("No matching curve found");
        return;
    }}
    
    // Update scatter data
    const x = curve.x;
    const y = curve.y;
    
    // Sort by x
    const indices = [...x.keys()].sort((a, b) => x[a] - x[b]);
    const x_sorted = indices.map(i => x[i]);
    const y_sorted = indices.map(i => y[i]);
    
    source_scatter.data = {{x: x_sorted, y: y_sorted}};
    
    // Update fit line and confidence bands
    source_fit.data = {{x: curve.x_fit, y: curve.y_fit}};
    
    const y_upper = curve.y_fit.map((val, i) => val + curve.y_err[i]);
    const y_lower = curve.y_fit.map((val, i) => val - curve.y_err[i]);

    source_ci.data = {{x: curve.x_fit, y1: y_lower, y2: y_upper}};
    
    // Update title
    p.title.text = sel_cat + " | " + sel_met + " | " + sel_meth + " | " + sel_size;
    
    // Update parameter display
    param_div.text = `
    <div style="font-size: 16px; padding: 10px; background-color: #f0f0f0; border-radius: 5px;">
        <b>Fitted Parameters:</b><br>
        <span style="color: darkgreen;">ū = ${{curve.u_bar.toFixed(4)}} ± ${{curve.u_bar_err.toFixed(4)}}</span><br>
        <span style="color: darkblue;">I<sub>max</sub> = ${{curve.I_max.toFixed(4)}} ± ${{curve.I_max_err.toFixed(4)}} bits</span><br>
        <span style="color: purple;">R² = ${{curve.r_squared.toFixed(4)}}</span><br>
    </div>
    `;
""")

select_category.js_on_change('value', callback)
select_metric.js_on_change('value', callback)
select_method.js_on_change('value', callback)
select_size.js_on_change('value', callback)

# Layout
dropdowns = row(select_category, select_metric, select_method, select_size)
layout = column(dropdowns, row(p, param_div))

# Save to HTML
output_file("analysis/noise_fits_interactive.html", title="Noise scaling zoo")
save(layout)
print("Saved interactive visualization to 'analysis/noise_fits_interactive.html'")

Saved 183 curves to 'analysis/noise_fit_parameters.csv'
                              curve_id     category      metric       method  \
0    SC_Protein MI_Rand. Proj._size100  Single-cell  Protein MI  Rand. Proj.   
1            SC_Protein MI_PCA_size100  Single-cell  Protein MI          PCA   
2           SC_Protein MI_SCVI_size100  Single-cell  Protein MI         SCVI   
3     SC_Protein MI_Geneformer_size100  Single-cell  Protein MI   Geneformer   
4    SC_Protein MI_Rand. Proj._size215  Single-cell  Protein MI  Rand. Proj.   
5            SC_Protein MI_PCA_size215  Single-cell  Protein MI          PCA   
6           SC_Protein MI_SCVI_size215  Single-cell  Protein MI         SCVI   
7     SC_Protein MI_Geneformer_size215  Single-cell  Protein MI   Geneformer   
8    SC_Protein MI_Rand. Proj._size464  Single-cell  Protein MI  Rand. Proj.   
9            SC_Protein MI_PCA_size464  Single-cell  Protein MI          PCA   
10          SC_Protein MI_SCVI_size464  Single-cell  Protein MI 