# サポートベクターマシン 実際のデータによるデモ
UC Irvine Machine Learning Repository で公開されている Adult データセットを用いてMADlibのサポートベクターマシンを試してみる

## 前準備

In [431]:
# ipython-sqlをロード
%reload_ext sql
# 文字コードをUTF-8にセット
# (データベースの文字コードもUTF-8にセットしておくこと)
import sys
reload(sys)
sys.setdefaultencoding("utf-8")

## PostgreSQLに接続＆接続確認

In [435]:
%sql postgresql://postgres@centos72/postgres
%sql SELECT version();

version
"PostgreSQL 9.6.1 on x86_64-pc-linux-gnu, compiled by gcc (GCC) 4.8.5 20150623 (Red Hat 4.8.5-4), 64-bit"


## MADlibが正常にインストールされているか確認

In [436]:
%sql SELECT madlib.version();

version
"MADlib version: 1.10.0-dev, git revision: rel/v1.9.1-8-g82e56a4, cmake configuration time: 2016年 11月 22日 火曜日 00:25:23 UTC, build type: RelWithDebInfo, build system: Linux-3.10.0-327.el7.x86_64, C compiler: gcc 4.8.5, C++ compiler: g++ 4.8.5"


## AdultデータをローカルPCにダウンロード

In [434]:
#import urllib
#urllib.urlretrieve('http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data', '/Users/Masanori/jupyter/MADlib/adult_dataset/adult.data')
#urllib.urlretrieve('http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test', '/Users/Masanori/jupyter/MADlib/adult_dataset/adult.test')

## テーブルを作成しテストデータをPostgreSQLにコピー

In [493]:
# pandasを使ってデータを読み込む 
import pandas as pd
from sqlalchemy import create_engine

adult_data = pd.read_csv(
                             'adult_dataset/adult.data'
                            ,header=None
                            ,names=[
                                 'age'
                                ,'workclass'
                                , 'fnlwgt'
                                , 'education'
                                , 'education_num'
                                , 'marital_status'
                                , 'occupation'
                                , 'relationship'
                                , 'race'
                                , 'sex'
                                , 'capital_gain'
                                , 'capital_loss'
                                , 'hours_per_week'
                                , 'native_country'
                                , 'class'
                            ]
                        )
# ipython-sql の機能を使ってPostgreSQLにデータを送信する
# テーブルは自動作成される
%sql DROP TABLE IF EXISTS adult_data;
%sql PERSIST adult_data
%sql SELECT * FROM adult_data LIMIT 5;

index,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,class
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K


## データが正常にPostgreSQLに送信されたことを確認
ローカルPCのデータの件数とPostgreSQLの対象テーブルのレコード件数を比較

In [441]:
adult_data['age'].count()

32561

In [446]:
%sql SELECT count(1) FROM adult_data;

count
32561


## Adultデータセットのテキストの属性情報を数値に変換する
サポートベクターマシンのインプットデータとしてテキストは使えない模様

In [455]:
%%sql

DROP TABLE IF EXISTS adult_data_parsed;
CREATE TABLE 
    adult_data_parsed
AS
    SELECT
         index
        ,age::FLOAT8
        ,CASE workclass
                    WHEN ' Private' THEN '1'::FLOAT8
                    WHEN ' Self-emp-not-inc' THEN '2'::FLOAT8
                    WHEN ' Self-emp-inc' THEN '3'::FLOAT8
                    WHEN ' Federal-gov' THEN '4'::FLOAT8
                    WHEN ' Local-gov' THEN '5'::FLOAT8
                    WHEN ' State-gov' THEN '6'::FLOAT8
                    WHEN ' Without-pay' THEN '7'::FLOAT8
                    WHEN ' Never-worked' THEN '8'::FLOAT8
                END AS workclass
        ,CASE education
                    WHEN ' Bachelors' THEN '1'::FLOAT8
                    WHEN ' Some-college' THEN '2'::FLOAT8
                    WHEN ' 11th' THEN '3'::FLOAT8
                    WHEN ' HS-grad' THEN '4'::FLOAT8
                    WHEN ' Prof-school' THEN '5'::FLOAT8
                    WHEN ' Assoc-acdm' THEN '6'::FLOAT8
                    WHEN ' Assoc-voc' THEN '7'::FLOAT8
                    WHEN ' 9th' THEN '8'::FLOAT8
                    WHEN ' 7th-8th' THEN '9'::FLOAT8
                    WHEN ' 12th' THEN '10'::FLOAT8
                    WHEN ' Masters' THEN '11'::FLOAT8
                    WHEN ' 1st-4th' THEN '12'::FLOAT8
                    WHEN ' 10th' THEN '13'::FLOAT8
                    WHEN ' Doctorate' THEN '14'::FLOAT8
                    WHEN ' 5th-6th' THEN '15'::FLOAT8
                    END AS education
        ,education_num::FLOAT8
        ,CASE marital_status
                    WHEN ' Married-civ-spouse' THEN '1'::FLOAT8
                    WHEN ' Divorced' THEN '2'::FLOAT8
                    WHEN ' Never-married' THEN '3'::FLOAT8
                    WHEN ' Separated' THEN '4'::FLOAT8
                    WHEN ' Widowed' THEN '5'::FLOAT8
                    WHEN ' Married-spouse-absent' THEN '6'::FLOAT8
                    WHEN ' Married-AF-spouse' THEN '7'::FLOAT8
                    END AS marital_status
        ,CASE occupation
                    WHEN ' Tech-support' THEN '1'::FLOAT8
                    WHEN ' Craft-repair' THEN '2'::FLOAT8
                    WHEN ' Other-service' THEN '3'::FLOAT8
                    WHEN ' Sales' THEN '4'::FLOAT8
                    WHEN ' Exec-managerial' THEN '5'::FLOAT8
                    WHEN ' Prof-specialty' THEN '6'::FLOAT8
                    WHEN ' Handlers-cleaners' THEN '7'::FLOAT8
                    WHEN ' Machine-op-inspct' THEN '8'::FLOAT8
                    WHEN ' Adm-clerical' THEN '9'::FLOAT8
                    WHEN ' Farming-fishing' THEN '10'::FLOAT8
                    WHEN ' Transport-moving' THEN '11'::FLOAT8
                    WHEN ' Priv-house-serv' THEN '12'::FLOAT8
                    WHEN ' Protective-serv' THEN '13'::FLOAT8
                    WHEN ' Armed-Forces' THEN '14'::FLOAT8
                    END AS occupation
        ,CASE relationship
                    WHEN ' Wife' THEN '1'::FLOAT8
                    WHEN ' Own-child' THEN '2'::FLOAT8
                    WHEN ' Husband' THEN '3'::FLOAT8
                    WHEN ' Not-in-family' THEN '4'::FLOAT8
                    WHEN ' Other-relative' THEN '5'::FLOAT8
                    WHEN ' Unmarried' THEN '6'::FLOAT8
                    END AS relationship
        ,CASE race
                    WHEN ' White' THEN '1'::FLOAT8
                    WHEN ' Asian-Pac-Islander' THEN '2'::FLOAT8
                    WHEN ' Amer-Indian-Eskimo' THEN '3'::FLOAT8
                    WHEN ' Other' THEN '4'::FLOAT8
                    WHEN ' Black' THEN '5'::FLOAT8
                    END AS race
        ,CASE sex 
                    WHEN ' Female' THEN '1'::FLOAT8
                    WHEN ' Male' THEN '2'::FLOAT8
                    END AS sex
        ,capital_gain::FLOAT8
        ,capital_loss::FLOAT8
        ,hours_per_week::FLOAT8
        ,CASE native_country
                    WHEN ' United-States' THEN '1'::FLOAT8
                    WHEN ' Cambodia' THEN '2'::FLOAT8
                    WHEN ' England' THEN '3'::FLOAT8
                    WHEN ' Puerto-Rico' THEN '4'::FLOAT8
                    WHEN ' Canada' THEN '5'::FLOAT8
                    WHEN ' Germany' THEN '6'::FLOAT8
                    WHEN ' Outlying-US(Guam-USVI-etc)' THEN '7'::FLOAT8
                    WHEN ' India' THEN '8'::FLOAT8
                    WHEN ' Japan' THEN '9'::FLOAT8
                    WHEN ' Greece' THEN '10'::FLOAT8
                    WHEN ' South, China' THEN '11'::FLOAT8
                    WHEN ' Cuba' THEN '12'::FLOAT8
                    WHEN ' Iran' THEN '13'::FLOAT8
                    WHEN ' Honduras' THEN '14'::FLOAT8
                    WHEN ' Philippines' THEN '15'::FLOAT8
                    WHEN ' Italy' THEN '16'::FLOAT8
                    WHEN ' Poland' THEN '17'::FLOAT8
                    WHEN ' Vietnam' THEN '18'::FLOAT8
                    WHEN ' Mexico' THEN '19'::FLOAT8
                    WHEN ' Portugal' THEN '20'::FLOAT8
                    WHEN ' Ireland' THEN '21'::FLOAT8
                    WHEN ' France' THEN '22'::FLOAT8
                    WHEN ' Dominican-Republic' THEN '23'::FLOAT8
                    WHEN ' Laos' THEN '24'::FLOAT8
                    WHEN ' Ecuador' THEN '25'::FLOAT8
                    WHEN ' Taiwan' THEN '26'::FLOAT8
                    WHEN ' Haiti' THEN '27'::FLOAT8
                    WHEN ' Columbia' THEN '28'::FLOAT8
                    WHEN ' Hungary' THEN '29'::FLOAT8
                    WHEN ' Guatemala' THEN '30'::FLOAT8
                    WHEN ' Nicaragua' THEN '31'::FLOAT8
                    WHEN ' Scotland' THEN '32'::FLOAT8
                    WHEN ' Thailand' THEN '33'::FLOAT8
                    WHEN ' Yugoslavia' THEN '34'::FLOAT8
                    WHEN ' El-Salvador' THEN '35'::FLOAT8
                    WHEN ' Trinadad&Tobago' THEN '36'::FLOAT8
                    WHEN ' Peru' THEN '37'::FLOAT8
                    WHEN ' Hong' THEN '38'::FLOAT8
                    WHEN ' Holand-Netherlands' THEN '39'::FLOAT8
                    END AS native_country
        ,CASE class
                    WHEN ' >50K' THEN '0'::FLOAT8
                    WHEN ' <=50K' THEN '1'::FLOAT8
                    END AS class
    FROM adult_data
