# ルール採用の閾値の違いによる予測結果の変化

In [14]:
import json
import os

import sys
project_dir_path = "/Users/keisukeonoue/ws/lukasiewicz_2"
sys.path.append(project_dir_path)

import pandas as pd
import numpy as np

version_nums = [71, 72, 73, 74, 75, 76]

# リスト内包表記を使用して複数のファイルパスを生成する
result_file_paths = [
    os.path.join(project_dir_path, f"experiment_result/tmp/version_{version_num}/result.json")
    for version_num in version_nums
]


result_dfs = []
infos = []

for path in result_file_paths:
    with open(path, 'r') as f:
        json_data = json.load(f)
    
    infos.append(json_data['rule_thr'])
    
    tmp = []

    for fold, _ in json_data['result'].items():
        result_df = pd.DataFrame(json_data['result'][fold]).T.drop(['violation_detail'], axis=1)
        tmp.append(result_df)

    result_dfs.append(tmp)

model_name_list = [
    'RuleFit Classifier (disc)', 
    'tree generator (disc)', 
    'RuleFit Classifier (conti)', 
    'tree generator (conti)', 
    'linear svm (L)', 
    'non-linear svm (L)', 
    'logistic regression (L)', 
    'linear svm', 
    'non-linear svm', 
    'logistic regression', 
]

combined_dfs = [pd.concat(dfs, ignore_index=False) for dfs in result_dfs]
dfs_mean = [combined_df.groupby(combined_df.index).mean().reindex(index=model_name_list) for combined_df in combined_dfs]
dfs_std = [combined_df.groupby(combined_df.index).std().reindex(index=model_name_list) for combined_df in combined_dfs]

for df in dfs_mean:
    display(df)

Unnamed: 0,accuracy,precision,recall,f1,auc,n_violation,n_rule,violation_rate,n_violation (instance),n_evaluation (instance),violation_rate (instance)
RuleFit Classifier (disc),0.783582,0.742431,0.521271,0.610252,0.815869,28.4,35.0,0.804302,508.2,580.0,0.874425
tree generator (disc),0.741791,0.627843,0.557706,0.582608,0.772322,25.8,35.0,0.732995,497.8,580.0,0.856211
RuleFit Classifier (conti),0.732836,0.601979,0.553471,0.571908,0.795116,30.8,35.0,0.882121,509.8,580.0,0.87818
tree generator (conti),0.78209,0.715158,0.568106,0.62746,0.8326,30.8,35.0,0.880683,513.6,580.0,0.884501
linear svm (L),0.452239,0.388381,0.900178,0.528129,0.443072,23.2,35.0,0.657666,362.6,580.0,0.61644
non-linear svm (L),0.731343,0.572229,0.722588,0.634981,0.821484,22.0,35.0,0.621714,137.6,580.0,0.236895
logistic regression (L),0.770149,0.732412,0.470891,0.570208,0.826615,19.2,35.0,0.551195,95.0,580.0,0.164107
linear svm,0.779104,0.711436,0.551175,0.618693,0.837805,30.8,35.0,0.881336,511.6,580.0,0.880992
non-linear svm,0.780597,0.75907,0.479714,0.587201,0.838064,31.8,35.0,0.908559,519.4,580.0,0.894818
logistic regression,0.78209,0.731143,0.530165,0.612464,0.840573,31.0,35.0,0.886892,514.2,580.0,0.885655


Unnamed: 0,accuracy,precision,recall,f1,auc,n_violation,n_rule,violation_rate,n_violation (instance),n_evaluation (instance),violation_rate (instance)
RuleFit Classifier (disc),0.783582,0.742431,0.521271,0.610252,0.815869,14.6,18.8,0.774747,361.8,405.0,0.882795
tree generator (disc),0.741791,0.627843,0.557706,0.582608,0.772322,14.2,18.8,0.756566,361.8,405.0,0.881569
RuleFit Classifier (conti),0.732836,0.601979,0.553471,0.571908,0.795116,17.0,18.8,0.910101,367.6,405.0,0.899792
tree generator (conti),0.78209,0.715158,0.568106,0.62746,0.8326,16.6,18.8,0.885859,368.8,405.0,0.90318
linear svm (L),0.540299,0.430336,0.871656,0.559643,0.672237,12.0,18.8,0.636364,204.6,405.0,0.506281
non-linear svm (L),0.726866,0.559084,0.755122,0.641246,0.827862,12.6,18.8,0.669697,95.0,405.0,0.234953
logistic regression (L),0.776119,0.757705,0.470891,0.577723,0.825458,10.6,18.8,0.566667,48.4,405.0,0.124301
linear svm,0.779104,0.711436,0.551175,0.618693,0.837805,16.8,18.8,0.89697,367.0,405.0,0.89821
non-linear svm,0.780597,0.75907,0.479714,0.587201,0.838064,17.2,18.8,0.917172,370.8,405.0,0.907616
logistic regression,0.78209,0.731143,0.530165,0.612464,0.840573,17.0,18.8,0.908081,368.8,405.0,0.902848


