Skip to content

Commit

Permalink
rewrite merge_categorical_columns to avoid fragmentation warning
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Mar 11, 2024
1 parent 98a6c2d commit 50a37db
Showing 1 changed file with 13 additions and 24 deletions.
37 changes: 13 additions & 24 deletions explainerdashboard/explainer_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,44 +388,33 @@ def retrieve_onehot_value(
def merge_categorical_columns(
X, onehot_dict=None, cols=None, not_encoded_dict=None, sep="_", drop_regular=False
):
"""
Returns a new feature Dataframe X_cats where the onehotencoded
categorical features have been merged back with the old value retrieved
from the encodings.
cat_pieces = []

Args:
X (pd.DataFrame): original dataframe with onehotencoded columns, e.g.
columns=['Age', 'Sex_Male', 'Sex_Female"].
onehot_dict (dict): dict of features with lists for onehot-encoded variables,
e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']}
cols (list[str]): list of columns to return
sep (str): separator used in the encoding, e.g. "_" for Sex_Male.
Defaults to "_".
Returns:
pd.DataFrame, with onehot encodings merged back into categorical columns.
"""
X_cats = pd.DataFrame()
not_encoded_dict = not_encoded_dict or {}
for col_name, col_list in onehot_dict.items():
if len(col_list) > 1:
X_cats[col_name] = retrieve_onehot_value(
merged_col = retrieve_onehot_value(
X,
col_name,
col_list,
not_encoded_dict.get(col_name, "NOT_ENCODED"),
sep,
).astype("category")
cat_pieces.append(pd.DataFrame({col_name: merged_col}))
else:
if not drop_regular:
if is_categorical_dtype(X[col_name]):
X_cats[col_name] = pd.Categorical(X[col_name])
cat_pieces.append(
pd.DataFrame({col_name: pd.Categorical(X[col_name])})
)
else:
X_cats.loc[:, col_name] = X[col_name].values
cat_pieces.append(pd.DataFrame({col_name: X[col_name].values}))

X_cats = pd.concat(cat_pieces, axis=1)

if cols:
return X_cats[cols]
else:
return X_cats
X_cats = X_cats[cols]

return X_cats


def matching_cols(cols1, cols2):
Expand Down

0 comments on commit 50a37db

Please sign in to comment.