# Imports and general definitions

In [1]:
import os
import ast

import json
import matplotlib.pyplot as plt
import math
import numpy as np
import pandas as pd
import seaborn as sns

# Plot settings

In [2]:
MARKERSIZE=10
FONT_SIZE = 18
plt.rc('xtick', labelsize=FONT_SIZE)
plt.rc('ytick', labelsize=FONT_SIZE)
plt.rc('font', size=14)
plt.rc('axes', labelsize=FONT_SIZE)

%matplotlib widget

markers_dict = {
    'noFinetune': '',
    'finetune@1st': '',
    'periodic': '^',
    'exponential': '*',
    'sentence': 'v',
    'reactive': '.',
    'random': 'P',
    'flexico': 'D',
    'optimum': 's',
    'opt-1': 's',
    'opt-5': 's',
}

colors = sns.color_palette(n_colors=12, palette='gist_stern')

color_palette = {
    'noFinetune': colors[0],
    'finetune@1st': colors[2],
    'random': sns.color_palette(n_colors=10, palette='Oranges'),
    'periodic': sns.color_palette(n_colors=10, palette='Blues_r'),
    'exponential': sns.color_palette(n_colors=10, palette='Purples_r'),
    'sentence': sns.color_palette(n_colors=10, palette='RdPu_r'),
    'reactive': sns.color_palette(n_colors=10, palette='Reds'),
    'flexico': colors[8],
    'optimum': colors[10],
    'opt-1': colors[10],
    'opt-5': 'black',
}

In [None]:
print(sns.color_palette(palette='cubehelix').as_hex())

# Plotting functions

## get label and color

In [4]:
def get_label_color_marker(target, num_adaptations=None):
    
    color = None
    marker = None
    if 'nop' in target or 'noFinetune' in target or 'no_finetune' in target:
        label = "No finetune"
    elif 'finetune@1st' in target:
        label = "Finetune@1st"
    elif 'reactive' in target:
        marker = markers_dict['random']
        if "-" in target:
            react_threshold = int(target.split("-")[1])
            color = color_palette['reactive'][int(react_threshold%10)]
            label = f'Reactive-{react_threshold}'
        else:
            color = color_palette['reactive'][5]
            label = 'Reactive-85'
    elif 'exponential' in target:
        marker = markers_dict['exponential']
        if "-" in target:
            base = int(target.split("-")[1])
            color = color_palette['exponential'][base]
            label = f'Exponential-{base}'
        else:
            color = color_palette['exponential'][2]
            label = 'Exponential-2'
    elif 'opt-5' in target:
        label = f'Optimum '
    elif 'optimum' in target:
        label = 'Perfect FIP'
    elif "flexico" in target:
        label = 'Flexico'
    elif 'random' in target:
        marker = markers_dict['random']
        if "-" in target:
            prob = int(target.split("-")[1])
            color = color_palette['random'][int(prob/10)]
            label = f'Random-{prob}'
        else:
            color = color_palette['random'][5]
            label = 'Random-50'
    elif 'sentence' in target:
        marker = markers_dict['sentence']
        if "-" in target:
            num_sents = int(target.split("-")[1])
            color = color_palette['sentence'][int(num_sents/1000)]
            label = f'Sentence-{num_sents}'
        else:
            color = color_palette['sentence'][0]
            label = 'Sentence-1000'
    else:
        marker = markers_dict['periodic']
        if "-" in target:
            period = int(target.split("-")[1])
            color = color_palette['periodic'][period-1]
            label = f'Periodic-{period}'
        else:
            color = color_palette['periodic'][0]
            label = 'Periodic-1'
    
    if num_adaptations is not None:
         label = label + f" ({num_adaptations})"
    
    if color is None:
        color = color_palette[target]
    if marker is None:
        marker = markers_dict[target]
                
    return label, color, marker

## Heatmap