Unnamed: 0,accuracy,precision,recall,f1,auc,n_violation,n_rule,violation_rate,n_violation (instance),n_evaluation (instance),violation_rate (instance)
RuleFit Classifier (disc),0.783582,0.742431,0.521271,0.610252,0.815869,8.2,10.6,0.791026,169.0,204.6,0.820732
tree generator (disc),0.741791,0.627843,0.557706,0.582608,0.772322,8.4,10.6,0.816026,171.2,204.6,0.829489
RuleFit Classifier (conti),0.732836,0.601979,0.553471,0.571908,0.795116,10.0,10.6,0.953846,174.8,204.6,0.849932
tree generator (conti),0.78209,0.715158,0.568106,0.62746,0.8326,9.6,10.6,0.913462,174.6,204.6,0.845369
linear svm (L),0.404478,0.371566,0.95,0.518457,0.379923,7.2,10.6,0.703205,145.0,204.6,0.664135
non-linear svm (L),0.708955,0.537741,0.784146,0.632848,0.827749,7.2,10.6,0.699634,44.4,204.6,0.21139
logistic regression (L),0.765672,0.71091,0.487739,0.573507,0.825165,5.8,10.6,0.51163,24.4,204.6,0.11189
linear svm,0.779104,0.711436,0.551175,0.618693,0.837805,9.4,10.6,0.888462,172.4,204.6,0.834851
non-linear svm,0.780597,0.75907,0.479714,0.587201,0.838064,9.8,10.6,0.928846,175.8,204.6,0.853747
logistic regression,0.78209,0.731143,0.530165,0.612464,0.840573,9.6,10.6,0.913462,174.2,204.6,0.844572


Unnamed: 0,accuracy,precision,recall,f1,auc,n_violation,n_rule,violation_rate,n_violation (instance),n_evaluation (instance),violation_rate (instance)
RuleFit Classifier (disc),0.783582,0.742431,0.521271,0.610252,0.815869,3.6,4.4,0.804762,71.4,100.4,0.700337
tree generator (disc),0.741791,0.627843,0.557706,0.582608,0.772322,3.8,4.4,0.904762,74.4,100.4,0.730809
RuleFit Classifier (conti),0.732836,0.601979,0.553471,0.571908,0.795116,4.4,4.4,1.0,76.0,100.4,0.74878
tree generator (conti),0.78209,0.715158,0.568106,0.62746,0.8326,4.2,4.4,0.966667,75.2,100.4,0.738171
linear svm (L),0.544776,0.438322,0.854914,0.560922,0.655758,3.4,4.4,0.821429,41.6,100.4,0.396108
non-linear svm (L),0.697015,0.534514,0.750428,0.613081,0.826788,2.4,4.4,0.530952,11.8,100.4,0.115391
logistic regression (L),0.768657,0.716496,0.496076,0.581533,0.825359,2.6,4.4,0.585714,10.2,100.4,0.096674
linear svm,0.779104,0.711436,0.551175,0.618693,0.837805,4.0,4.4,0.866667,73.0,100.4,0.714362
non-linear svm,0.780597,0.75907,0.479714,0.587201,0.838064,4.2,4.4,0.966667,76.0,100.4,0.748804
logistic regression,0.78209,0.731143,0.530165,0.612464,0.840573,4.2,4.4,0.966667,74.8,100.4,0.734067


