Skip to content
51 changes: 36 additions & 15 deletions thicket/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: MIT

from collections import OrderedDict
import warnings

from hatchet import GraphFrame
import numpy as np
Expand Down Expand Up @@ -360,11 +361,11 @@ def _handle_statsframe():
return combined_th

@staticmethod
def _index(thickets, from_statsframes=False):
def _index(thickets):
"""Unify a list of thickets into a single thicket

Arguments:
from_statsframes (bool): Whether this method was invoked from from_statsframes
thickets (list): list of Thicket objects

Returns:
unify_graph (hatchet.Graph): unified graph,
Expand All @@ -376,17 +377,37 @@ def _index(thickets, from_statsframes=False):
unify_profile_mapping (dict): profile mapping
"""

def _fill_perfdata(perfdata, fill_value=np.nan):
# Fill missing rows in dataframe with NaN's
perfdata = perfdata.reindex(
pd.MultiIndex.from_product(perfdata.index.levels), fill_value=fill_value
)
# Replace "NaN" with "None" in columns of string type
for col in perfdata.columns:
if pd.api.types.is_string_dtype(perfdata[col].dtype):
perfdata[col].replace({fill_value: None}, inplace=True)
def _fill_perfdata(df, numerical_fill_value=np.nan):
"""Create full index for DataFrame and fill created rows with NaN's or None's where applicable.

return perfdata
Arguments:
df (DataFrame): DataFrame to fill missing rows in
numerical_fill_value (any): value to fill numerical rows with

Returns:
(DataFrame): filled DataFrame
"""
try:
# Fill missing rows in dataframe with NaN's
df = df.reindex(
pd.MultiIndex.from_product(df.index.levels),
fill_value=numerical_fill_value,
)
# Replace "NaN" with "None" in columns of string type
for col in df.columns:
if pd.api.types.is_string_dtype(df[col].dtype):
df[col].replace({numerical_fill_value: None}, inplace=True)
except ValueError as e:
estr = str(e)
if estr == "cannot handle a non-unique multi-index!":
warnings.warn(
"Non-unique multi-index for DataFrame in _fill_perfdata. Cannot Fill missing rows.",
RuntimeWarning,
)
else:
raise

return df

# Add missing indicies to thickets
helpers._resolve_missing_indicies(thickets)
Expand Down Expand Up @@ -421,6 +442,9 @@ def _fill_perfdata(perfdata, fill_value=np.nan):
# Sort by keys
unify_profile_mapping = OrderedDict(sorted(unify_profile_mapping.items()))

# Validate unify_df before next operation
validate_dataframe(unify_df)

# Insert missing rows in dataframe
unify_df = _fill_perfdata(unify_df)

Expand All @@ -433,9 +457,6 @@ def _fill_perfdata(perfdata, fill_value=np.nan):
unify_inc_metrics = list(set(unify_inc_metrics))
unify_exc_metrics = list(set(unify_exc_metrics))

# Validate unify_df
validate_dataframe(unify_df)

