In [40]:
import os
import re
import pandas as pd
import plotly.graph_objects as go

# === Threshold Parameters ===
ACCURACY_THRESHOLD = 98
COMPRESSION_THRESHOLD = 3.5

def extract_metrics_from_file(filepath, accuracy_threshold, compression_threshold):
    results = []
    with open(filepath, 'r') as f:
        lines = f.readlines()

    # === Extract header ===
    header_params = {}
    header_end_idx = None
    for i, line in enumerate(lines):
        if line.strip() == "------------------------------------------------------------":
            header_end_idx = i
            break

    if header_end_idx is not None:
        for line in lines[:header_end_idx]:
            line = line.strip()
            if '=' in line and not line.startswith('='):
                key, val = line.split('=', 1)
                header_params[key.strip()] = val.strip()

    current_var_name = None
    current_var_value = None
    last_epoch_seen = None
    current_extra_param_name = None
    current_extra_param_value = None
    epoch_context = None

    for i, line in enumerate(lines):
        
        # === Aggiorna var_name e var_value se presente ===
        if "------------------------------------------------------------" in line:
            if i + 1 < len(lines):
                header_line = lines[i + 1].strip()
                match_header = re.match(r"([a-zA-Z_]+)\s*=\s*([0-9.\-]+)", header_line)
                if match_header:
                    current_var_name = match_header.group(1)
                    current_var_value = float(match_header.group(2))
                    
        # === Traccia l'epoca corrente se presente ===
        match_epoch = re.search(r"Epoch\s+(\d+):", line)
        if match_epoch:
            last_epoch_seen = int(match_epoch.group(1))
        
        # === sparse_ratio e sparse_accuracy sulla stessa riga ===
        if "sparse_ratio" in line:
            match_sparse = re.search(
                r"Epoch\s+(\d+).*?sparse_ratio\s*=\s*([\d.]+)%.*?sparse_accuracy\s*=\s*([\d.]+).*?sparsity\s*=\s*([\d.]+)%", 
                line
            )

            if match_sparse:
                epoch = int(match_sparse.group(1))
                sparse_ratio = float(match_sparse.group(2))
                sparse_accuracy = float(match_sparse.group(3))
                sparsity = float(match_sparse.group(4))

                if sparse_accuracy > accuracy_threshold and sparse_ratio < compression_threshold:
                    results.append({
                        'accuracy': sparse_accuracy,
                        'zstd_ratio': sparse_ratio,
                        'sparsity': sparsity,
                        'source': os.path.basename(filepath),
                        'var_name': current_var_name,
                        'var_value': current_var_value,
                        'header_params': header_params,
                        'epoch': epoch,
                        'sparse' : 'Y',
                        current_extra_param_name: current_extra_param_value
                    })   
                

        # === Blocco multi-riga dopo il simbolo ➡️ ===
        match_arrow = re.search(r"➡️.*?(?:(\w+)\s*=\s*(-?[\d.]+)[^\d]*)?Epoch\s+(\d+):", line)
        if match_arrow:
            extracted_param_name = match_arrow.group(1)
            extracted_param_value = match_arrow.group(2)
            epoch_context = int(match_arrow.group(3))

            # Se presente un nuovo parametro, lo aggiorno
            if extracted_param_name:
                current_extra_param_name = extracted_param_name
                current_extra_param_value = float(extracted_param_value)

            continue  # prossima riga


        if epoch_context is not None:
            match_sparse_block = re.search(
                r"sparse_ratio\s*=\s*([\d.]+)%.*?sparse_accuracy\s*=\s*([\d.]+).*?sparsity\s*=\s*([\d.]+)%", 
                line
            )
            if match_sparse_block:
                sparse_ratio = float(match_sparse_block.group(1))
                sparse_accuracy = float(match_sparse_block.group(2))
                sparsity = float(match_sparse_block.group(3))

                if sparse_accuracy > accuracy_threshold and sparse_ratio < compression_threshold:
                    result_entry = {
                        'accuracy': sparse_accuracy,
                        'zstd_ratio': sparse_ratio,
                        'sparsity': sparsity,
                        'source': os.path.basename(filepath),
                        'var_name': current_var_name,
                        'var_value': current_var_value,
                        'header_params': header_params,
                        'epoch': epoch_context,
                        'sparse': 'Y'
                    }
                    if current_extra_param_name:
                        result_entry[current_extra_param_name] = current_extra_param_value
                    results.append(result_entry)


                epoch_context = None  # resetto dopo aver trovato i dati

        # === Caso classico A_Q e zstd_ratio ===
        match_full_line = re.search(r"A_Q\s*=\s*([\d.]+).*?zstd_ratio\s*=\s*([\d.]+)%", line)
        if match_full_line:
            acc = float(match_full_line.group(1))
            ratio = float(match_full_line.group(2))
            if acc > accuracy_threshold and ratio < compression_threshold:
                results.append({
                    'accuracy': acc,
                    'zstd_ratio': ratio,
                    'source': os.path.basename(filepath),
                    'var_name': current_var_name,
                    'var_value': current_var_value,
                    'header_params': header_params,
                    'epoch': last_epoch_seen,
                    'sparse': 'N',
                    current_extra_param_name: current_extra_param_value
                })
            continue

        # === Caso "to" + zstd_ratio ===
        if "to" in line:
            match_to = re.search(r'to\s+([\d.]+)', line)
            if match_to and i + 1 < len(lines):
                acc = float(match_to.group(1))
                next_line = lines[i + 1]
                match_ratio = re.search(r'zstd_ratio\s*=\s*([\d.]+)%', next_line)
                if match_ratio:
                    ratio = float(match_ratio.group(1))
                    if acc > accuracy_threshold and ratio < compression_threshold:
                        results.append({
                            'accuracy': acc,
                            'zstd_ratio': ratio,
                            'source': os.path.basename(filepath),
                            'var_name': current_var_name,
                            'var_value': current_var_value,
                            'header_params': header_params,
                            'epoch': last_epoch_seen, 
                            'sparse': 'N',
                            current_extra_param_name: current_extra_param_value
                        })

    return results