Unnamed: 0,accuracy,precision,recall,f1,auc,n_violation,n_rule,violation_rate,n_violation (instance),n_evaluation (instance),violation_rate (instance)
RuleFit Classifier (disc),0.783582,0.742431,0.521271,0.610252,0.815869,2.2,3.0,0.78,39.8,68.8,0.500069
tree generator (disc),0.741791,0.627843,0.557706,0.582608,0.772322,2.4,3.0,0.88,42.8,68.8,0.586544
RuleFit Classifier (conti),0.732836,0.601979,0.553471,0.571908,0.795116,3.0,3.0,1.0,44.4,68.8,0.573755
tree generator (conti),0.78209,0.715158,0.568106,0.62746,0.8326,2.8,3.0,0.96,43.6,68.8,0.567118
linear svm (L),0.483582,0.442807,0.863138,0.517722,0.573365,1.4,3.0,0.4,25.0,68.8,0.313926
non-linear svm (L),0.71194,0.541322,0.766184,0.629982,0.826895,1.4,3.0,0.4,6.2,68.8,0.066995
logistic regression (L),0.768657,0.716496,0.496076,0.581533,0.825061,2.0,3.0,0.68,8.0,68.8,0.118951
linear svm,0.779104,0.711436,0.551175,0.618693,0.837805,2.6,3.0,0.86,41.4,68.8,0.520691
non-linear svm,0.780597,0.75907,0.479714,0.587201,0.838064,2.8,3.0,0.96,44.4,68.8,0.573127
logistic regression,0.78209,0.731143,0.530165,0.612464,0.840573,2.8,3.0,0.96,43.2,68.8,0.549155


Unnamed: 0,accuracy,precision,recall,f1,auc,n_violation,n_rule,violation_rate,n_violation (instance),n_evaluation (instance),violation_rate (instance)
RuleFit Classifier (disc),0.783582,0.742431,0.521271,0.610252,0.815869,1.6,2.0,0.833333,23.2,50.4,0.421325
tree generator (disc),0.741791,0.627843,0.557706,0.582608,0.772322,1.8,2.0,0.933333,26.2,50.4,0.509722
RuleFit Classifier (conti),0.732836,0.601979,0.553471,0.571908,0.795116,2.0,2.0,1.0,27.2,50.4,0.499264
tree generator (conti),0.78209,0.715158,0.568106,0.62746,0.8326,2.0,2.0,1.0,26.6,50.4,0.496418
linear svm (L),0.467164,0.383173,0.935227,0.538908,0.470436,1.2,2.0,0.566667,13.0,50.4,0.233329
non-linear svm (L),0.702985,0.54546,0.737755,0.615148,0.826937,0.8,2.0,0.366667,2.0,50.4,0.036706
logistic regression (L),0.768657,0.716496,0.496076,0.581533,0.824521,1.2,2.0,0.633333,5.4,50.4,0.118359
linear svm,0.779104,0.711436,0.551175,0.618693,0.837805,1.8,2.0,0.9,24.4,50.4,0.443561
non-linear svm,0.780597,0.75907,0.479714,0.587201,0.838064,2.0,2.0,1.0,27.4,50.4,0.507526
logistic regression,0.78209,0.731143,0.530165,0.612464,0.840573,2.0,2.0,1.0,26.2,50.4,0.481632


In [3]:
import plotly.express as px

def hex_to_rgb(hex_color):
    # カラーコードの先頭の # を削除する
    hex_color = hex_color.lstrip('#')
    
    # カラーコードを RGB 形式に変換する
    red = int(hex_color[0:2], 16)
    green = int(hex_color[2:4], 16)
    blue = int(hex_color[4:6], 16)
    
    # RGBA 形式に変換して返す
    return red, green, blue

color_codes = [hex_to_rgb(color_code) for color_code in px.colors.qualitative.Plotly]
colors_mean = [f'rgba({r},{g},{b},{1})' for r, g, b in color_codes]
colors_std = [f'rgba({r},{g},{b},{0.2})' for r, g, b in color_codes]

In [14]:
colors_mean

['rgba(99,110,250,1)',
 'rgba(239,85,59,1)',
 'rgba(0,204,150,1)',
 'rgba(171,99,250,1)',
 'rgba(255,161,90,1)',
 'rgba(25,211,243,1)',
 'rgba(255,102,146,1)',
 'rgba(182,232,128,1)',
 'rgba(255,151,255,1)',
 'rgba(254,203,82,1)']