;
SELECT * FROM adult_data_parsed LIMIT 5;

index,age,workclass,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,class
0,39.0,6.0,1.0,13.0,3.0,9.0,4.0,1.0,2.0,2174.0,0.0,40.0,1.0,1.0
1,50.0,2.0,1.0,13.0,1.0,5.0,3.0,1.0,2.0,0.0,0.0,13.0,1.0,1.0
2,38.0,1.0,4.0,9.0,2.0,7.0,4.0,1.0,2.0,0.0,0.0,40.0,1.0,1.0
3,53.0,1.0,3.0,7.0,1.0,7.0,3.0,5.0,2.0,0.0,0.0,40.0,1.0,1.0
4,28.0,1.0,1.0,13.0,1.0,6.0,1.0,5.0,1.0,0.0,0.0,40.0,12.0,1.0


## テストデータもPostgreSQLにロード
作成したモデルを評価するためのデータセットも提供されている。PostgreSQLに送信する。

In [444]:
adult_test = pd.read_csv(
                            'adult_dataset/adult.test'
                            , header=None
                            , skiprows=1
                            ,names=[
                                'age'
                                , 'workclass'
                                , 'fnlwgt'
                                , 'education'
                                , 'education_num'
                                ,'marital_status'
                                , 'occupation'
                                , 'relationship'
                                , 'race'
                                , 'sex'
                                ,'capital_gain'
                                , 'capital_loss'
                                , 'hours_per_week'
                                , 'native_country'
                                , 'class'
                            ]
                        )
%sql DROP TABLE IF EXISTS adult_test;
%sql PERSIST adult_test

u'Persisted adult_test'

In [445]:
adult_test['age'].count()

16281

In [447]:
%sql SELECT count(1) FROM adult_test;

count
16281


In [459]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed;
CREATE TABLE 
    adult_test_parsed
