# Plotlyライブラリの学習

In [56]:
# データ構造
import numpy as np
import pandas as pd

# Plotly関連
import plotly
import plotly.plotly as py
import plotly.graph_objs as go
import plotly.figure_factory as ff
from plotly import tools

# その他
import mojimoji
import warnings
warnings.filterwarnings('ignore')

In [121]:
def multi_title(df: pd.DataFrame, x=True) -> str:
    '''MultiIndexを持つDataFrameのindex名を結合'''
    res = ''
    names = df.columns.names if x == True else df.index.names
    for i, name in enumerate(df.index.names):
        res += name
        if i == len(df.index.names) - 1: break
        res += '-'
    return res

def multi_label(df: pd.DataFrame):
    '''MultiIndexを持つDataFrameのindex要素名を結合'''
    res = []
    n_df = len(df)
    for i in range(n_df):
        add = df.index.levels[0][df.index.codes[0][i]] + '-' + str(df.index.levels[1][df.index.codes[1][i]])
        res.append(add)
    return res

In [113]:
df = pd.read_csv('train.csv', index_col=0).drop(['Name', 'Ticket', 'SibSp', 'Parch'], axis=1)
ct = pd.crosstab([df['Sex'], df['Survived']], [df['Pclass']])
ct

Unnamed: 0_level_0,Pclass,1,2,3
Sex,Survived,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
female,0,3,6,72
female,1,91,70,72
male,0,77,91,300
male,1,45,17,47


In [117]:
z = ct.values
z

array([[  3,   6,  72],
       [ 91,  70,  72],
       [ 77,  91, 300],
       [ 45,  17,  47]])

In [124]:
x = list(map(lambda x: mojimoji.han_to_zen(str(x)), ct.columns))
x

['１', '２', '３']

In [122]:
y = multi_label(ct)
y

['female-0', 'female-1', 'male-0', 'male-1']

## ヒートマップ

In [125]:
fig = ff.create_annotated_heatmap(z=z,
                                  x=x,
                                  y=y,
                                  colorscale='Blues',
                                  showscale=True,
                                  reversescale=True,
                                  opacity=0.75)
fig['layout']['xaxis'].update(title=ct.columns.names[0], showgrid=False, side='bottom')
fig['layout']['yaxis'].update(title=multi_title(ct, False))
fig['layout'].update(title='クロス集計表')
py.iplot(fig, show_link=False)