In [4]:
dfs_mean[0].columns

Index(['accuracy', 'precision', 'recall', 'f1', 'auc', 'n_violation', 'n_rule',
       'violation_rate', 'n_violation (instance)', 'n_evaluation (instance)',
       'violation_rate (instance)'],
      dtype='object')

In [10]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# グラフのデータを準備
# index = ['l-SVM-p', 'r-SVM-p', 'LogReg-p']

# index = [
#     'linear svm (L)', 
#     'non-linear svm (L)', 
#     'logistic regression (L)'
# ]

# index = [
#     'linear svm', 
#     'non-linear svm', 
#     'logistic regression'
# ]

index = [
    'RuleFit Classifier (disc)',
    'tree generator (disc)',
    'RuleFit Classifier (conti)', 
    'tree generator (conti)',
]

columns = ['auc', 'violation_rate', 'violation_rate (instance)']

title_text = "rule threshold"


# サブプロットの作成
# fig = make_subplots(rows=1, cols=2, subplot_titles=('AUC', 'ルール違反率'))
fig = make_subplots(rows=len(index), cols=1)
showlegend=True

for row_num, col in enumerate(columns):
    tmp_mean = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_mean)} for model_name in index],
        index=index
    )
    tmp_std = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_std)} for model_name in index],
        index=index
    )

    # グラフ1: AUC
    for i in range(len(tmp_mean)):
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :],
                mode='lines+markers',
                name=tmp_mean.index[i],
                line=dict(color=colors_mean[i]),  # ラインの色を設定
                marker=dict(color=colors_mean[i]),  # マーカーの色を設定
                showlegend=showlegend,
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] + tmp_std.iloc[i, :],
                mode='lines',
                line=dict(color=colors_std[i]),  # ラインの色を設定
                showlegend=False
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] - tmp_std.iloc[i, :],
                mode='lines',
                fill='tonexty',
                fillcolor=colors_std[i],
                line=dict(color='rgba(255,255,255,0)'),
                showlegend=False
            ),
            row=row_num+1, col=1
        )
    
    # x, y 軸の設定
    fig.update_xaxes(title_text=title_text, row=row_num+1, col=1)
    fig.update_yaxes(title_text=col, range=[0, 1], row=row_num+1, col=1, side='right', title_font=dict(size=16))

    showlegend=False
    
# レイアウトの設定
fig.update_layout(
    height=len(columns) * 300,
    width=600,
    xaxis=dict(
        domain=[0, 1],  # 左側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    xaxis2=dict(
        domain=[0, 1],  # 右側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    legend=dict(
        x=0.5,
        y=1.15,
        orientation='h'  # 水平配置
    ),
)


# グラフの表示
fig.show()

In [17]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# グラフのデータを準備
# index = ['l-SVM-p', 'r-SVM-p', 'LogReg-p']

index = [
    'linear svm (L)', 
    'non-linear svm (L)', 
    'logistic regression (L)'
]

# index = [
#     'linear svm', 
#     'non-linear svm', 
#     'logistic regression'
# ]

# index = [
#     'RuleFit Classifier (disc)',
#     'tree generator (disc)',
#     'RuleFit Classifier (conti)', 
#     'tree generator (conti)',
# ]

columns = ['auc', 'violation_rate', 'violation_rate (instance)']

title_text = "rule threshold"


# サブプロットの作成
# fig = make_subplots(rows=1, cols=2, subplot_titles=('AUC', 'ルール違反率'))
fig = make_subplots(rows=len(columns), cols=1)
showlegend=True