AS
    SELECT
        index
        ,age::FLOAT8
        ,CASE workclass
                    WHEN ' Private' THEN '1'::FLOAT8
                    WHEN ' Self-emp-not-inc' THEN '2'::FLOAT8
                    WHEN ' Self-emp-inc' THEN '3'::FLOAT8
                    WHEN ' Federal-gov' THEN '4'::FLOAT8
                    WHEN ' Local-gov' THEN '5'::FLOAT8
                    WHEN ' State-gov' THEN '6'::FLOAT8
                    WHEN ' Without-pay' THEN '7'::FLOAT8
                    WHEN ' Never-worked' THEN '8'::FLOAT8
                    END AS workclass
        ,CASE education
                    WHEN ' Bachelors' THEN '1'::FLOAT8
                    WHEN ' Some-college' THEN '2'::FLOAT8
                    WHEN ' 11th' THEN '3'::FLOAT8
                    WHEN ' HS-grad' THEN '4'::FLOAT8
                    WHEN ' Prof-school' THEN '5'::FLOAT8
                    WHEN ' Assoc-acdm' THEN '6'::FLOAT8
                    WHEN ' Assoc-voc' THEN '7'::FLOAT8
                    WHEN ' 9th' THEN '8'::FLOAT8
                    WHEN ' 7th-8th' THEN '9'::FLOAT8
                    WHEN ' 12th' THEN '10'::FLOAT8
                    WHEN ' Masters' THEN '11'::FLOAT8
                    WHEN ' 1st-4th' THEN '12'::FLOAT8
                    WHEN ' 10th' THEN '13'::FLOAT8
                    WHEN ' Doctorate' THEN '14'::FLOAT8
                    WHEN ' 5th-6th' THEN '15'::FLOAT8
                    END AS education
        ,education_num::FLOAT8
        ,CASE marital_status
                    WHEN ' Married-civ-spouse' THEN '1'::FLOAT8
                    WHEN ' Divorced' THEN '2'::FLOAT8
                    WHEN ' Never-married' THEN '3'::FLOAT8
                    WHEN ' Separated' THEN '4'::FLOAT8
                    WHEN ' Widowed' THEN '5'::FLOAT8
                    WHEN ' Married-spouse-absent' THEN '6'::FLOAT8
                    WHEN ' Married-AF-spouse' THEN '7'::FLOAT8
                    END AS marital_status
        ,CASE occupation
                    WHEN ' Tech-support' THEN '1'::FLOAT8
                    WHEN ' Craft-repair' THEN '2'::FLOAT8
                    WHEN ' Other-service' THEN '3'::FLOAT8
                    WHEN ' Sales' THEN '4'::FLOAT8
                    WHEN ' Exec-managerial' THEN '5'::FLOAT8
                    WHEN ' Prof-specialty' THEN '6'::FLOAT8
                    WHEN ' Handlers-cleaners' THEN '7'::FLOAT8
                    WHEN ' Machine-op-inspct' THEN '8'::FLOAT8
                    WHEN ' Adm-clerical' THEN '9'::FLOAT8
                    WHEN ' Farming-fishing' THEN '10'::FLOAT8
                    WHEN ' Transport-moving' THEN '11'::FLOAT8
                    WHEN ' Priv-house-serv' THEN '12'::FLOAT8
                    WHEN ' Protective-serv' THEN '13'::FLOAT8
                    WHEN ' Armed-Forces' THEN '14'::FLOAT8
                    END AS occupation
        ,CASE relationship
                    WHEN ' Wife' THEN '1'::FLOAT8
                    WHEN ' Own-child' THEN '2'::FLOAT8
                    WHEN ' Husband' THEN '3'::FLOAT8
                    WHEN ' Not-in-family' THEN '4'::FLOAT8
                    WHEN ' Other-relative' THEN '5'::FLOAT8
                    WHEN ' Unmarried' THEN '6'::FLOAT8
                    END AS relationship
        ,CASE race
                    WHEN ' White' THEN '1'::FLOAT8
                    WHEN ' Asian-Pac-Islander' THEN '2'::FLOAT8
                    WHEN ' Amer-Indian-Eskimo' THEN '3'::FLOAT8
                    WHEN ' Other' THEN '4'::FLOAT8
                    WHEN ' Black' THEN '5'::FLOAT8
                    END AS race
        ,CASE sex
                    WHEN ' Female' THEN '1'::FLOAT8
                    WHEN ' Male' THEN '2'::bigint::FLOAT8
                    END AS sex
        ,capital_gain::FLOAT8
        ,capital_loss::FLOAT8
        ,hours_per_week::FLOAT8
        ,CASE native_country
                    WHEN ' United-States' THEN '1'::FLOAT8
                    WHEN ' Cambodia' THEN '2'::FLOAT8
                    WHEN ' England' THEN '3'::FLOAT8
                    WHEN ' Puerto-Rico' THEN '4'::FLOAT8
                    WHEN ' Canada' THEN '5'::FLOAT8
                    WHEN ' Germany' THEN '6'::FLOAT8
                    WHEN ' Outlying-US(Guam-USVI-etc)' THEN '7'::FLOAT8
                    WHEN ' India' THEN '8'::FLOAT8
                    WHEN ' Japan' THEN '9'::FLOAT8
                    WHEN ' Greece' THEN '10'::FLOAT8
                    WHEN ' South, China' THEN '11'::FLOAT8
                    WHEN ' Cuba' THEN '12'::FLOAT8
                    WHEN ' Iran' THEN '13'::FLOAT8
                    WHEN ' Honduras' THEN '14'::FLOAT8
                    WHEN ' Philippines' THEN '15'::FLOAT8
                    WHEN ' Italy' THEN '16'::FLOAT8
                    WHEN ' Poland' THEN '17'::FLOAT8
                    WHEN ' Vietnam' THEN '18'::FLOAT8
                    WHEN ' Mexico' THEN '19'::FLOAT8
                    WHEN ' Portugal' THEN '20'::FLOAT8
                    WHEN ' Ireland' THEN '21'::FLOAT8
                    WHEN ' France' THEN '22'::FLOAT8
                    WHEN ' Dominican-Republic' THEN '23'::FLOAT8
                    WHEN ' Laos' THEN '24'::FLOAT8
                    WHEN ' Ecuador' THEN '25'::FLOAT8
                    WHEN ' Taiwan' THEN '26'::FLOAT8
                    WHEN ' Haiti' THEN '27'::FLOAT8
                    WHEN ' Columbia' THEN '28'::FLOAT8
                    WHEN ' Hungary' THEN '29'::FLOAT8
                    WHEN ' Guatemala' THEN '30'::FLOAT8
                    WHEN ' Nicaragua' THEN '31'::FLOAT8
                    WHEN ' Scotland' THEN '32'::FLOAT8
                    WHEN ' Thailand' THEN '33'::FLOAT8
                    WHEN ' Yugoslavia' THEN '34'::FLOAT8
                    WHEN ' El-Salvador' THEN '35'::FLOAT8
                    WHEN ' Trinadad&Tobago' THEN '36'::FLOAT8
                    WHEN ' Peru' THEN '37'::FLOAT8
                    WHEN ' Hong' THEN '38'::FLOAT8
                    WHEN ' Holand-Netherlands' THEN '39'::FLOAT8
                    END AS native_country
        ,CASE class 
                    WHEN ' >50K.' THEN '0'::FLOAT8
                    WHEN ' <=50K.' THEN '1'::FLOAT8
                    END AS class
    FROM adult_test
;
SELECT * FROM adult_test_parsed LIMIT 5;

index,age,workclass,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,class
0,25.0,1.0,3.0,7.0,3.0,8.0,2.0,5.0,2.0,0.0,0.0,40.0,1.0,1.0
1,38.0,1.0,4.0,9.0,1.0,10.0,3.0,1.0,2.0,0.0,0.0,50.0,1.0,1.0
2,28.0,5.0,6.0,12.0,1.0,13.0,3.0,1.0,2.0,0.0,0.0,40.0,1.0,0.0
3,44.0,1.0,2.0,10.0,1.0,8.0,3.0,5.0,2.0,7688.0,0.0,40.0,1.0,0.0
4,18.0,,2.0,10.0,3.0,,2.0,1.0,1.0,0.0,0.0,30.0,1.0,1.0


## 線形モデル作製
まずは線形モデルを作成、トレーニングを行う。ARRAYに記載しているカラムを独立変数として選択。

In [449]:
%%sql

DROP TABLE IF EXISTS
     adult_data_parsed_svm
    ,adult_data_parsed_svm_summary
;
SELECT
    madlib.svm_classification(
         'adult_data_parsed'
        ,'adult_data_parsed_svm'
        ,'class = 0.0'
        ,'ARRAY[
            age
            ,workclass
            ,education
            ,education_num
            ,marital_status
            ,occupation
            ,relationship
            ,race
            ,sex
            ,capital_gain
            ,capital_loss
            ,hours_per_week
            ,native_country
         ]'
    )
;
SELECT * FROM adult_data_parsed_svm;

coef,loss,norm_of_gradient,num_iterations,num_rows_processed,num_rows_skipped,dep_var_mapping
"[0.149209009007529, -0.0363348195001957, 0.27651499541279, 1.67171366788677, -33.9858870913864, 0.0278045141390653, 1.39414013066558, 0.183006182555921, -1.37951835928823, 0.00488959391793084, 0.00343621919868571, 0.166854534901831, -0.0824172601010893]",76077.9980464,364549.829656,100,29900,2661,"[False, True]"


## 線形モデルテスト
作成したモデルをテストデータを使って評価する。テストデータを使って所得が>50Kか予測させる。

In [452]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred;
SELECT
    madlib.svm_predict(
         'adult_data_parsed_svm'
        ,'adult_test_parsed'
        ,'index'
        ,'adult_test_parsed_pred'
    )
;
SELECT * FROM adult_test_parsed_pred LIMIT 10;

index,prediction,decision_function
0,False,-77.9737555024
1,False,-2.05605209361
2,True,0.289557726197
3,True,36.5569540315
5,False,-75.2887505237
7,True,24.0024455574
8,False,-67.2993069221
9,False,-13.392109812
10,True,31.6298505376
11,True,1.69748525202


## Area under the ROC curve による予測結果の検証を行うため評価用のテーブルを作成する。

In [454]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_roc_input;
CREATE TABLE
    adult_test_parsed_pred_roc_input
AS
    SELECT 
        ob.class::FLOAT8 AS observed
        ,CASE pre.prediction
            WHEN True THEN '0'::FLOAT8
            WHEN False THEN '1'::FLOAT8
            END AS prediction
    FROM adult_test_parsed_pred AS pre
        LEFT OUTER JOIN adult_test_parsed AS ob
            ON pre.index = ob.index
;
SELECT * FROM adult_test_parsed_pred_roc_input LIMIT 10;

observed,prediction
1.0,1.0
1.0,1.0
0.0,0.0
0.0,0.0
1.0,1.0
0.0,0.0
1.0,1.0
1.0,1.0
0.0,0.0
1.0,0.0


## area_under_roc 算出

In [460]:
%%sql 

DROP TABLE IF EXISTS adult_test_parsed_pred_roc_output;
SELECT
    madlib.area_under_roc(
        'adult_test_parsed_pred_roc_input'
        ,'adult_test_parsed_pred_roc_output'
        ,'prediction'
        ,'observed'
    )
;
SELECT * FROM adult_test_parsed_pred_roc_output;

area_under_roc
0.7573908867338783


