In [1]:
import numpy as np
import pandas as pd
from sklearn.externals import joblib
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import pairwise_distances

## Load Dataset

In [2]:
df = pd.read_csv('student-por.csv', sep=';')
print (df.columns)
df.head()

Index(['school', 'sex', 'age', 'address', 'famsize', 'Pstatus', 'Medu', 'Fedu',
       'Mjob', 'Fjob', 'reason', 'guardian', 'traveltime', 'studytime',
       'failures', 'schoolsup', 'famsup', 'paid', 'activities', 'nursery',
       'higher', 'internet', 'romantic', 'famrel', 'freetime', 'goout', 'Dalc',
       'Walc', 'health', 'absences', 'G1', 'G2', 'G3'],
      dtype='object')


Unnamed: 0,school,sex,age,address,famsize,Pstatus,Medu,Fedu,Mjob,Fjob,...,famrel,freetime,goout,Dalc,Walc,health,absences,G1,G2,G3
0,GP,F,18,U,GT3,A,4,4,at_home,teacher,...,4,3,4,1,1,3,4,0,11,11
1,GP,F,17,U,GT3,T,1,1,at_home,other,...,5,3,3,1,1,3,2,9,11,11
2,GP,F,15,U,LE3,T,1,1,at_home,other,...,4,3,2,2,3,3,6,12,13,12
3,GP,F,15,U,GT3,T,4,2,health,services,...,3,2,2,1,1,5,0,14,14,14
4,GP,F,16,U,GT3,T,3,3,other,other,...,4,3,2,1,2,5,0,11,13,13


## Convert binary feature values to 1s and 0s

In [3]:
# 1 school - student's school (binary: 'GP' - Gabriel Pereira or 'MS' - Mousinho da Silveira) 
# 2 sex - student's sex (binary: 'F' - female or 'M' - male) 
# 3 age - student's age (numeric: from 15 to 22) 
# 4 address - student's home address type (binary: 'U' - urban or 'R' - rural) 
# 5 famsize - family size (binary: 'LE3' - less or equal to 3 or 'GT3' - greater than 3) 
# 6 Pstatus - parent's cohabitation status (binary: 'T' - living together or 'A' - apart) 
# 7 Medu - mother's education (numeric: 0 - none, 1 - primary education (4th grade), 2 â€“ 5th to 9th grade, 3 â€“ secondary education or 4 â€“ higher education) 
# 8 Fedu - father's education (numeric: 0 - none, 1 - primary education (4th grade), 2 â€“ 5th to 9th grade, 3 â€“ secondary education or 4 â€“ higher education) 
# 9 Mjob - mother's job (nominal: 'teacher', 'health' care related, civil 'services' (e.g. administrative or police), 'at_home' or 'other') 
# 10 Fjob - father's job (nominal: 'teacher', 'health' care related, civil 'services' (e.g. administrative or police), 'at_home' or 'other') 
# 11 reason - reason to choose this school (nominal: close to 'home', school 'reputation', 'course' preference or 'other') 
# 12 guardian - student's guardian (nominal: 'mother', 'father' or 'other') 
# 13 traveltime - home to school travel time (numeric: 1 - <15 min., 2 - 15 to 30 min., 3 - 30 min. to 1 hour, or 4 - >1 hour) 
# 14 studytime - weekly study time (numeric: 1 - <2 hours, 2 - 2 to 5 hours, 3 - 5 to 10 hours, or 4 - >10 hours) 
# 15 failures - number of past class failures (numeric: n if 1<=n<3, else 4) 
# 16 schoolsup - extra educational support (binary: yes or no) 
# 17 famsup - family educational support (binary: yes or no) 
# 18 paid - extra paid classes within the course subject (Math or Portuguese) (binary: yes or no) 
# 19 activities - extra-curricular activities (binary: yes or no) 
# 20 nursery - attended nursery school (binary: yes or no) 
# 21 higher - wants to take higher education (binary: yes or no) 
# 22 internet - Internet access at home (binary: yes or no) 
# 23 romantic - with a romantic relationship (binary: yes or no) 
# 24 famrel - quality of family relationships (numeric: from 1 - very bad to 5 - excellent) 
# 25 freetime - free time after school (numeric: from 1 - very low to 5 - very high) 
# 26 goout - going out with friends (numeric: from 1 - very low to 5 - very high) 
# 27 Dalc - workday alcohol consumption (numeric: from 1 - very low to 5 - very high) 
# 28 Walc - weekend alcohol consumption (numeric: from 1 - very low to 5 - very high) 
# 29 health - current health status (numeric: from 1 - very bad to 5 - very good) 
# 30 absences - number of school absences (numeric: from 0 to 93) 