In [5]:
def plot_heatmap(
    finetune_costs: list, 
    delta_thresholds: list,
    data: pd.DataFrame, 
    targets: list, 
    normalize: bool = False,
):
    
    cmap = sns.cubehelix_palette(start=0, rot=0.5, as_cmap=True)
    
    cmap = 'magma_r'

    fig_size = 10
    fig, axs = plt.subplots(int(len(targets)/2), 2, figsize=(fig_size, fig_size))
        
    if axs.ndim > 1:
        axes = [sub_ax for ax in axs for sub_ax in ax]
    else:
        axes = [ax for ax in axs]
        
    if normalize:
        values = 'normalized_totalCost'
    else:
        values = 'totalCost'  # 'totalCost-avg'
        
    for baseline, ax in zip(targets, axes):
        print(baseline)
        plot_data = data.loc[
            (data.baseline == baseline)
            & (data.deltaT.isin(delta_thresholds))
            & (data.finetuneCost.isin(finetune_costs))
        ].copy()
        
        vmax = plot_data[values].max()
        
        sns.heatmap(
            plot_data.pivot(index='finetuneCost', columns='deltaT', values=values),
            fmt='.0f',
            ax=ax,
            annot=True, annot_kws={'rotation': 30},
            cmap=cmap, cbar=False, cbar_kws={"shrink": .5},
#             vmin=1, vmax=vmax,
        )
        ax.invert_yaxis()
        ax.set_title(f'{baseline}')
        ax.set_ylabel("Finetune cost")
        ax.set_xlabel("Delta threshold")
        
    plt.tight_layout()


## Bar plot

In [27]:
def _compute_week_day_avg(row):
    avg = 0
    week_adaptations = ast.literal_eval(row['adaptationsWeekDay'])
    if not isinstance(week_adaptations[0], list):
        week_adaptations = [week_adaptations]

    total_adaptations = ast.literal_eval(row['totalAdaptations'])
    if not isinstance(total_adaptations, list):
        total_adaptations = [total_adaptations]

    for week, total in zip(week_adaptations, total_adaptations):
        if int(total):
            avg += float(sum(week) / int(total))

    return avg / len(week_adaptations)
    
def plot_bar(    
    finetune_costs: list, 
    delta_thresholds: list,
    data: pd.DataFrame,
    targets: list, 
    y_col: str = 'totalCost',
):
#     fig_len = 8
#     fig_height = 4*len(finetune_costs)
#     fig, axs = plt.subplots(len(finetune_costs), 1, figsize=(fig_len, fig_height))
    
#     if len(finetune_costs) == 1:
#         axes = [axs]
#     elif axs.ndim > 1:
#         axes = [sub_ax for ax in axs for sub_ax in ax]
#     else:
#         axes = [ax for ax in axs]
        
    colors = []
    labels = []
    for target in targets:
        label, color, _ = get_label_color_marker(target, None)
        colors.append(color)
        labels.append(label)

    if 'WeekDay' in y_col:
        data['avg-weekDay'] = data.apply(
            lambda x: float(sum(ast.literal_eval(x['adaptationsWeekDay'])) / x['totalAdaptations']) if x['totalAdaptations'] else 0.0,  #_compute_week_day_avg(x),
            axis = 1,
        )
        y_col = 'avg-weekDay'
        

#     for fc, ax in zip(finetune_costs, axes):
    for fc in finetune_costs:
        
        fig, ax = plt.subplots(1, 1, figsize=(8, 5))
        
        if fc == -1:
            plot_data = data.loc[
                (data.baseline.isin(targets))
                & (data.deltaT.isin(delta_thresholds))
            ].copy()

        else:
            plot_data = data.loc[
                (data.finetuneCost == fc)
                & (data.baseline.isin(targets))
                & (data.deltaT.isin(delta_thresholds))
            ].copy()

        
        sns_plot = sns.barplot(
            data=plot_data,
            x='deltaT',
            y=y_col,
            hue='baseline',
#             kind='bar',
            ax=ax,
            hue_order=targets,
            palette=colors#,_palette,
        )
        
        # ax.set_title(f'Finetune Cost = {fc}', y=1.28)
        if 'totalCost' in y_col:
            ax.set_ylabel("Total cost")
        elif 'totalAdaptations' in y_col:
            ax.set_ylabel("#fine-tunings")
        elif 'weekDay' in y_col:
            ax.set_ylabel("Avg adaptation week day")
        ax.set_xlabel("Delta threshold")
