# Feature Importance Analysis

Analyze which features are most important for predicting citation impact:
1. Load trained models
2. Extract feature importance
3. Analyze top features by category
4. Visualize importance rankings

In [None]:
import sys
sys.path.append('../')

import pandas as pd
import numpy as np
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import seaborn as sns

pd.set_option('display.max_columns', None)
%matplotlib inline

## 1. Load Models and Features

In [None]:
with open('../models/classification/lightgbm.pkl', 'rb') as f:
    clf_model = pickle.load(f)

with open('../models/regression/random_forest.pkl', 'rb') as f:
    reg_model = pickle.load(f)

X = pd.read_pickle('../data/features/X_all.pkl')
feature_names = X.columns.tolist()

print(f"Loaded models and {len(feature_names)} features")

## 2. Extract Feature Importance - Classification

In [None]:
clf_importance = pd.DataFrame({
    'feature': feature_names,
    'importance': clf_model.feature_importances_
}).sort_values('importance', ascending=False)

print("Top 20 features for Classification:")
print(clf_importance.head(20))

## 3. Extract Feature Importance - Regression

In [None]:
reg_importance = pd.DataFrame({
    'feature': feature_names,
    'importance': reg_model.feature_importances_
}).sort_values('importance', ascending=False)

print("Top 20 features for Regression:")
print(reg_importance.head(20))

## 4. Categorize Features

In [None]:
def categorize_feature(feature):
    if feature.startswith('tfidf_'):
        return 'Text'
    elif any(keyword in feature.lower() for keyword in ['snip', 'citescore', 'sjr', 'venue', 'views', 'field_weighted_view']):
        return 'Venue'
    elif any(keyword in feature.lower() for keyword in ['author', 'institution', 'team', 'collab', 'num_countries']):
        return 'Author'
    else:
        return 'Other'

clf_importance['category'] = clf_importance['feature'].apply(categorize_feature)
reg_importance['category'] = reg_importance['feature'].apply(categorize_feature)

print("Classification - Importance by Category:")
print(clf_importance.groupby('category')['importance'].sum().sort_values(ascending=False))

print("\nRegression - Importance by Category:")
print(reg_importance.groupby('category')['importance'].sum().sort_values(ascending=False))

## 5. Top Features by Category

In [None]:
print("Classification - Top 10 Text Features:")
print(clf_importance[clf_importance['category'] == 'Text'].head(10))

print("\nClassification - Top Venue Features:")
print(clf_importance[clf_importance['category'] == 'Venue'].head(10))

print("\nClassification - Top Author Features:")
print(clf_importance[clf_importance['category'] == 'Author'].head(10))

## 6. Visualize Top 20 Features - Classification

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
top_20 = clf_importance.head(20)
colors = top_20['category'].map({'Text': 'steelblue', 'Venue': 'coral', 'Author': 'green', 'Other': 'gray'})

ax.barh(range(len(top_20)), top_20['importance'], color=colors)
ax.set_yticks(range(len(top_20)))
ax.set_yticklabels(top_20['feature'].str.replace('tfidf_', '').str[:40])
ax.set_xlabel('Importance')
ax.set_title('Top 20 Features - Classification (LightGBM)')
ax.invert_yaxis()

from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='steelblue', label='Text'),
                   Patch(facecolor='coral', label='Venue'),
                   Patch(facecolor='green', label='Author')]
ax.legend(handles=legend_elements, loc='lower right')

plt.tight_layout()
plt.show()

## 7. Visualize Top 20 Features - Regression

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
top_20 = reg_importance.head(20)
colors = top_20['category'].map({'Text': 'steelblue', 'Venue': 'coral', 'Author': 'green', 'Other': 'gray'})

ax.barh(range(len(top_20)), top_20['importance'], color=colors)
ax.set_yticks(range(len(top_20)))
ax.set_yticklabels(top_20['feature'].str.replace('tfidf_', '').str[:40])
ax.set_xlabel('Importance')
ax.set_title('Top 20 Features - Regression (Random Forest)')
ax.invert_yaxis()

from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='steelblue', label='Text'),
                   Patch(facecolor='coral', label='Venue'),
                   Patch(facecolor='green', label='Author')]
ax.legend(handles=legend_elements, loc='lower right')

plt.tight_layout()
plt.show()

## 8. Category Importance Comparison

In [None]:
clf_cat = clf_importance.groupby('category')['importance'].sum()
reg_cat = reg_importance.groupby('category')['importance'].sum()

comparison = pd.DataFrame({
    'Classification': clf_cat,
    'Regression': reg_cat
})

fig, ax = plt.subplots(figsize=(10, 6))
comparison.plot(kind='bar', ax=ax, color=['steelblue', 'coral'])
ax.set_ylabel('Total Importance')
ax.set_title('Feature Category Importance: Classification vs Regression')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
plt.tight_layout()
plt.show()

print("Category Importance Comparison:")
print(comparison)

## Summary

In [None]:
print("=" * 60)
print("FEATURE IMPORTANCE SUMMARY")
print("=" * 60)
print(f"\nTotal features: {len(feature_names)}")
print(f"\nClassification - Most important feature:")
print(f"  {clf_importance.iloc[0]['feature']}: {clf_importance.iloc[0]['importance']:.4f}")
print(f"\nRegression - Most important feature:")
print(f"  {reg_importance.iloc[0]['feature']}: {reg_importance.iloc[0]['importance']:.4f}")
print(f"\nKey takeaway: Which category drives citation impact most?")
print(comparison)