# # these grades are related with the course subject, Math or Portuguese: 
# 31 G1 - first period grade (numeric: from 0 to 20) 
# 31 G2 - second period grade (numeric: from 0 to 20) 
# 32 G3 - final grade (numeric: from 0 to 20, output target)

df['school'][df['school'] == 'GP'] = 1 # GP is 1, MS is 0
df['school'][df['school'] == 'MS'] = 0
df['sex'][df['sex'] == 'F'] = 1 # Female is 1, Male is 0
df['sex'][df['sex'] == 'M'] = 0
df['address'][df['address'] == 'U'] = 1 # U is 1, R is 0
df['address'][df['address'] == 'R'] = 0
df['famsize'][df['famsize'] == 'LE3'] = 0 # LE3 is 0, GT3 is 1
df['famsize'][df['famsize'] == 'GT3'] = 1
df['Pstatus'][df['Pstatus'] == 'T'] = 1 # T is 1, A is 0
df['Pstatus'][df['Pstatus'] == 'A'] = 0
# df[df['Mjob'] == 'teacher'] = 1 # categorical values
# df[df['Mjob'] == 'health'] = 2
# df[df['Mjob'] == 'services'] = 3
# df[df['Mjob'] == 'home'] = 4
# df[df['Mjob'] == 'other'] = 5
# df[df['Fjob'] == 'teacher'] = 1 # same as MJob
# df[df['Fjob'] == 'health'] = 2
# df[df['Fjob'] == 'services'] = 3
# df[df['Fjob'] == 'home'] = 4
# df[df['Fjob'] == 'other'] = 5
df['schoolsup'][df['schoolsup'] == 'yes'] = 1
df['schoolsup'][df['schoolsup'] == 'no'] = 0
df['famsup'][df['famsup'] == 'yes'] = 1
df['famsup'][df['famsup'] == 'no'] = 0
df['paid'][df['paid'] == 'yes'] = 1
df['paid'][df['paid'] == 'no'] = 0
df['activities'][df['activities'] == 'yes'] = 1
df['activities'][df['activities'] == 'no'] = 0
df['nursery'][df['nursery'] == 'yes'] = 1
df['nursery'][df['nursery'] == 'no'] = 0
df['higher'][df['higher'] == 'yes'] = 1
df['higher'][df['higher'] == 'no'] = 0
df['internet'][df['internet'] == 'yes'] = 1
df['internet'][df['internet'] == 'no'] = 0
df['romantic'][df['romantic'] == 'yes'] = 1
df['romantic'][df['romantic'] == 'no'] = 0

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
A value is trying to be set on a copy of a slice from a

In [4]:
common_cols_to_drop = ["G1", "G2"]
immutable_cols = ["sex","age", "famsize", "Pstatus", "reason", "guardian", "failures", "nursery", "Fjob", "Mjob", "Fedu", "Medu"]
mutable_cols = [x for x in df.columns if x not in immutable_cols and x not in common_cols_to_drop and x != 'G3']
cols_all = mutable_cols + immutable_cols
# immutable_cols += mutable_cols
# mutable_cols += ["sex", "Medu", "Fedu"]

In [5]:
cols_all, mutable_cols