## 非線形モデル作製
n_componentの値を変えながら複数の非線形モデルを作成して評価する。

## n_components=10

In [462]:
%%sql

DROP TABLE IF EXISTS
    adult_data_parsed_svm_gaussian_10
    ,adult_data_parsed_svm_gaussian_10_summary
    ,adult_data_parsed_svm_gaussian_10_random
    ;
SELECT
    madlib.svm_classification(
        'adult_data_parsed'
        ,'adult_data_parsed_svm_gaussian_10'
        ,'class = 0.0'
        ,'ARRAY[
            workclass
            ,education
            ,marital_status
            ,occupation
            ,relationship
            ,race
            ,sex
            ,native_country
        ]'
        ,'gaussian'
        ,'n_components=10'
        ,''
        ,'init_stepsize=2,
            max_iter=200'
    )
;
SELECT * FROM adult_data_parsed_svm_gaussian_10;

coef,loss,norm_of_gradient,num_iterations,num_rows_processed,num_rows_skipped,dep_var_mapping
"[-2.93651172290274, 0.341683783738794, -0.440133878862007, 0.257267436381406, 0.35899688525636, 0.134067316852449, -0.048281495957981, 0.505773126346411, -0.179603594680169, 0.686461643517295]",19719.5132261,106.026378711,160,29900,2661,"[False, True]"