unify_parts = (
unify_graph,
unify_df,
Expand Down
13 changes: 9 additions & 4 deletions thicket/tests/test_concat_thickets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

import hatchet as ht
import pandas as pd
import pytest

from test_filter_metadata import filter_one_column
from test_filter_metadata import filter_multiple_and
from test_filter_stats import check_filter_stats
from test_query import check_query
from thicket import Thicket
from thicket.utils import DuplicateIndexError


def test_concat_thickets_index(mpi_scaling_cali):
Expand All @@ -22,15 +24,18 @@ def test_concat_thickets_index(mpi_scaling_cali):
tk = Thicket.concat_thickets([th_27, th_64])

# Check dataframe shape
tk.dataframe.shape == (90, 7)

# Check that the two Thickets are equivalent
assert tk
assert tk.dataframe.shape == (90, 7)

# Check specific values. Row order can vary so use "sum" to check
node = tk.dataframe.index.get_level_values("node")[8]
assert sum(tk.dataframe.loc[node, "Min time/rank"]) == 0.000453

# Check error thrown
with pytest.raises(
DuplicateIndexError,
):
Thicket.from_caliperreader([mpi_scaling_cali[0], mpi_scaling_cali[0]])


def test_concat_thickets_columns(thicket_axis_columns):
thickets, thickets_cp, combined_th = thicket_axis_columns
Expand Down
36 changes: 31 additions & 5 deletions thicket/tests/test_from_statsframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
#
# SPDX-License-Identifier: MIT

from thicket import Thicket as th
import pytest

import thicket as th
from thicket.utils import DuplicateValueError

def test_from_statsframes(mpi_scaling_cali):

def test_single_trial(mpi_scaling_cali):
th_list = []
for file in mpi_scaling_cali:
th_list.append(th.from_caliperreader(file))
th_list.append(th.Thicket.from_caliperreader(file))

# Add arbitrary value to aggregated statistics table
t_val = 0
for t in th_list:
t.statsframe.dataframe["test"] = t_val
t_val += 2

tk = th.from_statsframes(th_list)
tk = th.Thicket.from_statsframes(th_list)

# Check level values
assert set(tk.dataframe.index.get_level_values("profile")) == {
Expand All @@ -30,7 +33,7 @@ def test_from_statsframes(mpi_scaling_cali):
# Check performance data table values
assert set(tk.dataframe["test"]) == {0, 2, 4, 6, 8}

tk_named = th.from_statsframes(th_list, metadata_key="mpi.world.size")
tk_named = th.Thicket.from_statsframes(th_list, metadata_key="mpi.world.size")

# Check level values
assert set(tk_named.dataframe.index.get_level_values("mpi.world.size")) == {
Expand All @@ -42,3 +45,26 @@ def test_from_statsframes(mpi_scaling_cali):
}
# Check performance data table values
assert set(tk_named.dataframe["test"]) == {0, 2, 4, 6, 8}


def test_multi_trial(rajaperf_cali_alltrials):
tk = th.Thicket.from_caliperreader(rajaperf_cali_alltrials)

# Simulate multiple trial from grouping by tuning.
gb = tk.groupby("tuning")

# Arbitrary data in statsframe.
for _, ttk in gb.items():
ttk.statsframe.dataframe["mean"] = 1

stk = th.Thicket.from_statsframes(list(gb.values()), metadata_key="tuning")

# Check error thrown for simulated multi-trial
with pytest.raises(
DuplicateValueError,
):
th.Thicket.from_statsframes(
[list(gb.values())[0], list(gb.values())[0]], metadata_key="tuning"
)

assert stk.dataframe.shape == (222, 2)
82 changes: 37 additions & 45 deletions thicket/thicket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from thicket.ensemble import Ensemble
import thicket.helpers as helpers
from .groupby import GroupBy
from .utils import verify_thicket_structures
from .utils import verify_thicket_structures, check_duplicate_metadata_key
from .external.console import ThicketRenderer


Expand Down Expand Up @@ -329,9 +329,6 @@ def concat_thickets(thickets, axis="index", calltree="union", **kwargs):
axis (str): axis to concatenate on -> "index" or "column"
calltree (str): calltree to use -> "union" or "intersection"

Keyword Arguments:
from_statsframes (bool): (if axis="index") Whether this method was invoked from from_statsframes

Keyword Arguments:
headers (list): (if axis="columns") List of headers to use for the new columnar multi-index
metadata_key (str): (if axis="columns") Name of the column from the metadata tables to replace the 'profile'
Expand All @@ -342,10 +339,8 @@ def concat_thickets(thickets, axis="index", calltree="union", **kwargs):
(thicket): concatenated thicket
"""

def _index(thickets, from_statsframes=False):
thicket_parts = Ensemble._index(
thickets=thickets, from_statsframes=from_statsframes
)
def _index(thickets):
thicket_parts = Ensemble._index(thickets=thickets)

return Thicket(
graph=thicket_parts[0],
Expand Down Expand Up @@ -653,14 +648,14 @@ def tree(
)

@staticmethod
def from_statsframes(th_list, metadata_key=None):
def from_statsframes(tk_list, metadata_key=None):
"""Compose a list of Thickets with data in their statsframes.

The Thicket's individual aggregated statistics tables are ensembled and become the
new Thickets performance data table.
new Thickets performance data table. This also results in aggregation of the metadata.

Arguments:
th_list (list): list of thickets
tk_list (list): list of thickets
metadata_key (str, optional): name of the metadata column to use as
the new second-level index. Uses the first value so this only makes
sense if provided column is all equal values and each thicket's columns
Expand All @@ -670,59 +665,56 @@ def from_statsframes(th_list, metadata_key=None):
(thicket): New Thicket object.
"""
# Pre-check of data structures
for th in th_list:
for tk in tk_list:
verify_thicket_structures(
th.dataframe, index=["node", "profile"]
tk.dataframe, index=["node", "profile"]
) # Required for deepcopy operation
verify_thicket_structures(
th.statsframe.dataframe, index=["node"]
tk.statsframe.dataframe, index=["node"]
) # Required for deepcopy operation

# Setup names list
th_names = []
tk_names = []
if metadata_key is None:
for i in range(len(th_list)):
th_names.append(i)
idx_name = "profile" # Set index name to general "profile"
for i in range(len(tk_list)):
tk_names.append(i)
else: # metadata_key was provided.
for th in th_list:
check_duplicate_metadata_key(tk_list, metadata_key)
idx_name = metadata_key # Set index name to metadata_key
for tk in tk_list:
# Get name from metadata table
name_list = th.metadata[metadata_key].tolist()
name_list = tk.metadata[metadata_key].tolist()

if len(name_list) > 1:
if len(set(name_list)) > 1:
warnings.warn(
f"Multiple values for name {name_list} at thicket.metadata[{metadata_key}]. Only the first will be used."
f"Multiple values for name {name_list} at thicket.metadata['{metadata_key}']. Only the first value will be used for the new DataFrame index."
)
th_names.append(name_list[0])
tk_names.append(name_list[0])

th_copy_list = []
for i in range(len(th_list)):
th_copy = th_list[i].deepcopy()
tk_copy_list = []
for i in range(len(tk_list)):
tk_copy = tk_list[i].deepcopy()

th_id = th_names[i]

if metadata_key is None:
idx_name = "profile"
else:
idx_name = metadata_key
tk_id = tk_names[i]

# Modify graph
# Necessary so node ids match up
th_copy.graph = th_copy.statsframe.graph
# Modify graph. Necessary so node ids match up
tk_copy.graph = tk_copy.statsframe.graph

# Modify the performance data table
df = th_copy.statsframe.dataframe
df[idx_name] = th_id
df = tk_copy.statsframe.dataframe
df[idx_name] = tk_id
df.set_index(idx_name, inplace=True, append=True)
th_copy.dataframe = df
tk_copy.dataframe = df

# Adjust profile and profile_mapping
th_copy.profile = [th_id]
profile_paths = list(th_copy.profile_mapping.values())
th_copy.profile_mapping = OrderedDict({th_id: profile_paths})
tk_copy.profile = [tk_id]
profile_paths = list(tk_copy.profile_mapping.values())
tk_copy.profile_mapping = OrderedDict({tk_id: profile_paths})

# Modify metadata dataframe
th_copy.metadata[idx_name] = th_id
th_copy.metadata.set_index(idx_name, inplace=True)
tk_copy.metadata[idx_name] = tk_id
tk_copy.metadata.set_index(idx_name, inplace=True)

def _agg_to_set(obj):
"""Aggregate values in 'obj' into a set to remove duplicates."""
Expand All @@ -739,12 +731,12 @@ def _agg_to_set(obj):
return _set

# Execute aggregation
th_copy.metadata = th_copy.metadata.groupby(idx_name).agg(_agg_to_set)
tk_copy.metadata = tk_copy.metadata.groupby(idx_name).agg(_agg_to_set)

# Append copy to list
th_copy_list.append(th_copy)
tk_copy_list.append(tk_copy)

return Thicket.concat_thickets(th_copy_list, from_statsframes=True)
return Thicket.concat_thickets(tk_copy_list)

def to_json(self, ensemble=True, metadata=True, stats=True):
jsonified_thicket = {}
Expand Down
Loading