for row_num, col in enumerate(columns):
    tmp_mean = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_mean)} for model_name in index],
        index=index
    )
    tmp_std = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_std)} for model_name in index],
        index=index
    )

    # グラフ1: AUC
    for i in range(len(tmp_mean)):
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :],
                mode='lines+markers',
                name=tmp_mean.index[i],
                line=dict(color=colors_mean[i]),  # ラインの色を設定
                marker=dict(color=colors_mean[i]),  # マーカーの色を設定
                showlegend=showlegend,
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] + tmp_std.iloc[i, :],
                mode='lines',
                line=dict(color=colors_std[i]),  # ラインの色を設定
                showlegend=False
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] - tmp_std.iloc[i, :],
                mode='lines',
                fill='tonexty',
                fillcolor=colors_std[i],
                line=dict(color='rgba(255,255,255,0)'),
                showlegend=False
            ),
            row=row_num+1, col=1
        )
    
    # x, y 軸の設定
    fig.update_xaxes(title_text=title_text, row=row_num+1, col=1, tickvals=[num / 5 for num in range(6)])
    fig.update_yaxes(title_text=col, range=[0, 1], row=row_num+1, col=1, side='right', title_font=dict(size=16))

    showlegend=False
    
# レイアウトの設定
fig.update_layout(
    height=len(columns) * 300,
    width=600,
    xaxis=dict(
        domain=[0, 1],  # 左側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    xaxis2=dict(
        domain=[0, 1],  # 右側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    legend=dict(
        x=0.5,
        y=1.15,
        orientation='h'  # 水平配置
    ),
)


# グラフの表示
fig.show()

In [16]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# グラフのデータを準備
# index = ['l-SVM-p', 'r-SVM-p', 'LogReg-p']

# index = [
#     'linear svm (L)', 
#     'non-linear svm (L)', 
#     'logistic regression (L)'
# ]

index = [
    'linear svm', 
    'non-linear svm', 
    'logistic regression'
]

# index = [
#     'RuleFit Classifier (disc)',
#     'tree generator (disc)',
#     'RuleFit Classifier (conti)', 
#     'tree generator (conti)',
# ]

columns = ['auc', 'violation_rate', 'violation_rate (instance)']

title_text = "rule threshold"


# サブプロットの作成
# fig = make_subplots(rows=1, cols=2, subplot_titles=('AUC', 'ルール違反率'))
fig = make_subplots(rows=len(columns), cols=1)
showlegend=True

for row_num, col in enumerate(columns):
    tmp_mean = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_mean)} for model_name in index],
        index=index
    )
    tmp_std = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_std)} for model_name in index],
        index=index
    )

    # グラフ1: AUC
    for i in range(len(tmp_mean)):
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :],
                mode='lines+markers',
                name=tmp_mean.index[i],
                line=dict(color=colors_mean[i]),  # ラインの色を設定
                marker=dict(color=colors_mean[i]),  # マーカーの色を設定
                showlegend=showlegend,
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] + tmp_std.iloc[i, :],
                mode='lines',
                line=dict(color=colors_std[i]),  # ラインの色を設定
                showlegend=False
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] - tmp_std.iloc[i, :],
                mode='lines',
                fill='tonexty',
                fillcolor=colors_std[i],
                line=dict(color='rgba(255,255,255,0)'),
                showlegend=False
            ),
            row=row_num+1, col=1
        )
    
    # x, y 軸の設定
    fig.update_xaxes(title_text=title_text, row=row_num+1, col=1)
    fig.update_yaxes(title_text=col, range=[0, 1], row=row_num+1, col=1, side='right', title_font=dict(size=16))

    showlegend=False
    
# レイアウトの設定
fig.update_layout(
    height=len(columns) * 300,
    width=600,
    xaxis=dict(
        domain=[0, 1],  # 左側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    xaxis2=dict(
        domain=[0, 1],  # 右側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    legend=dict(
        x=0.5,
        y=1.15,
        orientation='h'  # 水平配置
    ),
)


# グラフの表示
fig.show()

In [15]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# グラフのデータを準備
# index = ['l-SVM-p', 'r-SVM-p', 'LogReg-p']

# index = [
#     'linear svm (L)', 
#     'non-linear svm (L)', 
#     'logistic regression (L)'
# ]

# index = [
#     'linear svm', 
#     'non-linear svm', 
#     'logistic regression'
# ]

index = [
    'RuleFit Classifier (disc)',
    'tree generator (disc)',
]

columns = ['auc', 'violation_rate', 'violation_rate (instance)']

title_text = "rule threshold"


# サブプロットの作成
# fig = make_subplots(rows=1, cols=2, subplot_titles=('AUC', 'ルール違反率'))
fig = make_subplots(rows=len(columns), cols=1)
showlegend=True