(['school',
  'address',
  'traveltime',
  'studytime',
  'schoolsup',
  'famsup',
  'paid',
  'activities',
  'higher',
  'internet',
  'romantic',
  'famrel',
  'freetime',
  'goout',
  'Dalc',
  'Walc',
  'health',
  'absences',
  'sex',
  'age',
  'famsize',
  'Pstatus',
  'reason',
  'guardian',
  'failures',
  'nursery',
  'Fjob',
  'Mjob',
  'Fedu',
  'Medu'],
 ['school',
  'address',
  'traveltime',
  'studytime',
  'schoolsup',
  'famsup',
  'paid',
  'activities',
  'higher',
  'internet',
  'romantic',
  'famrel',
  'freetime',
  'goout',
  'Dalc',
  'Walc',
  'health',
  'absences'])

In [6]:
df_mutable = df[mutable_cols]
df_mutable['G3'] = pd.Series(df['G3'], index=df.index)
df_all = df[cols_all]
df_all['G3'] = pd.Series(df['G3'], index=df.index)

print (df_mutable.columns, df_all.columns)

def same_cols(v1, v2):
    if np.all(v1 == v2):
        return 1
    else:
        return 0
def check_all_rows_unique(df):
    if np.all(np.sum(pairwise_distances(np.array(df), metric=same_cols), axis=1) == 1):
        return True
    else:
        return False, "{} entries > 1".format(np.count_nonzero(np.sum(pairwise_distances(np.array(df), metric=same_cols), axis=1) > 1))

Index(['school', 'address', 'traveltime', 'studytime', 'schoolsup', 'famsup',
       'paid', 'activities', 'higher', 'internet', 'romantic', 'famrel',
       'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences', 'G3'],
      dtype='object') Index(['school', 'address', 'traveltime', 'studytime', 'schoolsup', 'famsup',
       'paid', 'activities', 'higher', 'internet', 'romantic', 'famrel',
       'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences', 'sex', 'age',
       'famsize', 'Pstatus', 'reason', 'guardian', 'failures', 'nursery',
       'Fjob', 'Mjob', 'Fedu', 'Medu', 'G3'],
      dtype='object')


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  after removing the cwd from sys.path.


In [7]:
def convert_df_to_array(df):
    X = None
    feature_info = []
    for fname in df.columns:
        if fname in ['Mjob', 'Fjob', 'reason', 'guardian']:
            ohe = OneHotEncoder()
            new_df = ohe.fit_transform(np.array([df[fname]]).reshape((len(df), 1))).toarray()
            X = new_df if X is None else np.append(X, new_df, axis=1)
            for cat in ohe.categories_[0]:
                print ("Cat: {}".format(cat))
                feature_info.append('{}_{}'.format(fname, cat))
            print (fname, new_df, ohe.categories_)
            print ()
        elif fname in ['school','sex','address','famsize','Pstatus',
                       'schoolsup','famsup','paid','activities',
                       'nursery','higher','internet','romantic']:
            X = np.array(df[fname]).reshape((len(df), 1)) if X is None else \
                np.append(X, np.array(df[fname]).reshape((len(df), 1)), axis=1)
            feature_info.append('{}'.format(fname))
            print (X.shape, np.array(df[fname]).reshape((len(df), 1)).shape)
        elif fname == 'G3':
            Y = np.array(df[fname]).flatten()
        else:
            X = np.array(df[fname]).reshape((len(df), 1)) if X is None else \
                np.append(X, np.array(df[fname]).reshape((len(df), 1)), axis=1)
            feature_info.append('{}'.format(fname))
    return X, Y, feature_info

In [8]:
X_all, Y_all, feature_info_all = convert_df_to_array(df_all)
X_mutable, Y_mutable, feature_info_mutable = convert_df_to_array(df_mutable)
X_all.shape, X_mutable.shape

(649, 1) (649, 1)
(649, 2) (649, 1)
(649, 5) (649, 1)
(649, 6) (649, 1)
(649, 7) (649, 1)
(649, 8) (649, 1)
(649, 9) (649, 1)
(649, 10) (649, 1)
(649, 11) (649, 1)
(649, 19) (649, 1)
(649, 21) (649, 1)
(649, 22) (649, 1)
Cat: course
Cat: home
Cat: other
Cat: reputation
reason [[1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 ...
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]] [array(['course', 'home', 'other', 'reputation'], dtype=object)]