#         ax.set_yscale('log')
        
#         if fc == finetune_costs[0]:
        ax.legend(
            loc='upper right', 
            bbox_to_anchor=(1.0, 1.35),
            ncol=3,
            frameon=False,
        )
#         else:
#             ax.legend_.remove()
        
        # rename each baseline
        for t, l in zip(sns_plot.legend_.texts, labels):
            t.set_text(l)
            
        # add num_adaptations on top of each bar
#         num_adaptations_dict = {}
#         for baseline in plot_data.baseline.unique():
#             num_adaptations_dict[baseline] = []
#             for delta in plot_data.deltaT.unique():
#                 num_adaptations_dict[baseline].append(sum(ast.literal_eval(
#                     plot_data.loc[
#                         (plot_data.baseline == baseline)
#                         & (plot_data.deltaT == delta)
#                     ]['adaptations'].to_numpy()[0])
#                 ))
#         for container in ax.containers:
#             ax.bar_label(
#                 container, 
#                 labels=num_adaptations_dict[container.get_label()]
#             )
        
        plt.tight_layout()
        if 'totalCost' in y_col:
            fig_name = f"hk_news-scenarioB-sys_u-fc_{fc}-{len(delta_thresholds)}_deltas.png"
        else:
            fig_name = f"hk_news-scenarioB-{y_col}_u-fc_{fc}-{len(delta_thresholds)}_deltas.png"

        plt.savefig(fig_name, format='png', dpi=1200, bbox_inches='tight', pad_inches=0.01)


## Line Plot

In [7]:
def plot_res(
    finetune_cost: int, 
    delta_threshold: float,
    data: pd.DataFrame, 
    targets: list,
    y_axis: str = 'cost',
    title = None,    
):
    
    all_results = data.loc[
        (data.deltaT == delta_threshold)
#         & (data.finetuneCost == finetune_cost)
    ]
    
    fig, ax = plt.subplots()
    counter = 1
    for target in targets:
        results = all_results.loc[
            all_results.baseline == target
        ]
        y = ast.literal_eval(results[y_axis].to_numpy()[0])
        
        print(f"{target}: {len(y)}")
        
        finetunings = ast.literal_eval(results['adaptations'].to_numpy()[0])
        
        label, color, marker = get_label_color_marker(target, sum(finetunings))

        marker_on = []
        if 'nop' not in target or 'noFinetune' not in target or 'no_finetune' not in target:
            marker_on = finetunings

        linestyle = 'solid'
        x = list(range(len(y)))
        if 'cost' in y_axis:
            y = np.cumsum(y)
        ax.plot(
            x, y, label=label, color=color, linestyle=linestyle,
            markevery=marker_on, markersize=MARKERSIZE, marker=marker, 
        )
        counter += 1


    ax.set(xlabel='Time', ylabel=y_axis)
    ax.legend(
        loc='best', 
        ncol=1,
    )
    ax.grid()
    if title is not None:
        ax.set_title(title)

## FIP accuracy

In [8]:
def scatter_plot(real: np.ndarray, pred: np.ndarray, title: str):
    fig, ax = plt.subplots()
    
    ax.scatter(real, pred, label="linear-reg")
    
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]

    # now plot both limits against eachother
    ax.plot(lims, lims, ':', alpha=0.75, zorder=0, label="optimum", color="black")
    ax.set_aspect('equal')
    ax.set_xlim(lims)
    ax.set_ylim(lims)
    
    ax.set(xlabel='Real', ylabel='Predicted')
    ax.legend(loc='best')
    ax.grid()
    
    ax.set_title(title)
    plt.show()
    