In [463]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_10;
SELECT
    madlib.svm_predict(
        'adult_data_parsed_svm_gaussian_10'
        ,'adult_test_parsed'
        ,'index'
        ,'adult_test_parsed_pred_gaussian_10'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_10 LIMIT 10;

index,prediction,decision_function
0,False,-1.32993233113
1,False,-1.09707010755
2,True,0.321027412741
3,False,-2.14599864169
5,False,-0.597154717257
7,False,-0.715990222365
8,False,-1.48916354527
9,False,-1.24085773818
10,False,-0.923307390366
11,True,0.111854319862


In [466]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_10_roc_input;
CREATE TABLE
    adult_test_parsed_pred_gaussian_10_roc_input
AS
    SELECT
        ob.class::FLOAT8 AS observed
        ,CASE pre.prediction
            WHEN True THEN '0.0'::FLOAT8
            WHEN False THEN '1.0'::FLOAT8
            END AS prediction
    FROM adult_test_parsed_pred_gaussian_10 AS pre
            LEFT OUTER JOIN adult_test_parsed AS ob
                ON pre.index = ob.index
;
SELECT * FROM adult_test_parsed_pred_gaussian_10_roc_input LIMIT 10;

observed,prediction
1.0,1.0
1.0,1.0
0.0,0.0
0.0,1.0
1.0,1.0
0.0,1.0
1.0,1.0
1.0,1.0
0.0,1.0
1.0,0.0


In [469]:
%%sql 

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_10_roc_output;
SELECT
    madlib.area_under_roc(
        'adult_test_parsed_pred_gaussian_10_roc_input'
        ,'adult_test_parsed_pred_gaussian_10_roc_output'
        ,'prediction'
        ,'observed'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_10_roc_output;

area_under_roc
0.5323055370802969


## n_components=20

In [470]:
%%sql

DROP TABLE IF EXISTS
    adult_data_parsed_svm_gaussian_20
    ,adult_data_parsed_svm_gaussian_20_summary
    ,adult_data_parsed_svm_gaussian_20_random
;
SELECT
    madlib.svm_classification(
        'adult_data_parsed'
        ,'adult_data_parsed_svm_gaussian_20'
        ,'class = 0.0'
        ,'ARRAY[
            workclass
            ,education
            ,marital_status
            ,occupation
            ,relationship
            ,race
            ,sex
            ,native_country
        ]'
        ,'gaussian'
        ,'n_components=20'
        ,''
        ,'init_stepsize=2,
            max_iter=200'
    )
;
SELECT * FROM adult_data_parsed_svm_gaussian_20;

coef,loss,norm_of_gradient,num_iterations,num_rows_processed,num_rows_skipped,dep_var_mapping
"[-3.86814810268387, 0.449707764693438, -0.755980138466978, 0.181375858849701, 0.829237319869862, 0.00458928801845787, 0.171058151350693, -0.513100445299128, -0.430969534150961, 0.105338585609318, -0.237240804803337, 0.359755473290516, -0.335599590033492, 0.135979395068039, -0.0112713606630472, -0.274565699889732, -0.931919949164275, -0.860925475812818, 0.321416675647694, -1.0238848578281]",18849.1150496,89.7174444837,183,29900,2661,"[False, True]"


In [471]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_20;
SELECT
    madlib.svm_predict(
        'adult_data_parsed_svm_gaussian_20'
        ,'adult_test_parsed'
        ,'index'
        ,'adult_test_parsed_pred_gaussian_20'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_20 LIMIT 10;

index,prediction,decision_function
0,False,-1.17190698142
1,False,-1.18132717988
2,False,-0.0445688331689
3,False,-1.77818365967
5,False,-0.971939899064
7,False,-0.874358793898
8,False,-0.882414054576
9,False,-1.65001996431
10,False,-1.00112560513
11,False,-0.131453982087


In [472]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_20_roc_input;
CREATE TABLE
    adult_test_parsed_pred_gaussian_20_roc_input
AS
    SELECT
        ob.class::FLOAT8 AS observed
        ,CASE pre.prediction
            WHEN True THEN '0.0'::FLOAT8
            WHEN False THEN '1.0'::FLOAT8
            END AS prediction
    FROM adult_test_parsed_pred_gaussian_20 AS pre
            LEFT OUTER JOIN adult_test_parsed AS ob
                ON pre.index = ob.index
;
SELECT * FROM adult_test_parsed_pred_gaussian_20_roc_input LIMIT 10;

observed,prediction
1.0,1.0
1.0,1.0
0.0,1.0
0.0,1.0
1.0,1.0
0.0,1.0
1.0,1.0
1.0,1.0
0.0,1.0
1.0,1.0


In [473]:
%%sql 

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_20_roc_output;
SELECT
    madlib.area_under_roc(
        'adult_test_parsed_pred_gaussian_20_roc_input'
        ,'adult_test_parsed_pred_gaussian_20_roc_output'
        ,'prediction'
        ,'observed'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_20_roc_output;

area_under_roc
0.5274727718989202


## n_components=30

In [474]:
%%sql

DROP TABLE IF EXISTS
    adult_data_parsed_svm_gaussian_30
    ,adult_data_parsed_svm_gaussian_30_summary
    ,adult_data_parsed_svm_gaussian_30_random
;
SELECT
    madlib.svm_classification(
        'adult_data_parsed'
        ,'adult_data_parsed_svm_gaussian_30'
        ,'class = 0.0'
        ,'ARRAY[
            workclass
            ,education
            ,marital_status
            ,occupation
            ,relationship
            ,race
            ,sex
            ,native_country
        ]'
        ,'gaussian'
        ,'n_components=30'
        ,''
        ,'init_stepsize=2,
            max_iter=200'
    )
;
SELECT * FROM adult_data_parsed_svm_gaussian_30;

coef,loss,norm_of_gradient,num_iterations,num_rows_processed,num_rows_skipped,dep_var_mapping
"[-4.89899524001805, 0.502068590111403, -0.448432712468678, -0.264644263446127, 0.953312188977701, 0.542099805581305, -0.0491652636454902, -1.46433389026839, -0.694065328395931, 1.07884194353142, -0.539551708488271, 0.00681979333013868, -0.356586988637732, 0.0622683338217195, 0.277744926344081, -0.432933514181231, -2.21751016511408, -0.267193976008079, -0.510322287251929, -1.51443422598838, 0.0687754778661905, 0.711737438466957, 0.0931018297577553, 0.535800897591793, -0.246127314688611, 2.62781277064238, 0.144360687030306, 0.628837085118467, -0.296165387275386, 0.119662047659125]",17647.5505522,273.251910947,194,29900,2661,"[False, True]"


In [475]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_30;
SELECT
    madlib.svm_predict(
        'adult_data_parsed_svm_gaussian_30'
        ,'adult_test_parsed'
        ,'index'
        ,'adult_test_parsed_pred_gaussian_30'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_30 LIMIT 10;

index,prediction,decision_function
0,False,-1.14259429013
1,False,-1.27152168076
2,False,-0.207596574986
3,False,-1.05631585258
5,False,-0.946978327983
7,False,-0.896884632668
8,False,-1.09335208809
9,False,-1.52032030881
10,False,-0.999999955925
11,True,0.154337355323


In [476]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_30_roc_input;
CREATE TABLE
    adult_test_parsed_pred_gaussian_30_roc_input
AS
    SELECT
        ob.class::FLOAT8 AS observed
        ,CASE pre.prediction
            WHEN True THEN '0.0'::FLOAT8
            WHEN False THEN '1.0'::FLOAT8
            END AS prediction
    FROM adult_test_parsed_pred_gaussian_30 AS pre
            LEFT OUTER JOIN adult_test_parsed AS ob
                ON pre.index = ob.index
;
SELECT * FROM adult_test_parsed_pred_gaussian_30_roc_input LIMIT 10;

observed,prediction
1.0,1.0
1.0,1.0
0.0,1.0
0.0,1.0
1.0,1.0
0.0,1.0
1.0,1.0
1.0,1.0
0.0,1.0
1.0,0.0


In [477]:
%%sql 

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_30_roc_output;
SELECT
    madlib.area_under_roc(
        'adult_test_parsed_pred_gaussian_30_roc_input'
        ,'adult_test_parsed_pred_gaussian_30_roc_output'
        ,'prediction'
        ,'observed'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_30_roc_output;

area_under_roc
0.6180593427318534


## n_components=40

In [478]:
%%sql

DROP TABLE IF EXISTS
    adult_data_parsed_svm_gaussian_40
    ,adult_data_parsed_svm_gaussian_40_summary
    ,adult_data_parsed_svm_gaussian_40_random
;
SELECT
    madlib.svm_classification(
        'adult_data_parsed'
        ,'adult_data_parsed_svm_gaussian_40'
        ,'class = 0.0'
        ,'ARRAY[
            workclass
            ,education
            ,marital_status
            ,occupation
            ,relationship
            ,race
            ,sex
            ,native_country
        ]'
        ,'gaussian'
        ,'n_components=40'
        ,''
        ,'init_stepsize=2,
            max_iter=200'
    )
;
SELECT * FROM adult_data_parsed_svm_gaussian_40;

coef,loss,norm_of_gradient,num_iterations,num_rows_processed,num_rows_skipped,dep_var_mapping
"[-4.88484314019201, 0.536204691911348, -1.27689416201869, -0.397085816795131, 1.16781310467534, 0.603824276083918, -0.0903523770323435, -1.61516267518376, -0.512712743985378, 1.58585108956316, -0.589430217592102, -0.694512246389889, -0.751021698709412, 0.770702693019932, 0.290135390494713, -0.429994127137603, -2.73049804431851, -0.423242109692993, -0.762852709237821, -1.89291388009236, -0.195074127605914, 0.431432035565612, 0.194014387093558, -0.292411939725685, -0.318556318845158, 1.87150422933071, -0.0919452127657904, 0.884950245249454, -0.184842415790891, 0.00634895368966422, -1.75886473453548, -0.693352426131809, 1.22592506673662, 0.0349468994590692, 0.829953797887951, 0.905333260868029, -0.235126637148866, 2.5739680904738, 0.678486783527511, 0.704528701931597]",16803.2914882,218.945039033,193,29900,2661,"[False, True]"


In [479]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_40;
SELECT
    madlib.svm_predict(
        'adult_data_parsed_svm_gaussian_40'
        ,'adult_test_parsed'
        ,'index'
        ,'adult_test_parsed_pred_gaussian_40'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_40 LIMIT 10;

index,prediction,decision_function
0,False,-1.47216629147
1,False,-1.29243941586
2,False,-0.501686418949
3,False,-1.00116038862
5,False,-1.41009743909
7,False,-0.881206391154
8,False,-1.56859401587
9,False,-1.40380233705
10,False,-0.999999944356
11,False,-0.0198800864604


In [480]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_40_roc_input;
CREATE TABLE
    adult_test_parsed_pred_gaussian_40_roc_input
AS
    SELECT
        ob.class::FLOAT8 AS observed
        ,CASE pre.prediction
            WHEN True THEN '0.0'::FLOAT8
            WHEN False THEN '1.0'::FLOAT8
            END AS prediction
    FROM adult_test_parsed_pred_gaussian_40 AS pre
            LEFT OUTER JOIN adult_test_parsed AS ob
                ON pre.index = ob.index
;
SELECT * FROM adult_test_parsed_pred_gaussian_40_roc_input LIMIT 10;

observed,prediction
1.0,1.0
1.0,1.0
0.0,1.0
0.0,1.0
1.0,1.0
0.0,1.0
1.0,1.0
1.0,1.0
0.0,1.0
1.0,1.0


In [481]:
%%sql 

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_40_roc_output;
SELECT
    madlib.area_under_roc(
        'adult_test_parsed_pred_gaussian_40_roc_input'
        ,'adult_test_parsed_pred_gaussian_40_roc_output'
        ,'prediction'
        ,'observed'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_40_roc_output;

area_under_roc
0.6308386841105325


## n_components=1000

In [482]:
%%sql

DROP TABLE IF EXISTS
    adult_data_parsed_svm_gaussian_1000
    ,adult_data_parsed_svm_gaussian_1000_summary
    ,adult_data_parsed_svm_gaussian_1000_random
;
SELECT
    madlib.svm_classification(
        'adult_data_parsed'
        ,'adult_data_parsed_svm_gaussian_1000'
        ,'class = 0.0'
        ,'ARRAY[
            workclass
            ,education
            ,marital_status
            ,occupation
            ,relationship
            ,race
            ,sex
            ,native_country
        ]'
        ,'gaussian'
        ,'n_components=1000'
        ,''
        ,'init_stepsize=2,
            max_iter=200'
    )
;
SELECT * FROM adult_data_parsed_svm_gaussian_1000;

coef,loss,norm_of_gradient,num_iterations,num_rows_processed,num_rows_skipped,dep_var_mapping
"[-2.01824758710825, 0.142415330341722, -0.406952768146646, 0.323547279368753, 4.02757766368734, -0.736682098940432, -3.90502080708794, 2.09180936648094, -3.27643943025332, -1.78192192853058, -0.452869151390114, -5.79066527242473, -0.733466331531628, -5.68343719987323, -1.18998340218889, -0.88826142820505, -0.251904145544684, 3.88242628207631, -3.11886536657036, -1.90390904010437, 1.80267483350202, -3.69442774749305, -1.7395081067012, 0.991108483711886, -2.24051174401476, 2.09374165192408, -0.855821446232527, 0.565624422007473, -0.47204639432617, 0.673272521128769, -4.18526150826022, 4.50868541923335, 3.28084614597216, -1.83819039852294, 0.209500724447511, 5.91861796059941, 3.55913466083447, 7.81849911968738, 1.74133781387628, 10.1638054272517, -3.22369122918122, 2.37482750253116, -1.03779836354897, -3.66447170193058, 0.0350408516223365, 0.454528613577024, -1.19385053146103, 1.54694978640823, -1.24393227933923, -0.189506571607644, -0.356518047196438, 2.39130821865772, 7.86337290910466, 2.24370088600951, 0.857772317015588, -2.56506851657179, 0.413988496163076, 4.14905378313412, 1.37017732590804, -1.20695730132281, 1.57899457317195, 2.14692974288438, -2.42833341417327, 0.0430553944138281, 1.86211912116226, -0.00165507460752983, -1.21230431073205, 0.491551848744662, -0.24228190027641, 2.54654366294532, 1.63312376257851, 1.75189211712619, 0.348313469337073, 1.08010307456071, 1.4475806032459, 1.17555580372987, -5.24690391659617, 1.91843036934891, 2.72642738408904, -1.67502204100129, 1.03125566874432, -1.98654557542641, -1.14724673575079, 2.95963555880111, -3.1699074879827, 1.56512534015439, 0.280919254004124, 1.8138864319314, -0.799843848181833, -2.70183459617245, -2.6846761941514, 1.23619907454178, 4.83405871474735, -1.41270953582195, 0.163779400439891, 1.85126580362659, 1.64491558086739, -0.844527637738523, 1.93959737692736, 2.34865135591704, 0.145814495126528, -0.181799271294751, -1.54842093618882, 1.75290409031993, 1.08300599596827, 0.57155411177913, 0.399713835898495, -0.958369480412752, -1.24684291230027, 2.50776577918768, 6.05498231494876, -2.11599093772155, 0.260814542957255, 0.759979162442359, -7.18898952998077, 0.987106401364634, 0.498371515691701, -1.62388863377201, -1.62797444257821, -1.89017266630003, 0.819978163941632, 0.214647367818759, 1.91191696646278, 2.21717895881137, -1.21422616142453, 2.49551837719482, 0.647186660849757, 0.402355914533151, -1.31640062446456, -0.420828686024598, 0.0312071422413146, -0.447012810439768, 2.00283150929903, -0.711210981982425, -2.68528453998336, 2.17627634421483, -2.71857339812224, 5.34703351663374, -9.8147617098087, -4.76425179914937, -2.64893634963102, 2.02137004212944, -2.05950517597814, -1.61403641180907, 0.314959378199402, -1.27001755623414, 2.30945628940531, -0.282106657889014, 0.566783026711567, 2.91911067554183, -2.40459624791333, 0.58716948941795, -0.942129981400272, 2.42987958746258, 0.619646152537155, -0.405719120915164, 1.48610165557033, -0.461827462320926, 1.56500890135932, -1.76909198135687, -4.80317363328185, 0.524941154717851, -0.303775487819387, -2.43040103221232, 0.759644931200224, 0.829941848448204, 0.104320093244976, -2.53037580977074, 7.06568231220562, -4.07548770576049, -2.06353582162627, -1.63814180571116, 10.6384521520404, 0.854329429950901, -3.49722500315062, -1.35682982781711, 1.71879363405341, -1.73199583142812, 0.33882492701572, -1.54329725714477, -1.41149069417492, 0.857307173599383, 0.812726529482659, -5.60336383002867, -0.867887255580424, 1.6823572866346, -0.0649239896874719, 1.59075896785058, 0.417983035352976, -5.42700165820152, -0.608482751281484, -1.42562254429006, -2.94367515053271, -0.902577754366346, 0.321585649143886, -0.11001988050023, -1.30383490825469, 0.439353901291115, 0.980601987452019, 1.43484944868761, -0.0370424167254366, 1.55428858674486, 2.77958911896975, 0.127175718220849, 0.0781875801298774, 1.36049288702175, 1.241997993516, -1.37396500558922, 1.83205993866774, -1.46719466945584, 0.81721300634866, -3.38946476136628, -2.85092737888607, -1.53760292288551, 0.596420357577416, -1.07676058151849, -0.986358956595394, -2.80759021160902, 8.62796002607404, -0.90609953306694, 2.42862194119554, 2.41274002606631, 0.294841879078176, 0.179777931395681, -6.32785034891294, 0.0779296831303358, -0.628510321135943, 1.08937582589772, -2.14684458833304, 3.63496189103067, 7.36197601972843, 2.24001720882167, -3.19987555304572, 1.89106177425055, 0.526876009543636, -1.07318314933889, 4.95751347492776, 3.41106623070051, 3.21366407615174, 1.05470715592338, 0.0574346023062929, -4.59111890329718, -2.41628086850112, 1.30862145854652, -1.50713090444051, 2.1831187500181, 2.93280276310908, 0.991533892382647, 0.980188399712388, -2.78387028658602, -0.374861483363181, 2.10436468737492, 1.74619303272059, -2.74453087998132, -3.2326392797732, -0.690069374044381, 2.97716156253064, -4.62383277770297, 3.40399184100349, 2.30809534799524, 0.851669828838599, -0.380474626912707, 6.41822735995989, 0.662735364846171, -3.96321753256218, 7.35246572680609, 1.61878010661309, -0.212139099005121, 1.83975186995664, -0.717885978402878, 0.242922120806054, -0.168780243032213, 0.757385118432017, 0.955577593556555, 2.60845330590662, 0.81328909798238, 1.39679741972812, 1.13063555043757, 1.31823794426527, 1.64856084963928, 5.95311617707755, 1.94238065279933, 1.628509632388, 0.926818527492127, -1.06799862948255, -2.36360286867811, -1.71487252283197, -0.380312706407827, -4.47000214426787, 2.37555737843041, 1.37456776013176, 0.349438400834982, 0.624927544010469, 1.3482552493821, -8.2070942241132, 2.29405924319056, -0.998096772819238, -1.67830494235115, 0.54894241054814, 2.85557189495556, -0.408735041702337, 2.43831107320463, -0.39492518247199, 0.766679491977223, 2.04746471483282, -0.721500221661909, 3.35827124114523, 0.374203733810742, -3.79784780676543, 1.09032444070965, 6.80952216189456, -2.34897225980757, 1.87610305681386, 1.73018860298458, 5.06095888807717, 0.894174419581468, 0.930989974134458, 1.42238213188739, -0.204049093984008, 2.6367186569255, -1.81850923473634, 0.984088570690374, -1.27058044818047, 2.91306390412121, 0.888022337281047, 2.27079103036391, 3.05973330230213, 1.34770866091384, 0.251486897267717, 3.67045433017008, 0.210125995358086, -3.36462511571752, -1.75813800327195, -0.551753590653634, -4.19933456690116, 1.33782389528109, -1.60639258121759, -0.381978043897073, 0.472336481674122, -2.83914124004781, -2.95824331656919, 6.69753440633239, 1.75599005983504, -0.802724459103696, -4.42426141003174, 2.93603023674471, -2.58365680363243, 1.66536377114323, -2.57837538565523, -3.5356900547746, -0.270732968103115, 1.42805803501485, 1.15282036455697, -0.410485490105691, 0.516491361456642, -0.399014026810564, 1.7307206514003, 2.24356543660392, 3.27964731720119, -0.998568941541142, 0.231260274885132, -0.650414349657995, 1.1757993241609, -1.97764406299704, -6.08485915812574, 0.285868770091467, 3.22200395297331, -1.27081103478941, 0.309974066102388, -0.816847158288933, 7.09048985105319, -0.347776745363437, -0.717410706314042, 2.39529049037691, 1.95828983751614, 1.31946959107038, 0.270744910275313, -6.64126177806145, -4.91721436465928, -2.00199977186345, 2.57767762672991, 2.40413810264162, -8.55393724720554, 1.79550584228167, 0.928131230940308, -3.79342074830297, 0.542071648436192, 1.82748116637129, -5.66056944953469, 1.22802673833125, -0.423118413171415, -4.37445296943835, 0.547183232113799, 0.990564193315451, 0.617860889198911, 0.48164785270983, 1.81552554633034, -0.381052215469875, 0.718647941571689, 2.35350278313863, -1.29972669887515, 1.92479895468402, 6.58836895666794, 0.198601211156792, -1.92515935361195, -0.709616311885843, -4.16379182504992, 1.63463247852211, 0.968146366289041, 2.51433518611445, -2.29029836429131, -5.33991570734851, 0.288163574475366, -2.23164743552353, 2.69294491548063, -5.6605076092617, -5.33536081620689, -3.240786835776, -3.2182130555368, 1.18024203194265, -3.79074057405403, -0.719529549754878, -3.33079677183237, -2.80311942206107, -0.6900827264415, 1.39354429275414, 9.23348691193084, 1.2707980598789, 3.24741865601111, 0.548813247839667, -1.12798844017083, 2.2011273208097, -4.14234400484444, 3.3992188024009, 2.03138619720895, -0.466391367515463, -4.66139366903303, -3.77297053501255, -1.73591002598755, -1.87354025437309, -0.195326729941067, -0.995620713149734, 2.11246680483591, -2.74163998576275, 1.15888659813539, 0.111704678747003, 3.07156438749981, 0.254510547873408, -1.92322430715187, -0.860601707314247, -7.84924082048632, -1.98528235691064, -0.912011727866459, -0.971628580133712, 3.81105256582439, -8.07747460607845, -0.743641930678321, 2.24411458228977, -1.4143074945712, 1.53006367805961, 1.43402530485993, -0.417296162776374, -2.42598830711082, -0.94558163737544, 1.09018801842302, -1.33496782153588, 2.50101251860882, 5.77255858724619, -2.02436427239368, -0.874502570122143, -1.02158088227658, -0.456099914639248, 4.3506735419001, 3.1798993009928, 0.162724552735261, -3.58285151084936, -0.552513193872595, 17.1289987303823, -2.64580896977823, 2.15917324442827, -1.0275382842954, -1.94621270617105, 0.742809158547212, -1.98612020514191, 3.21431272511935, -3.15066694442771, -0.0391909545850595, 1.81266327678238, 0.518062935012981, -0.326762841518968, 0.937595162158657, -1.592410680205, -2.51339846124423, -0.902477215116669, 0.958403924081741, 1.01675503285035, -2.70162688667608, -2.42043746765158, -3.88559584256774, 0.824480614309686, -0.419838745145085, 1.40300304862858, 1.25367952001717, -1.95972500614612, -0.636802458817755, -1.79001073581158, -5.70097777149932, 1.65762914773518, 3.51958762203445, -3.25869960627186, -0.275089714262975, 1.12580670900319, -1.44307346524956, -1.26371151600462, 0.204490192871032, 1.02401563434433, -3.02856456148297, 0.425133648830485, 0.982149245692121, -0.973690695573789, -1.47040751051563, -3.50285386141624, -2.0893143063621, -2.6355342424028, 0.278714642888715, 4.56311087554184, -0.569350217562004, -0.729602784898523, 3.71948989389452, -0.724958505439015, -3.45970366448802, -1.92024672310163, -1.7174510215634, -4.16514675647426, 0.437335841851158, -3.68847517832551, 1.80129090699342, -3.86683658673551, 3.26137569134298, -0.19794780373231, 4.82513757083991, 0.281445494497228, -4.09100711950687, 0.985503226481729, -1.28736414988975, 2.18261819883172, 1.39336249374498, 0.871627482501457, -1.06301276788759, 0.853711257650794, -0.114331086639339, 0.175384749184649, -0.817237943030045, -0.676396003268472, -2.80072703144623, 0.0899269690000173, 1.7714603399254, -0.802625287327766, 0.203991142144639, -0.226979032324133, 1.0483859781461, 2.72586249339241, 2.2629340276588, 5.05080311608595, 1.27550700329154, 0.239879589794741, 4.09128999090813, -0.762454497808261, 5.42207307175356, -5.00049420705978, -1.83856604212971, 2.89528121698059, 0.202470028130204, -2.91798425430511, 2.74924412318307, 0.379399405198308, -0.322409559124229, 1.93067823029175, -0.28128142060197, -0.0195688782626132, 0.956918129506866, 0.977371568149236, 0.432510102254331, 1.26045544218567, 4.02413361539169, 4.12099921119353, -2.10854696067953, -2.45789243225281, 1.24432287135059, -1.97220908600972, -0.790525940627018, 2.640506678871, -2.05521914078308, 1.41820237874983, 1.3762018701991, 6.32035448683575, -0.240622871709259, -1.22987449926001, 0.780317375552188, 1.18660912598983, 0.363432621696084, -0.537353082512488, 1.89208929360724, -1.18440145641904, -0.608927703189967, -0.885225589892745, 1.69828128358948, 6.02615024384923, -0.0810671822304382, -0.205015319318975, -4.12838771861146, 0.525088565749195, 0.622937189224143, -1.84434663478604, -0.469800515974432, 0.675173297653204, 0.0454068469227292, 2.46728813900492, 1.00473748441637, -2.66368220214627, -0.943041091624968, -0.913050799915377, -1.14149463693269, -1.75849916643794, -2.03722909948611, -0.566211726086684, -1.6303350102775, -0.406733652035376, 1.93456981707628, 1.63865053790222, -3.70372966870405, 1.24235723286685, -3.46558111223201, 1.59923069440176, -0.961202508685938, 1.01388113170396, 0.516875409687929, 3.66272996668639, 0.391292681479203, 1.77930512849479, 0.210194995352734, -0.214105715355006, 3.30075663093631, 1.15375825720297, -0.183190413308839, 1.14668423261979, 1.37288395719693, -0.846659477859848, 1.0068464565697, -0.00296613586290941, 1.16764927993089, -0.820419373848885, 3.09155565638894, 0.450687083059663, -4.00957294081476, 0.078024164220119, -0.32183017502586, -0.00362337098307917, 0.714861382471852, -1.40347211648658, -1.27379216006984, 2.34625674211512, 1.61814058570015, 2.62856782749745, 2.15112303726139, -2.00322707615552, -1.01390532345447, 12.7915177204758, -8.08615712144256, 2.2304528816945, 1.32036754342691, 2.2702494863983, -3.72539051119995, 1.54986092819454, 0.724982398164898, -1.09121418499668, 0.0945548835855781, -4.48734469566858, -0.238957974905104, -0.612841860247726, 2.15303357142573, -0.798648572201587, -2.57408750136018, 1.10809952806822, -0.655239490002755, 0.602533573256594, 0.349263593473057, -2.09490440352192, -0.148557642396387, -0.21217366596356, -0.0174673581355355, 3.24297779853635, 1.31077495052936, 0.00700680175224654, -0.956168336863797, 3.00977359917886, 5.8488951257441, 6.01130518065335, -1.24529405083245, 1.21376489947943, -1.75533669907323, -2.48941775800101, -1.39746870340337, -0.585788733420767, 2.79186019581625, 6.95191887366949, -2.49667647767497, 0.85689242927665, 6.05505934380818, -1.84353644423164, -3.05739172616633, -1.54084232993153, 1.80563146646424, -0.431780755621594, 4.21547178602458, 1.91306205211088, 0.614214117136056, 4.83233504342067, 0.419761937347761, -1.97477252490264, 0.00995964657075922, -1.1611140407772, -0.921496965038808, 2.01011879387675, -0.492688175545836, 0.369383843475892, -2.36715992679585, 2.41163706282417, 5.37613780294965, 2.57019156813478, 1.63642794838333, -1.40761389283495, 0.0486890516488719, 1.53173097313344, 1.22514423527104, -7.87774062459924, 4.69075522019316, -1.92575040738476, -0.447821620488205, -1.96228587945921, 2.33303909350511, -0.850776818052441, -0.972671048859905, -0.625492773879719, -0.169022632882361, -2.56960522095849, -0.627246085080542, -1.37568415937089, 2.49693528461527, 0.853030133094958, 1.56410618996722, 2.70995945650986, 2.90192608200292, 6.70747853736404, -3.0323498631161, 2.79134531200972, 1.11789401658906, 0.337097433226428, -4.25575709395813, -0.916516112985071, 0.274295336020196, 0.462092212092527, 0.611333622236719, 2.22802945228043, 1.86212782297195, -3.04058535835153, 0.501120360559417, -0.252592358144802, 2.86591802328413, 1.55184926665455, 5.1192433523522, 1.91331260360041, 0.012277211591537, 2.8544971185633, 1.91337032148145, 0.117116502103378, 0.77041071207476, 0.90625247883272, -2.75432525465143, -2.22981669801874, 1.61357798033004, -0.92825890519647, 1.30940635608328, -1.52328145333143, 4.50399059862043, 0.364723716895453, -0.147484762156478, -1.92487978940188, 2.47652934854498, 1.86105000750261, -4.57263062516964, -0.419421311937118, -0.391761729996905, 0.61212827856411, 0.0281387967300959, 0.773151999282133, -1.6880187309389, 1.03345654026867, -1.35620986322534, 1.90006277750095, -1.72386990226026, 1.68899000563477, -0.114240734836126, -1.36993036271712, 7.19058700534276, -0.590004538234607, -3.32240089630418, 2.6394893808645, -1.92303716126017, 1.15794632959984, 1.54628112135738, 3.49509996150198, -5.33161321146804, 2.09224179830962, -0.254251634820641, 3.38667254848687, 1.70534215376173, -0.621684855266498, -0.909791979351218, -5.30218888692287, -6.74035913089964, -0.925521611799997, 0.222915995313889, 2.65992731395045, 1.23119491830485, -1.68579497181653, -0.247226606482449, -0.258433998492756, 0.288169851972106, 0.772174653248034, 2.17473848822524, 1.50995544452039, -0.704893170276075, 0.285749779061449, 1.1971561212752, -1.217459097368, 0.499370901286226, 5.59553385426875, -2.61713720272148, -1.84265973189907, -0.24392143654274, 2.87321477201843, 1.34886751357067, -1.22453941440383, -0.292048566625244, 1.52103728778329, -9.06749647328285, 0.429361249490179, -3.43787448387318, 0.940309800200014, -0.833002796828883, 0.693425569061337, -0.689897724246691, -2.40828323034832, 1.59177150823457, -1.03813363880376, -1.15333875510192, -22.3135926217558, -0.598037182875771, 0.346051096119365, -5.42281386330272, -2.94818860421179, 1.56080203834123, -0.0699015299885074, 0.61606852798695, -0.0157416654424327, -0.48119781109517, 5.47522714547851, 1.06847181538689, 2.11230026090632, -0.888853186320215, 4.87427934683587, 1.33773204480111, 0.950757881352634, -2.56612612026103, 3.17058828227055, 1.16415203664664, -1.17863981697432, 0.83934468798233, -4.03334348588841, 4.30387884966592, -1.5812134841177, -1.52027338830363, 1.7405701962862, -2.79910052143091, 3.03032260917178, 6.1244170358685, 5.30735381860436, -1.32058649042565, -5.40108299091297, -1.29306598195044, -0.0832083551207922, -1.7015502643233, -0.221726450468085, -0.709035852223573, 0.543175216251043, -3.01860758919566, 0.0679974084687863, -1.38653632334608, 1.33725766394818, 0.0567243350223979, 0.63882224298841, -0.454392298408734, -0.192219517561247, 1.85449593695456, -1.07420113869111, -0.621295411121019, -2.49472873967776, -2.84681428994615, 0.562047559622932, 3.67283839160803, -0.258229067260464, -0.57811818430399, -1.81936862269067, -0.100821331113075, 2.65879489015578, 2.49406721894068, -0.435096332939509, -0.0837219949006843, 1.68095211399727, 0.485175616535479, -2.27681042076767, -1.77128171962889, -2.86301717349838, -2.74987192182243, -1.92368122641289, 0.973583702267197, 0.0864427910532248, -2.23313189771042, 0.62836082726975, -1.99486300418142, 1.86391458074731, 0.657819934283818, -1.29882053995899, -1.59995574427656, 3.33526640310776, -0.0506631567633932, -11.2791028249808, 0.42573686705446, 2.01464964330714, -2.69315343977122, 0.347701288877952, -6.43739951779589, -1.98152526042192, -0.282811832473117, 3.35429703498427, 1.58664113425862, -1.17058348808448, -1.78805863813681, -2.45044477356293, -0.974234998119589, 3.40783597411539, -0.828747179172742, -1.58726440961032, 2.70826853848943, -0.369716718428016, 3.16793808553943, 0.55921506054703, 2.3126750627441, 2.01906190348261, -0.620681594078998, 2.16526877826374, -1.74027256861319, -0.792115031912475, -0.385630433645413, 2.27888416210375, 1.05333794631991, -1.78092401433396, 3.43985783197436, -1.06842140953276, 3.06791736608014, 4.07117541924771, -0.919215402876125, -0.595558917199812, 1.20320056360266, 0.777244765617223, -7.59698462900782, 0.94088959140672, -0.191004796236771, -8.02425415523561, -0.372499199601035, -2.12108359837473, -1.59069121259246, -1.88531774069186, 4.24759945761543, 3.12045663666777, -1.42753689861479, 0.567403515289283, -0.589566095169243, 0.363951068852293, 1.65896337975769, 0.533537284369876, 1.47225699930881, 0.509396266590304, 0.0151258732408307, 2.14446290932801, 0.54967973668871, -0.969515498995302, 0.242784855513648, 2.5650379476082]",11107.0638326,378.584010449,193,29900,2661,"[False, True]"


In [483]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_1000;
SELECT
    madlib.svm_predict(
        'adult_data_parsed_svm_gaussian_1000'
        ,'adult_test_parsed'
        ,'index'
        ,'adult_test_parsed_pred_gaussian_1000'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_1000 LIMIT 10;

index,prediction,decision_function
0,False,-2.75168489053
1,False,-1.01947139686
2,True,1.00000001042
3,False,-1.03874735588
5,False,-2.38269391935
7,True,1.00000001097
8,False,-1.29193105377
9,False,-1.49512161449
10,False,-1.01893068037
11,True,1.09882108507


In [484]:
%%sql

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_1000_roc_input;
CREATE TABLE
    adult_test_parsed_pred_gaussian_1000_roc_input
AS
    SELECT
        ob.class::FLOAT8 AS observed
        ,CASE pre.prediction
            WHEN True THEN '0.0'::FLOAT8
            WHEN False THEN '1.0'::FLOAT8
            END AS prediction
    FROM adult_test_parsed_pred_gaussian_1000 AS pre
            LEFT OUTER JOIN adult_test_parsed AS ob
                ON pre.index = ob.index
;
SELECT * FROM adult_test_parsed_pred_gaussian_1000_roc_input LIMIT 10;

observed,prediction
1.0,1.0
1.0,1.0
0.0,0.0
0.0,1.0
1.0,1.0
0.0,0.0
1.0,1.0
1.0,1.0
0.0,1.0
1.0,0.0


In [485]:
%%sql 

DROP TABLE IF EXISTS adult_test_parsed_pred_gaussian_1000_roc_output;
SELECT
    madlib.area_under_roc(
        'adult_test_parsed_pred_gaussian_1000_roc_input'
        ,'adult_test_parsed_pred_gaussian_1000_roc_output'
        ,'prediction'
        ,'observed'
    )
;
SELECT * FROM adult_test_parsed_pred_gaussian_1000_roc_output;

area_under_roc
0.712228973997911


## 結果をまとめてみる

In [492]:
%%sql
SELECT 
    aur.area_under_roc AS aur
    ,aur10.area_under_roc AS aur10
    ,aur20.area_under_roc AS aur20
    ,aur30.area_under_roc AS aur30
    ,aur40.area_under_roc AS aur40
    ,aur1000.area_under_roc AS aur1000
FROM
    adult_test_parsed_pred_roc_output AS aur
    ,adult_test_parsed_pred_gaussian_10_roc_output AS aur10
    ,adult_test_parsed_pred_gaussian_20_roc_output AS aur20
    ,adult_test_parsed_pred_gaussian_30_roc_output AS aur30
    ,adult_test_parsed_pred_gaussian_40_roc_output AS aur40
    ,adult_test_parsed_pred_gaussian_1000_roc_output AS aur1000
;

aur,aur10,aur20,aur30,aur40,aur1000
0.7573908867338783,0.5323055370802969,0.5274727718989202,0.6180593427318534,0.6308386841105325,0.712228973997911


更なるチューニングの余地あり…