Skip to content

Commit

Permalink
add feature importance comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
kasunamare committed Dec 14, 2023
1 parent 8d23956 commit cc35aa7
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions src/triage/component/postmodeling/report_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def plot_feature_importance(self, n_top_features=20):
""" plot all feature importance """
self._make_plot_grid(plot_type='plot_feature_importance', subplot_width=7, n_top_features=n_top_features)


def plot_feature_group_importance(self, n_top_groups=20):
""" plot all feature group importance """
self._make_plot_grid(plot_type='plot_feature_group_importance', subplot_width=7, n_top_groups=n_top_groups)
Expand Down Expand Up @@ -441,6 +442,96 @@ def display_crosstab_pos_vs_neg(

if return_dfs:
return dfs

def _pairwise_feature_importance_comparison_single_split(self, train_end_time, n_top_features, model_group_ids=None,plot=True):
""" For a given train_end_time, compares the top n features (highest absolute importance) of two models
Args:
train_end_time (str): The prediction date we care about in YYYY-MM-DD format
n_top_features (int): Number of features to consider for the comparispn
model_group_ids (int, optional): Model group ids to consider, if not provided, all model groups included in the report are used
plot (bool, optional): Whether to plot the results. Defaults to True.
"""

feature_lists = dict()

if model_group_ids is not None:

for mg in model_group_ids:
feature_lists[mg] = self.models[mg][train_end_time].get_feature_importances(n_top_features=n_top_features)
if feature_lists[mg].empty:
logging.warning('No feature importance values were found for model group {mg}. Excluding from comparison')
feature_lists.pop(mg)
# By default all feature importance values are considered
else:
model_group_ids = self.model_groups
for mg, m in self.models.items():
feature_lists[mg] = m[train_end_time].get_feature_importances(n_top_features=n_top_features)
if feature_lists[mg].empty:
logging.warning('No feature importance values were found for model group {mg}. Excluding from comparison')
feature_lists.pop(mg)

pairs = list(itertools.combinations(feature_lists.keys(), 2))

logging.info(f'Performing {len(pairs)} comparisons')

metrics = ['jaccard', 'overlap', 'rank_corr']
results = dict()

for m in metrics:
results[m] = pd.DataFrame(index=model_group_ids, columns=model_group_ids)
# filling the diagonal with 1
results[m].values[[np.arange(results[m].shape[0])]*2] = 1

for model_group_pair in pairs:
logging.info(f'Comparing {model_group_pair[0]} and {model_group_pair[1]}')

df1 = feature_lists[model_group_pair[0]]
df2 = feature_lists[model_group_pair[1]]

f1 = set(df1.feature)
f2 = set(df2.feature)

if (len(f1) == 0 or len(f2)) == 0:
logging.error('No feature importance available for the models!')
continue

inter = f1.intersection(f2)
un = f1.union(f2)
results['jaccard'].loc[model_group_pair[1], model_group_pair[0]] = len(inter)/len(un)

# If the list sizes are not equal, using the smallest list size to calculate simple overlap
results['overlap'].loc[model_group_pair[1], model_group_pair[0]] = len(inter)/ min(len(f1), len(f2))

# calculating rank correlation
df1.sort_values('feature_importance', ascending=False, inplace=True)
df2.sort_values('feature_importance', ascending=False, inplace=True)

# only returning the corr coefficient, not the p-value
results['rank_corr'].loc[model_group_pair[0], model_group_pair[1]] = spearmanr(df1.feature.iloc[:], df2.feature.iloc[:])[0]


if plot:
fig, axes = plt.subplots(1, len(metrics), figsize=(10, 3))

for i, m in enumerate(metrics):
sns.heatmap(
data=results[m].fillna(0),
cmap='Greens',
vmin=0,
vmax=1,
annot=True,
linewidth=0.1,
ax=axes[i]
)

axes[i].set_title(m)

fig.suptitle(train_end_time)
fig.tight_layout()

return results


def _pairwise_list_comparison_single_fold(self, threshold_type, threshold, train_end_time, matrix_uuid=None, plot=True):
"""For a given train_end_time, compares the lists generated by the analyzed model groups
Expand Down Expand Up @@ -577,5 +668,34 @@ def pairwise_top_k_list_comparison(self, threshold_type, threshold, train_end_ti
matrix_uuid=matrix_uuid,
plot=plot
)

def pairwise_feature_importance_comparison(self, n_top_features, model_groups=None, train_end_times=None, plot=True):
"""
Compare the top-k lists for the given train_end_times for all model groups considered (pairwise)
Args:
threshold_type (str): Type of the ranking to use. Has to be one of the four ranking types used in triage
- rank_pct_no_ties
- rank_pct_with_ties
- rank_abs_no_ties
- rank_abs_with_ties
threshold (Union[float, int]): The threshold rank for creating the list. Int for 'rank_abs_*' and Float for 'rank_pct_*'
train_end_times (Optional, List[str]): The prediction date we care about in YYYY-MM-DD format
"""

# If no train_end_times are provided, we consider all the train_end_times
# NOTE -- Assuming that the all model groups have the same train_end_times
if train_end_times is None:
train_end_times = self.models[self.model_groups[0]].keys()


for train_end_time in train_end_times:
self._pairwise_feature_importance_comparison_single_split(
train_end_time=train_end_time,
n_top_features=n_top_features,
model_group_ids=model_groups,
plot=plot
)



0 comments on commit cc35aa7

Please sign in to comment.