def header_params_to_str(params, exclude_keys=None):
    if exclude_keys is None:
        exclude_keys = []
    return '<br>'.join([f"{k}={v}" for k, v in params.items() if k not in exclude_keys])

def efficient_frontier(df):
    df_sorted = df.sort_values(by=['zstd_ratio', 'accuracy'], ascending=[True, False])
    frontier_rows = []
    max_acc = -1
    for _, row in df_sorted.iterrows():
        if row['accuracy'] > max_acc:
            frontier_rows.append(row)
            max_acc = row['accuracy']
    return pd.DataFrame(frontier_rows)

def build_label(row):
    # === Parametri chiave sempre in alto ===
    primary_params = []
    exclude_keys = []

    for param in ['delta', 'r']:
        if param in row and pd.notna(row[param]):
            primary_params.append(f"{param} = {row[param]}")
            exclude_keys.append(param)

    # === Blocco fisso ===
    label_lines = [
        f"File: {row['source']}",
        *primary_params,  # delta e r subito dopo il file
        f"{row['var_name']} = {row['var_value']}",
        f"Accuracy: {row['accuracy']}",
        f"Compression: {row['zstd_ratio']}%",
        f"Epoch: {row.get('epoch', 'N/A')}",
        f"Sparse: {row.get('sparse', 'N')}"
    ]

    if 'sparsity' in row and pd.notna(row['sparsity']):
        label_lines.append(f"Sparsity: {row['sparsity']}%")

    # === Parametri header ===
    header_params_str = header_params_to_str(row['header_params'], exclude_keys)
    if header_params_str:
        label_lines.append(f"<b>Params:</b>")
        label_lines.extend(header_params_str.split("<br>"))

    # === Rimozione duplicati nel resto della label ===
    seen = {}
    for idx, line in enumerate(label_lines):
        key_match = re.match(r"(\w+)\s*=", line)
        if key_match:
            key = key_match.group(1)
            seen[key] = idx  # Tiene traccia dell'ultima posizione di ciascun parametro

    final_lines = []
    for idx, line in enumerate(label_lines):
        key_match = re.match(r"(\w+)\s*=", line)
        if key_match:
            key = key_match.group(1)
            if seen[key] != idx:
                continue  # Skippa i duplicati
        final_lines.append(line)

    return "<br>".join(final_lines)

def plot_results(df, frontier_df, recent_files, accuracy_threshold, compression_threshold):
    df['is_recent_file'] = df['source'].apply(lambda x: x in recent_files)
    df['label'] = df.apply(build_label, axis=1)
    frontier_df['label'] = frontier_df.apply(build_label, axis=1)

    df['is_frontier'] = df.index.isin(frontier_df.index)

    recent_points = df[df['is_recent_file']]
    blue_points = df[(~df['is_recent_file']) & (~df['is_frontier'])]
    red_points = frontier_df

    trace_blue = go.Scatter(
        x=blue_points['accuracy'],
        y=blue_points['zstd_ratio'],
        mode='markers',
        name='Unefficient points',
        marker=dict(size=8, color='blue'),
        text=blue_points['label'],
        hoverinfo='text'
    )

    trace_red = go.Scatter(
        x=red_points['accuracy'],
        y=red_points['zstd_ratio'],
        mode='markers+lines',
        name='Pareto Frontier',
        line=dict(color='red', width=2, dash='dash'),
        marker=dict(size=10, symbol='diamond'),
        text=red_points['label'],
        hoverinfo='text'
    )

    trace_yellow = go.Scatter(
        x=recent_points['accuracy'],
        y=recent_points['zstd_ratio'],
        mode='markers',
        name='Last two tests',
        marker=dict(size=9, color='yellow', line=dict(color='black', width=1)),
        text=recent_points['label'],
        hoverinfo='text'
    )

    fig = go.Figure(data=[trace_blue, trace_red, trace_yellow])

    fig.update_layout(
        title=f'Pareto Frontier [ACCURACY_THRESHOLD={accuracy_threshold}, COMPRESSION_THRESHOLD={compression_threshold}]',
        xaxis_title='Accuracy (%)',
        yaxis_title='Compression Ratio (%)',
        template='plotly_white'
    )

    return fig

# === File discovery and sorting ===
file_list = [f for f in os.listdir('.') if f.startswith('output-Test-2025-')]
file_list.sort(key=lambda f: os.path.getmtime(f))
last_two_files = file_list[-2:] if len(file_list) >= 2 else file_list

# === Extract metrics from all files ===
all_data = []
for filename in file_list:
    all_data.extend(extract_metrics_from_file(filename, ACCURACY_THRESHOLD, COMPRESSION_THRESHOLD))

df = pd.DataFrame(all_data)
frontier_df = efficient_frontier(df)

# Generate plot and save to HTML
fig = plot_results(df, frontier_df, last_two_files, ACCURACY_THRESHOLD, COMPRESSION_THRESHOLD)
fig.write_html("ParetoFrontier.html", include_plotlyjs='embed', full_html=True)