Cat: father
Cat: mother
Cat: other
guardian [[0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 ...
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]] [array(['father', 'mother', 'other'], dtype=object)]

(649, 31) (649, 1)
Cat: at_home
Cat: health
Cat: other
Cat: services
Cat: teacher
Fjob [[0. 0. 0. 0. 1.]
 [0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0.]
 ...
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 1. 0. 0.]] [array(['at_home', 'health', 'other', 'services', 'teacher'], dtype=object)]

Cat: at_home
Cat: health
Cat: other
Cat: services
Cat: teacher
Mjob [[1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [

((649, 43), (649, 18))

In [9]:
assert np.all(Y_all == Y_mutable)

In [10]:
check_all_rows_unique(X_all), check_all_rows_unique(X_mutable)

(True, (False, '4 entries > 1'))

In [11]:
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import MinMaxScaler

scaler_mutable, scaler_all = MinMaxScaler(), MinMaxScaler()
X_mutable = scaler_mutable.fit_transform(X_mutable)
X_all = scaler_all.fit_transform(X_all)

ridge_mutable, ridge_all = Ridge(alpha=0.1), Ridge(alpha=200)
linreg_mutable, linreg_all = LinearRegression(), LinearRegression()
lasso_mutable, lasso_all = Lasso(), Lasso()
ridge_mutable.fit(X_mutable, Y_mutable)
lasso_mutable.fit(X_mutable, Y_mutable)
linreg_mutable.fit(X_mutable, Y_mutable)
ridge_all.fit(X_all, Y_all)
lasso_all.fit(X_all, Y_all)
linreg_all.fit(X_all, Y_all)



LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
         normalize=False)

In [12]:
(mean_absolute_error(Y_mutable, ridge_mutable.predict(X_mutable)), 
 mean_absolute_error(Y_mutable, lasso_mutable.predict(X_mutable)), 
 mean_absolute_error(Y_mutable, linreg_mutable.predict(X_mutable)))

(2.0278044518666696, 2.4058822272501725, 2.0278889211650775)

In [13]:
(mean_absolute_error(Y_all, ridge_all.predict(X_all)), 
 mean_absolute_error(Y_all, lasso_all.predict(X_all)), 
 mean_absolute_error(Y_all, linreg_all.predict(X_all)))

(2.0458734675051757, 2.4058822272501725, 1.8637326656394453)

In [14]:
# Females
print (mean_absolute_error(Y_mutable[df['sex'] == 1], ridge_mutable.predict(X_mutable[df['sex'] == 1])), 
 mean_absolute_error(Y_mutable[df['sex'] == 1], lasso_mutable.predict(X_mutable[df['sex'] == 1])), 
 mean_absolute_error(Y_mutable[df['sex'] == 1], linreg_mutable.predict(X_mutable[df['sex'] == 1])))
print (mean_absolute_error(Y_all[df['sex'] == 1], ridge_all.predict(X_all[df['sex'] == 1])), 
 mean_absolute_error(Y_all[df['sex'] == 1], lasso_all.predict(X_all[df['sex'] == 1])), 
 mean_absolute_error(Y_all[df['sex'] == 1], linreg_all.predict(X_all[df['sex'] == 1])))

1.999914780936217 2.3855700877429427 1.9998368139531053
2.066929253501052 2.3855700877429427 1.8222095300261096


In [15]:
# Males
print (mean_absolute_error(Y_mutable[df['sex'] == 0], ridge_mutable.predict(X_mutable[df['sex'] == 0])), 
 mean_absolute_error(Y_mutable[df['sex'] == 0], lasso_mutable.predict(X_mutable[df['sex'] == 0])), 
 mean_absolute_error(Y_mutable[df['sex'] == 0], linreg_mutable.predict(X_mutable[df['sex'] == 0])))
print (mean_absolute_error(Y_all[df['sex'] == 0], ridge_all.predict(X_all[df['sex'] == 0])), 
 mean_absolute_error(Y_all[df['sex'] == 0], lasso_all.predict(X_all[df['sex'] == 0])), 
 mean_absolute_error(Y_all[df['sex'] == 0], linreg_all.predict(X_all[df['sex'] == 0])))

2.067961384071043 2.435128653683516 2.068279737188331
2.0155563019547214 2.435128653683516 1.9234609962406015


In [20]:
df_ridge_mutable = pd.DataFrame(data=np.append(ridge_mutable.coef_, ridge_mutable.intercept_), 
                    index=feature_info_mutable + ["Intercept"],
                    columns = ["Ridge Regression,\nRegularization Const. = {}".format(ridge_mutable.alpha)])
df_ridge_all = pd.DataFrame(data=np.append(ridge_all.coef_, ridge_all.intercept_), 
                    index=feature_info_all + ["Intercept"],
                    columns = ["Ridge Regression,\nRegularization Const. = {}".format(ridge_all.alpha)])
df_ridge_mutable

Unnamed: 0,"Ridge Regression, Regularization Const. = 0.1"
school,1.440692
address,0.313617
traveltime,-0.248103
studytime,1.601607
schoolsup,-1.466385
famsup,0.0881
paid,-0.790441
activities,0.196532
higher,2.509282
internet,0.560564


In [21]:
df_ridge_all

Unnamed: 0,"Ridge Regression, Regularization Const. = 200"
school,0.606459
address,0.279544
traveltime,-0.092589
studytime,0.412547
schoolsup,-0.250455
famsup,0.057057
paid,-0.103463
activities,0.121087
higher,0.629339
internet,0.243342


In [22]:
common_index = df_ridge_all.index.intersection(df_ridge_mutable.index)
common_index, len(common_index), len(df_ridge_all)

(Index(['school', 'address', 'traveltime', 'studytime', 'schoolsup', 'famsup',
        'paid', 'activities', 'higher', 'internet', 'romantic', 'famrel',
        'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences', 'Intercept'],
       dtype='object'), 19, 44)

In [59]:
# print (len(df_ridge_mutable), len(df_ridge_all))
# print (df_ridge_all.index, df_ridge_mutable.index)
weights = pd.merge(df_ridge_mutable, df_ridge_all, left_index=True, right_index=True, how='inner')
immutable_weights = df_ridge_all.loc[df_ridge_all.index.difference(df_ridge_mutable.index)]
weights = weights.round(4)
immutable_weights = immutable_weights.round(4)
weights_selected = weights.loc[['school','traveltime', 'higher', 'absences']]
all_features_weights = pd.merge(df_ridge_mutable, df_ridge_all, how='outer', 
                                        left_index=True, right_index=True)
index_order = list(df_ridge_all.index)
all_features_weights = all_features_weights.loc[index_order]
all_features_weights = all_features_weights.round(4)
all_features_weights

Unnamed: 0,"Ridge Regression, Regularization Const. = 0.1","Ridge Regression, Regularization Const. = 200"
school,1.4407,0.6065
address,0.3136,0.2795
traveltime,-0.2481,-0.0926
studytime,1.6016,0.4125
schoolsup,-1.4664,-0.2505
famsup,0.0881,0.0571
paid,-0.7904,-0.1035
activities,0.1965,0.1211
higher,2.5093,0.6293
internet,0.5606,0.2433


In [77]:
import matplotlib.pyplot as plt
import pandas as pd
from pandas.plotting import table

row_colors_all, row_colors_immutable = np.array(['white'] * len(weights)), np.array(['white'] * len(immutable_weights))
row_colors_selected = np.array(['white'] * len(weights_selected.columns))
def init_table_plot():
    fig = plt.figure()
    ax = fig.add_subplot(111, frame_on=False) # no visible frame
    ax.xaxis.set_visible(False)  # hide the x axis
    ax.yaxis.set_visible(False)  # hide the y axis
    ax.axis('off')
    ax.axis('tight')
    return ax, fig

def set_table_font(table, fontsize):
    table.auto_set_font_size(False)
    table.set_fontsize(fontsize)

def set_cell_heights(table, num_cols, num_rows, cell_height, col_label_height):
    cellDict = table.get_celld()
    #iterate over all items in the grid, if column labels set it to a larger height
    for i in range(0, num_rows + 1):
        for j in range(num_cols):
            cellDict[(i, j)].set_height(col_label_height if i == 0 else cell_height)
    # for the row labels
    for i in range(1, num_rows + 1):
        cellDict[(i, -1)].set_height(cell_height)

def rotate_col_labels(table, num_cols):
    cellDict = table.get_celld()
    # for the column labels
    for i in range(num_cols):
        cellDict[(0, i)].get_text().set_rotation(90)

def get_pruned_labels(row_labels):
    for i in range(len(row_labels)):
        if '_' in row_labels[i]:
            row_labels[i] = "{}\n({})".format(row_labels[i].split('_',1)[0], row_labels[i].split('_',1)[1])
    return row_labels
        
ax, fig = init_table_plot()
table = ax.table(cellText=weights.values.T, rowLabels=list(weights.columns), 
         colLabels=list(weights.index), loc='center', colWidths = [0.25 for x in weights.index])
set_table_font(table, 18)
set_cell_heights(table, num_cols=len(weights.index), 
                 num_rows=len(weights.columns), cell_height=0.25, col_label_height=0.45)
rotate_col_labels(table, len(weights.index))
plt.savefig('common_weights.pdf', bbox_inches='tight')
plt.close()

ax, fig = init_table_plot()
table = ax.table(cellText=weights_selected.values.T, rowColours=row_colors_selected, 
                 colLabels=list(weights_selected.index), rowLabels=list(weights_selected.columns), 
                 loc='center', colWidths = [0.3 for x in weights_selected.index])
set_table_font(table, 18)
set_cell_heights(table, len(weights_selected.index), len(weights_selected.columns), 0.25, 0.45)
rotate_col_labels(table, len(weights_selected.index))
plt.savefig('common_weights_selected.pdf', bbox_inches='tight')
plt.close()

# ax, fig = init_table_plot()
# table = ax.table(cellText=all_features_weights.values, colLabels=list(all_features_weights.columns), 
#                  rowLabels=list(all_features_weights.index), loc='center', 
#                  colWidths = [0.9 for x in all_features_weights.columns])
# set_table_font(table, 20)
# set_cell_heights(table, num_cols=len(all_features_weights.columns), 
#                  num_rows=len(all_features_weights.index), cell_height=0.25, col_label_height=0.45)
# plt.savefig('all_weights.pdf', bbox_inches='tight')
# plt.close()

# ax, fig = init_table_plot()
# table = ax.table(cellText=all_features_weights.values.T, 
#                  colLabels=get_pruned_labels(list(all_features_weights.index)), 
#                  rowLabels=list(all_features_weights.columns), loc='center', 
#                  colWidths = [0.25 for x in all_features_weights.index])
# set_table_font(table, 18)
# set_cell_heights(table, num_cols=len(all_features_weights.index), 
#                  num_rows=len(all_features_weights.columns), cell_height=0.25, col_label_height=0.45)
# rotate_col_labels(table, len(all_features_weights.index))
# plt.savefig('all_weights_transpose.pdf', bbox_inches='tight')
# plt.close()


# just the immutable weights
ax, fig = init_table_plot()
table = ax.table(cellText=immutable_weights.values.T, 
                 rowLabels=list(immutable_weights.columns), 
                 colLabels=get_pruned_labels(list(immutable_weights.index)), loc='center', 
                 colWidths=[0.25 for x in immutable_weights.index])
set_table_font(table, 19)
set_cell_heights(table, num_cols=len(immutable_weights.index), 
                 num_rows=len(immutable_weights.columns), cell_height=0.25, col_label_height=0.55)
rotate_col_labels(table, len(immutable_weights.index))
plt.savefig('only_immutable_weights.pdf', bbox_inches='tight')
plt.close()