In [26]:
import matplotlib.pyplot as plt

from os import getcwd, path
import tarfile
from pandas import Series, DataFrame
import pandas as pd
import numpy as np


import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
pio.templates.default = "plotly_white"

import warnings
warnings.filterwarnings("ignore")

from typing import Set, List
from ast import literal_eval

In [27]:
# directory from where data will be imported
data_dir = path.join(getcwd(),'data')

# directory for saving plots
plot_dir = path.join(getcwd(), "plots")

In [28]:
# discrete color palettes

prism_light = ["rgb(175, 163, 200)",
               "rgb(142, 180, 203)",
               "rgb(156, 211, 210)",
               "rgb(135, 194, 170)",
               "rgb(185, 215, 164)",
               "rgb(246, 214, 132)",
               "rgb(240, 190, 130)",
               "rgb(230, 168, 159)",
               "rgb(202, 154, 183)",
               "rgb(183, 160, 184)",
               "rgb(179, 179, 179)"]

prism_ext = px.colors.qualitative.Prism + prism_light 

In [29]:
# utility functions

# literal evaluation of specific columns in a dataframe
def literal_eval_cols(data: DataFrame, cols: List[str]):
    for col_name in cols:
        data[col_name] = data.apply(lambda x: literal_eval(x[col_name]), axis=1)
        
# position to integer representation
def to_int(position):
    result = 0
    for s in position:
        m = 2 if s<0 else 1
        result += m*3**(abs(s)-1)
    return result

# list of tuples of global optima or fixed points (set representation) to sorted tuples of ints
def sets_to_int_tuples(row, label):
    result = []
    for tup in row[label]:
        new_tup = (to_int(tup[0]), to_int(tup[1]))
        result.append(new_tup)
    
    return tuple(result)

In [30]:
# data loading 

data_file_name = 'weightings_ensemble.csv.tar.gz'

if data_file_name[data_file_name.find('.'):len(data_file_name)] == '.csv.tar.gz':
    with tarfile.open(path.join(data_dir,data_file_name)) as tar:
        for tarinfo in tar:
            file_name = tarinfo.name
        tar.extractall(data_dir)
    re_data = pd.read_csv(path.join(data_dir, file_name))

else:
    re_data = pd.read_csv(path.join(data_dir,data_file_name))

print(re_data.columns)
re_data.shape

Index(['model_name', 'ds', 'n_sentence_pool', 'ds_infer_dens',
       'ds_n_consistent_complete_positions', 'account_penalties',
       'faithfulness_penalties', 'weight_account', 'weight_systematicity',
       'weight_faithfulness', 'init_coms', 'init_coms_dia_consistent',
       'fixed_point_coms', 'fixed_point_coms_consistent', 'fixed_point_theory',
       'fixed_point_dia_consistent', 'n_branches', 'fixed_points',
       'n_fixed_points', 'fp_coms_consistent', 'fp_union_consistent',
       'fixed_point_is_global_optimum', 'fixed_point_is_re_state',
       'fixed_point_is_full_re_state', 'global_optima', 'n_global_optima',
       'go_coms_consistent', 'go_union_consistent', 'go_full_re_state',
       'full_re_states', 'n_full_re_states', 'go_fixed_point',
       'fp_full_re_state', 'fp_global_optimum'],
      dtype='object')


(276000, 34)

In [31]:
re_data["ds"].nunique()

10

In [32]:
re_data.groupby("ds")["init_coms"].nunique().describe()

count     10.0
mean     100.0
std        0.0
min      100.0
25%      100.0
50%      100.0
75%      100.0
max      100.0
Name: init_coms, dtype: float64

In [33]:
# configuration of weights to single column
re_data["configuration"] = re_data.apply(lambda row: (row["weight_account"],
                                                 row["weight_systematicity"],
                                                 row["weight_faithfulness"]), axis=1)

In [34]:
re_data["configuration"].nunique()

276

### Centroids of Full RE Fixed Points

Proposition: Regions of weight configurations that yield a global optimum are convex. Consequently, centroids of such regions are meaningful representatives.  

In [37]:
# fixed points from all branches

df = re_data[['ds',
              'init_coms', 
              'weight_account', 
              'weight_systematicity', 
              'weight_faithfulness', 
              'fixed_points',
              "fp_coms_consistent",
              'fp_union_consistent',
              'fp_global_optimum',
              'fp_full_re_state']]

In [38]:
# convert strings to objects
literal_eval_cols(df, ['fixed_points',
                       "fp_coms_consistent",
                       'fp_union_consistent',
                       'fp_global_optimum',
                       'fp_full_re_state'])

In [39]:
# convert fixed points to integer representations
df["fixed_points_int"] = df.apply(lambda row: sets_to_int_tuples(row, "fixed_points"), axis=1)

In [40]:
# explode fixed points and corresponding information
df_ex = df.set_index(["ds",
                      "init_coms", 
                      "weight_account", 
                      "weight_systematicity", 
                      "weight_faithfulness"]).apply(pd.Series.explode).reset_index()

