-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
37 lines (26 loc) · 1013 Bytes
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def plot(df, type, explain, k=10, **options):
if type == 'tabular':
_tabular_plot(df, explain, k=3) # k most important variable
if type == 'image':
_image_plot(df)
def _tabular_plot(df, explain, k, **options):
title = options.get('title', 'Explainations')
figsize = options.get('figsize', (10, 10))
color = options.get('color', 'blue')
fig = plt.figure(figsize=figsize)
data = {'variable': df.columns,
'value': explain}
df = pd.DataFrame(data=data, columns=['variable', 'value'])
df = df.sort_values('value', ascending=True).reset_index(drop=True)
if df.shape[0] > k:
df = df.iloc[:k]
df = df.sort_values('value', ascending=False).reset_index(drop=True)
plt.barh('variable', 'value', data=df, color=color)
plt.title(title)
plt.show()
def _image_plot(df): # df array
plt.imshow(df, cmap='hot', interpolation='nearest')
plt.show()