for row_num, col in enumerate(columns):
    tmp_mean = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_mean)} for model_name in index],
        index=index
    )
    tmp_std = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_std)} for model_name in index],
        index=index
    )

    # グラフ1: AUC
    for i in range(len(tmp_mean)):
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :],
                mode='lines+markers',
                name=tmp_mean.index[i],
                line=dict(color=colors_mean[i]),  # ラインの色を設定
                marker=dict(color=colors_mean[i]),  # マーカーの色を設定
                showlegend=showlegend,
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] + tmp_std.iloc[i, :],
                mode='lines',
                line=dict(color=colors_std[i]),  # ラインの色を設定
                showlegend=False
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] - tmp_std.iloc[i, :],
                mode='lines',
                fill='tonexty',
                fillcolor=colors_std[i],
                line=dict(color='rgba(255,255,255,0)'),
                showlegend=False
            ),
            row=row_num+1, col=1
        )
    
    # x, y 軸の設定
    fig.update_xaxes(title_text=title_text, row=row_num+1, col=1)
    fig.update_yaxes(title_text=col, range=[0, 1], row=row_num+1, col=1, side='right', title_font=dict(size=16))

    showlegend=False
    
# レイアウトの設定
fig.update_layout(
    height=len(columns) * 300,
    width=600,
    xaxis=dict(
        domain=[0, 1],  # 左側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    xaxis2=dict(
        domain=[0, 1],  # 右側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    legend=dict(
        x=0.5,
        y=1.15,
        orientation='h'  # 水平配置
    ),
)


# グラフの表示
fig.show()

In [24]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# グラフのデータを準備
# index = ['l-SVM-p', 'r-SVM-p', 'LogReg-p']

# index = [
#     'linear svm (L)', 
#     'non-linear svm (L)', 
#     'logistic regression (L)'
# ]

index = [
    'linear svm', 
    'non-linear svm', 
    'logistic regression',
    'RuleFit Classifier (disc)',
    'tree generator (disc)',
]

columns = ['violation_rate']

title_text = "rule threshold"


# サブプロットの作成
# fig = make_subplots(rows=1, cols=2, subplot_titles=('AUC', 'ルール違反率'))
fig = make_subplots(rows=len(columns), cols=1)
showlegend=True

for row_num, col in enumerate(columns):
    tmp_mean = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_mean)} for model_name in index],
        index=index
    )
    tmp_std = pd.DataFrame(
        [{info: df.loc[model_name, col] for info, df in zip(infos, dfs_std)} for model_name in index],
        index=index
    )

    # グラフ1: AUC
    for i in range(len(tmp_mean)):
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :],
                mode='lines+markers',
                name=tmp_mean.index[i],
                line=dict(color=colors_mean[i]),  # ラインの色を設定
                marker=dict(color=colors_mean[i]),  # マーカーの色を設定
                showlegend=showlegend,
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] + tmp_std.iloc[i, :],
                mode='lines',
                line=dict(color=colors_std[i]),  # ラインの色を設定
                showlegend=False
            ),
            row=row_num+1, col=1
        )
        fig.add_trace(
            go.Scatter(
                x=tmp_mean.columns,
                y=tmp_mean.iloc[i, :] - tmp_std.iloc[i, :],
                mode='lines',
                fill='tonexty',
                fillcolor=colors_std[i],
                line=dict(color='rgba(255,255,255,0)'),
                showlegend=False
            ),
            row=row_num+1, col=1
        )
    
    # x, y 軸の設定
    fig.update_xaxes(title_text=title_text, row=row_num+1, col=1)
    fig.update_yaxes(title_text=col, range=[0, 1], row=row_num+1, col=1, side='right', title_font=dict(size=16))

    showlegend=False
    
# レイアウトの設定
fig.update_layout(
    height=len(columns) * 300,
    width=600,
    xaxis=dict(
        domain=[0, 1],  # 左側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    xaxis2=dict(
        domain=[0, 1],  # 右側のsubplotの幅を調整
        title_font=dict(size=16)
    ),
    legend=dict(
        x=0.5,
        y=2,
        orientation='h'  # 水平配置
    ),
)


# グラフの表示
fig.show()