# Load data

In [9]:
BASE_DIR = f"{os.getcwd()}/../../"
RESULTS_DIR = BASE_DIR + "src/framework/framework_results/"

In [10]:
baselines = pd.read_csv(
    RESULTS_DIR + "adaptiveMT-topicWeights-dataset_hk-news-prism_False.csv"
)

flexico = pd.read_csv(
    RESULTS_DIR + "adaptiveMT-flexico-dataset_hk-news-prism_True-topicWeights.csv"
)

opt = pd.read_csv(
    RESULTS_DIR + "optimum_baseline-dataset_hk-news-default_nop-topicWeights_real_delayed.csv"
)

In [11]:
res = pd.concat(
    [
        baselines.loc[
            (baselines.fipFeatureSet == 'all')
            | (baselines.fipFeatureSet == 'None')
        ], 
        opt
    ], 
    ignore_index=True
)

In [None]:
res.columns

In [None]:
res.loc[
    (res.baseline.isin(['flexico', 'periodic-2']))
    & (res.finetuneCost == 1)
    & (res.deltaT.isin([0.2, 0.4, 0.6, 0.8, 1.0]))
][[
    'runID', 'baseline', 'finetuneCostType', 'finetuneCost', 'deltaT',
       'fipTargetMetrics', 'fipModel', 'fidType', 'fipFeatureSet', 'totalAdaptations', 'totalCost',
]]

# Normalize totalCost to optimum

In [None]:
res.baseline.unique()

In [15]:
# Function to normalize totalCost for 'baseline' based on 'optimum'
def normalize_total_cost(
    df: pd.DataFrame,
):
    # Get unique combinations of fc and dt
    unique_combinations = df[['finetuneCost', 'deltaT']].drop_duplicates()
    
    # Get unique baselines excluding 'optimum'
    baselines = df['baseline'].unique()
    
    # Normalize totalCost for each unique combination of fc and dt
    normalized_data = []
    
    for _, row in unique_combinations.iterrows():
        fc_value = row['finetuneCost']
        dt_value = row['deltaT']
        
        # Filter rows based on fc and dt
        subset = df[(df['finetuneCost'] == fc_value) & (df['deltaT'] == dt_value)]
        
        # Get the totalCost for 'optimum'
        optimum_costs = subset[subset['baseline'] == 'opt-5']['totalCost'].values
        
        if len(optimum_costs) > 0:  # Ensure there is at least one optimum cost to normalize by
            optimum_cost = optimum_costs[0]
            
            for baseline in baselines:
                # Get the totalCost for the current baseline
                baseline_costs = subset[subset['baseline'] == baseline]['totalCost'].values
                
                # Normalize baseline costs based on optimum cost
                normalized_baseline_costs = baseline_costs / optimum_cost
                
                # Append normalized costs to the results
                for cost in normalized_baseline_costs:
                    normalized_data.append({
                        'finetuneCost': fc_value,
                        'deltaT': dt_value,
                        'baseline': baseline,
                        'normalized_totalCost': cost
                    })
        else:
            for baseline in baselines:
                # If no optimum cost, keep the original baseline cost
                baseline_costs = subset[subset['baseline'] == baseline]['totalCost'].values
                for cost in baseline_costs:
                    normalized_data.append({
                        'finetuneCost': fc_value,
                        'deltaT': dt_value,
                        'baseline': baseline,
                        'normalized_totalCost': cost
                    })

    return pd.DataFrame(normalized_data)


In [16]:
# # Apply the normalization function
normalized_df = normalize_total_cost(
    res.loc[res['topicWeightType'] == 'delayed']
)

# # Merge the normalized values back into the original DataFrame
# res = res.merge(normalized_df, on=['finetuneCost', 'deltaT', 'baseline'], how='left')