In [41]:
df_ex.shape

(674297, 11)

In [42]:
df_ex[df_ex["fp_full_re_state"]].shape

(77526, 11)

In [43]:
sub_df_ex = df_ex[df_ex["fp_full_re_state"]]

In [44]:
centroid_df = sub_df_ex.groupby(["ds",
                                 "init_coms",
                                 "fixed_points_int"])["weight_account", "weight_systematicity", "weight_faithfulness"].mean().reset_index()

In [47]:
file_name = "full_re_fp_centroids_large_ensemble"

fig = px.scatter_ternary(centroid_df, 
                             a="weight_account", 
                             b="weight_systematicity", 
                             c="weight_faithfulness",
                             #color_discrete_sequence= prism_ext,
                             #size="area",
                             #color="mean_dist_centroid_border",
                             #color="go_full_re_state",
                             #symbol="go_full_re_state",
                             #color_continuous_scale=px.colors.sequential.Viridis, 
                             opacity=0.40,
                             #title="Full RE Fixed Point Centroids" + "<br>" + '<span style="font-size: 12px;">10 randomly generated structures and 100 random sets of initial commitments per structure</span>',
                             height=615)



# relabel axes
fig.update_ternaries(aaxis={"title":r"$\alpha_{A}$", "ticks":"outside", "title_font_size":16, "tickfont_size":14},
                         baxis={"title":r"$\alpha_{S}$", "ticks":"outside", "title_font_size":16, "tickfont_size":14},
                         caxis={"title":r"$\alpha_{F}$", "ticks":"outside", "title_font_size":16, "tickfont_size":14},
                         )

fig.update_traces(marker_color="#29ae80", marker_size=6)


fig.update_layout(margin=dict(l=5, r=5, t=45, b=45))

fig.write_image(path.join(plot_dir, file_name + ".png"), scale=2)


fig.update_traces(opacity=0.6)


fig.update_traces(textposition='bottom center')



fig.show()
fig.write_image(path.join(plot_dir, file_name + ".png"), scale=2)

In [48]:
file_name = "full_re_fp_centroids_large_ensemble_selection"

fig = px.scatter_ternary(centroid_df, 
                             a="weight_account", 
                             b="weight_systematicity", 
                             c="weight_faithfulness",
                             #color_discrete_sequence= prism_ext,
                             #size="area",
                             #color="mean_dist_centroid_border",
                             #color="go_full_re_state",
                             #symbol="go_full_re_state",
                             #color_continuous_scale=px.colors.sequential.Viridis, 
                             opacity=0.15,
                             #title="Full RE Fixed Point Centroids" + "<br>" + '<span style="font-size: 12px;">10 randomly generated structures and 100 random sets of initial commitments per structure</span>',
                             height=615)



# relabel axes
fig.update_ternaries(aaxis={"title":r"$\alpha_{A}$", "ticks":"outside", "title_font_size":16, "tickfont_size":14},
                         baxis={"title":r"$\alpha_{S}$", "ticks":"outside", "title_font_size":16, "tickfont_size":14},
                         caxis={"title":r"$\alpha_{F}$", "ticks":"outside", "title_font_size":16, "tickfont_size":14},
                         )

fig.update_traces(marker_color="#29ae80", marker_size=6)


fig.update_layout(margin=dict(l=5, r=5, t=45, b=45))

fig.write_image(path.join(plot_dir, file_name + ".png"), scale=2)


fig.update_traces(opacity=0.6)

# add selection of weightings
fig.add_trace(
    go.Scatterternary(a=[0.35, 0.55, 
                         0.10, 0.10, 
                         0.46, 0.55, 0.70],
                      b=[0.55, 0.35, 
                         0.55, 0.35, 
                         0.10, 0.20, 0.20],
                      c=[0.10, 0.10, 
                         0.35, 0.55, 
                         0.44, 0.25, 0.10],                     
                      mode='markers',
                      marker=dict(color='black', size=12, symbol="hexagon"),
                      showlegend=False))
# text 
fig.add_trace(
    go.Scatterternary(a=[0.34, 0.50, 
                         0.07, 0.07, 
                         0.43, 0.52, 0.85],
                      b=[0.55, 0.35, 
                         0.55, 0.35, 
                         0.10, 0.10, 0.16],
                      c=[0.17, 0.10, 
                         0.30, 0.60, 
                         0.44, 0.25, 0.14],
                      
                      text=["(0.35, 0.55, 0.10)", 
                            "(0.55, 0.35, 0.10)",
                            "(0.10, 0.55, 0.35)",
                            "(0.10, 0.35, 0.55)",
                            "(0.46, 0.10, 0.44)",
                            "(0.55, 0.20, 0.25)",
                            "(0.70, 0.20, 0.10)"],
                      mode='text',
                      textfont=dict(size=14),
                      #marker=dict(color='black', size=12, symbol="hexagon"),
                      showlegend=False))



fig.update_traces(textposition='bottom center')



fig.show()
fig.write_image(path.join(plot_dir, file_name + "_selection.png"), scale=2)