# PRISM latency analysis

In [17]:
flexico['prismLatency-avg'] = flexico.apply(
    lambda row : np.mean(ast.literal_eval(row['formalVerificationLatencies'])),
    axis = 1,
)

flexico['prismLatency-95th'] = flexico.apply(
    lambda row : np.percentile(ast.literal_eval(row['formalVerificationLatencies']), 95),
    axis = 1,
)

flexico['prismLatency-99th'] = flexico.apply(
    lambda row : np.percentile(ast.literal_eval(row['formalVerificationLatencies']), 99),
    axis = 1,
)

flexico['mapeLatencies-avg'] = flexico.apply(
    lambda row : np.mean(ast.literal_eval(row['mapeLatencies'])),
    axis = 1,
)

flexico['mapeLatencies-95th'] = flexico.apply(
    lambda row : np.percentile(ast.literal_eval(row['mapeLatencies']), 95),
    axis = 1,
)

flexico['mapeLatencies-99th'] = flexico.apply(
    lambda row : np.percentile(ast.literal_eval(row['mapeLatencies']), 99),
    axis = 1,
)

In [None]:
flexico.loc[
    flexico['topicWeightType'] == 'delayed'
][[
    'prismLatency-avg',
    'prismLatency-95th',
    'prismLatency-99th',
    'mapeLatencies-avg',
    'mapeLatencies-95th',
    'mapeLatencies-99th',
]].describe()

In [None]:
flexico[[
    'prismLatency-avg',
    'prismLatency-95th',
    'prismLatency-99th',
    'mapeLatencies-avg',
    'mapeLatencies-95th',
    'mapeLatencies-99th',
]].describe()

# Plot results

## Baselines

In [None]:
baselines.baseline.unique()

## Line plots

In [None]:
plot_res(
    finetune_cost=10,
    delta_threshold=0.5,
    data=res.loc[res['topicWeightType'] == 'delayed'], 
    targets=[
        'periodic-2',
        'reactive-85',
#         'random',
        'exponential',
        'flexico',
        'opt-5',
    ],
    title = None
)

In [None]:
plot_res(
    finetune_cost=10,
    delta_threshold=0.5,
    data=res.loc[res['topicWeightType'] == 'delayed'], 
    targets=[
        'periodic-2',
        'reactive-85',
#         'random',
        'exponential',
        'flexico',
        'opt-5',
    ],
    title = None,
    y_axis = 'adaptationsWeekDay',
)

## Heatmaps

In [None]:
plot_heatmap(
    finetune_costs = [1, 5, 10, 15, 20, 25], 
    delta_thresholds = list(np.arange(0.1, 1.1, 0.1)),
    data=normalized_df, 
    targets=[
        'periodic-2',
#         'reactive-85',
#         'random',
#         'random-75',
        'exponential',
        'sentence',
        'sentence-2000',
        'flexico',
        'opt-5',
    ],
    normalize=True,
)

## Bar plots

In [None]:
plot_bar(
    finetune_costs = [1, 5, 10, 15, 20, 25], 
    # delta_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    delta_thresholds = [0.2, 0.4, 0.6, 0.8, 1.0],
    data=res.loc[res['topicWeightType'] == 'delayed'], 
    targets=[
        'random',
        'random-75',
        'reactive-85',
        'periodic-2',
        'exponential',
        'sentence',
        'sentence-2000',
        'flexico',
        'opt-5',
    ],
)

In [None]:
plot_bar(
    finetune_costs = [1, 5, 10, 15, 20, 25], 
    delta_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
    #delta_thresholds = [0.2, 0.4, 0.6, 0.8, 1.0],
    data=res.loc[res['topicWeightType'] == 'delayed'], 
    targets=[
        'random',
        'random-75',
        'reactive-85',
        'periodic-2',
        'exponential',
        'sentence',
        'sentence-2000',
        'flexico',
        'opt-5',
    ],
)