In [1]:
%matplotlib inline

import numpy as np

import matplotlib.pyplot as plt

import pandas.io.sql as psql
import psycopg2 as pg

import pandas as pd

from keras.models import Sequential
from keras.layers import Dense
from keras.engine.saving import model_from_json
from keras.layers import Dropout, BatchNormalization, regularizers
from keras.optimizers import Adam, RMSprop

from sklearn import model_selection

from imblearn.over_sampling import SMOTE


Using TensorFlow backend.


In [2]:
with pg.connect(database='chi-navi-mesh',
                host='localhost',
                user='postgres',
                port=5432) as conn:
    all_season_sql = \
        "with bus_stop_flowting_population as( " \
        "    with bus_stop_mesh as( " \
        "       select " \
        "           m.keycode, bs.bus_stop_id, m.latlng_path, point(bs.latlon) as bus_stop_latlon " \
        "       from " \
        "           bus_stop bs, population p, mesh m " \
        "       where " \
        "           m.keycode = p.keycode " \
        "           and box(polygon(m.latlng_path)) @> point(bs.latlon) " \
        "   ), " \
        "   flowting_population as( " \
        "        select " \
        "           m.latlng_path as latlng, " \
        "           fp.keycode, " \
        "           fp.param ->> 'year' as year, " \
        "           fp.param ->> 'month' as month, " \
        "           fp.param ->> 'hour' as hour," \
        "           fp.param ->> 'day_type' as day_type, " \
        "           fp.param ->> 'cnt_population' as cntPopulation " \
        "       from " \
        "           mesh m ,flowing_population fp " \
        "       where " \
        "           m.keycode = fp.keycode " \
        "   ) " \
        "   select " \
        "       b.bus_stop_id, " \
        "       f.cntpopulation, f.year, f.month, f.hour, f.day_type " \
        "   from " \
        "       flowting_population f, bus_stop_mesh b " \
        "   where " \
        "       b.keycode = f.keycode " \
        "), " \
        "spot_bus_stop as( " \
        "   select " \
        "       tmp.bus_stop_id, " \
        "       max(" \
        "           case " \
        "               when spot_id = 3 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as chitose_station, " \
        "       max(" \
        "           case " \
        "               when spot_id = 4 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as minami_chitose_station, " \
        "       max(" \
        "           case " \
        "               when spot_id = 5 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as osatu_station, " \
        "       max(" \
        "           case " \
        "               when spot_id = 6 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as chitose_airport, " \
        "       max(" \
        "           case " \
        "               when spot_id = 7 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as aeon, " \
        "       max(" \
        "           case " \
        "               when spot_id = 8 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as rera, " \
        "       max(" \
        "           case " \
        "               when spot_id = 9 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as chitose_mall, " \
        "       max(" \
        "           case " \
        "               when spot_id = 10 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as big_house, " \
        "       max(" \
        "           case " \
        "               when spot_id = 11 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as hokuren, " \
        "       max(" \
        "           case " \
        "               when spot_id = 12 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as morimoto, " \
        "       max(" \
        "           case " \
        "               when spot_id = 13 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as city_hospital, " \
        "       max(" \
        "           case " \
        "               when spot_id = 14 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as houyuukai_hospital, " \
        "       max(" \
        "           case " \
        "               when spot_id = 15 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as chitose_hospital, " \
        "       max(" \
        "           case " \
        "               when spot_id = 16 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as hokusei_hospital, " \
        "       max(" \
        "           case " \
        "               when spot_id = 17 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as city_library, " \
        "       max(" \
        "           case " \
        "               when spot_id = 18 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as city_office, " \
        "       max(" \
        "           case " \
        "               when spot_id = 19 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as kouyoudai_city_office, " \
        "       max(" \
        "           case " \
        "               when spot_id = 20 then 1 " \
        "               else 0 " \
        "           end" \
        "       ) as city_service_center, " \
        "       max(" \
        "           case " \
        "               when spot_id = 21 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as sports_center, " \
        "       max(" \
        "           case " \
        "               when spot_id = 22 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as sougou_budoukan, " \
        "       max(" \
        "           case " \
        "               when spot_id = 23 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as onsui_pool, " \
        "       max(" \
        "           case " \
        "               when spot_id = 24 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as mati_library, " \
        "       max(" \
        "           case " \
        "               when spot_id = 25 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as salmon_park, " \
        "       max(" \
        "           case " \
        "               when spot_id = 26 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as sikotu_lake, " \
        "       max(" \
        "           case " \
        "               when spot_id = 27 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as hokuyou_highschool, " \
        "       max(" \
        "           case " \
        "               when spot_id = 28 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as chitose_highschool, " \
        "       max(" \
        "           case " \
        "               when spot_id = 29 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as aoba_junior_high_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 30 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as chitose_junior_high_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 31 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as hokushin_junior_high_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 32 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as tomioka_junior_high_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 33 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as hokuto_junior_high_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 34 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as yuumai_junior_high_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 35 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as kouyoudai_junior_high_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 36 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as chitose_drivers_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 37 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as chitose_rehabilitation_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 38 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as airplane_vacational_school, " \
        "       max(" \
        "           case " \
        "               when spot_id = 39 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as north_chitose_military_post, " \
        "       max(" \
        "           case " \
        "               when spot_id = 40 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as airself_defence_force_fort, " \
        "       max(" \
        "           case " \
        "               when spot_id = 41 " \
        "               then 1 else 0 " \
        "           end" \
        "       ) as east_chitose_military_post " \
        "   from " \
        "       (select " \
        "           bus_stop_id,s.spot_id,s.area,p.latlon, " \
        "           row_number() over (partition by bus_stop_id) as seq " \
        "       from " \
        "           (select spot_id, s.area from spot s order by spot_id) as s " \
        "       left outer join " \
        "           (select bus_stop_id,latlon::point as latlon from bus_stop) as p " \
        "       on " \
        "           ST_Covers( " \
        "               ST_Buffer(ST_POINT(s.area[0], s.area[1])::geography,400), " \
        "               ST_POINT(p.latlon[1], p.latlon[0]) " \
        "           ) " \
        "       order by p.bus_stop_id, s.spot_id " \
        "       ) tmp " \
        "   group by tmp.bus_stop_id " \
        "), " \
        "delay_info as( " \
        "   with route_bus_stop_list as( " \
        "       select " \
        "           t.task_id,r.route_id,b.bus_stop_id,rbs.ordinal " \
        "       from " \
        "           bus_stop b,route r,route_bus_stop rbs,task t " \
        "       where " \
        "           b.bus_stop_id = rbs.bus_stop_id " \
        "           and r.route_id = rbs.route_id " \
        "           and t.route_id = r.route_id " \
        "       order by r.route_id,rbs.ordinal " \
        "   ), " \
        "   delay_list as( " \
        "       with stddev as( " \
        "           select " \
        "               stddev_samp(extract(epoch from (collide_at - departure_time))) as stddev " \
        "           from " \
        "               task_delay " \
        "           where " \
        "               (collide_at - departure_time) > '00:00:00' " \
        "       ), " \
        "       delay_avg as( " \
        "           select " \
        "               avg((extract(epoch from (collide_at - departure_time)))) as delay_avg " \
        "           from " \
        "               task_delay " \
        "           where " \
        "               (collide_at - departure_time) > '00:00:00' " \
        "       ) " \
        "       select " \
        "           td.task_id,td.departure_time,td.bus_stop_id,td.submit_in,td.collide_at, " \
        "           extract(epoch from (td.collide_at - td2.departure_time)) as delay " \
        "       from " \
        "           task_delay td,delay_avg da,stddev std,task_detail td2 " \
        "       where " \
        "           (extract(epoch from (td.collide_at - td2.departure_time)) between " \
        "           da.delay_avg - std.stddev * 3 and da.delay_avg + std.stddev * 3) " \
        "           and extract(epoch from (td.collide_at - td2.departure_time)) > 0 " \
        "           and td.task_id = td2.task_id " \
        "           and td.bus_stop_id = td2.bus_stop_id " \
        "           and td.departure_time = td2.departure_time " \
        "   ) " \
        "   select " \
        "       distinct r.route_id,r.bus_stop_id,r.ordinal,d.collide_at, " \
        "       d.task_id,d.delay,d.submit_in,d.departure_time, " \
        "       date_part('HOUR' , d.departure_time) as hour, date_part('minutes',d.departure_time) as minutes," \
        "       case " \
        "           when " \
        "               extract(dow from submit_in) = 0 or " \
        "               extract(dow from submit_in) = 6 then 0 " \
        "           else 1 " \
        "       end as day_type " \
        "   from " \
        "       delay_list d, route_bus_stop_list r " \
        "   where " \
        "       d.bus_stop_id = r.bus_stop_id " \
        "       and r.task_id = d.task_id " \
        "), " \
        "access_info as( " \
        "   select " \
        "       count(distinct session_id) as access_num,browsing_at::date as date, " \
        "       date_part('HOUR', browsing_at) as hour, " \
        "       case " \
        "           when extract(dow from browsing_at::date) = 0 " \
        "           or extract(dow from browsing_at::date) = 6 then 0 " \
        "           else 1 " \
        "       end as youbi " \
        "   from " \
        "       user_browsing_log ubl " \
        "   group by date, hour " \
        ") " \
        "select " \
        "case " \
        "when delay between 0 and 60 then 0 " \
        "when delay between 61 and 120 then 1 " \
        "when delay between 121 and 180 then 2 " \
        "when delay between 181 and 240 then 3 " \
        "when delay between 241 and 300 then 4 " \
        "when delay between 301 and 360 then 5 " \
        "when delay between 361 and 420 then 6 " \
        "when delay between 421 and 480 then 7 " \
        "when delay between 481 and 540 then 8 " \
        "when delay > 540 then 9 " \
        "end as delay, " \
        "   ai.hour as time, " \
        "   di.minutes, " \
        "   ai.access_num,ai.youbi, " \
        "   case " \
        "       when submit_in = '2017-07-23' then 1 /* 航空祭 */ " \
        "       when (submit_in between '2018-01-26' and '2018-02-18') or " \
        "            (submit_in between '2017-01-27' and '2017-02-19') " \
        "       then 2/* ひょうとう祭り */ " \
        "       when (submit_in between '2017-07-06' and '2017-07-09') or " \
        "            (submit_in between '2018-07-05' and '2018-07-08') " \
        "       then 3/* セガサミーカップ */ " \
        "       else 0 " \
        "   end as event, " \
        "   case " \
        "       when di.route_id in (13,14,23,24,55,56,58,59,60,61,62,63,64,67,68,69,70,71,74, " \
        "                            78,79,80,81,82,91,92,93,94,95,96,97,98,99,106,107,108,109,11) " \
        "       then 1 " \
        "       else 0 " \
        "   end as destination, " \
        "   case " \
        "       when date_part('month',submit_in) in (3,4,5,6,9,10,11) then 1 " \
        "       when date_part('month', submit_in) in (7,8) then 0 " \
        "       when date_part('month', submit_in) in (12,1,2) then 2 " \
        "   end as season, " \
        "   s.*, " \
        "   b.cntpopulation, " \
        "   case " \
        "       when route_id = 11 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_11, " \
        "   case " \
        "       when route_id = 12 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_12, " \
        "   case " \
        "       when route_id = 13 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_13," \
        "   case " \
        "       when route_id = 14 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_14, " \
        "   case " \
        "       when route_id = 23 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_23, " \
        "   case " \
        "       when route_id = 24 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_24, " \
        "   case " \
        "       when route_id = 55 " \
        "       then 1 else 0 " \
        "   end as poor_55, " \
        "   case " \
        "       when route_id = 56 " \
        "       then 1 else 0 " \
        "   end as poor_56, " \
        "   case " \
        "       when route_id = 57 " \
        "       then 1 else 0 " \
        "   end as poor_57, " \
        "   case " \
        "       when route_id = 58 " \
        "       then 1 else 0 " \
        "   end as tosyokan_58, " \
        "   case " \
        "       when route_id = 59 " \
        "       then 1 else 0 " \
        "   end as yuumai_59, " \
        "   case " \
        "       when route_id = 60 " \
        "       then 1 else 0 " \
        "   end as yuumai_60, " \
        "   case " \
        "       when route_id = 61 " \
        "       then 1 else 0 " \
        "   end as sakuragi_61, " \
        "   case " \
        "       when route_id = 62 " \
        "       then 1 else 0 " \
        "   end as sakuragi_62, " \
        "   case " \
        "       when route_id = 63 " \
        "       then 1 else 0 " \
        "   end as sakuragi_63, " \
        "   case " \
        "       when route_id = 63 " \
        "       then 1 else 0 " \
        "   end as sakuragi_63, " \
        "   case " \
        "       when route_id = 64 " \
        "       then 1 else 0 " \
        "   end as sakuragi_64, " \
        "   case " \
        "       when route_id = 67 " \
        "       then 1 else 0 " \
        "   end as inaho_67, " \
        "   case " \
        "       when route_id = 68 " \
        "       then 1 else 0 " \
        "   end as inaho_68, " \
        "   case " \
        "       when route_id = 67 " \
        "       then 1 else 0 " \
        "   end as inaho_67, " \
        "   case " \
        "       when route_id = 68 " \
        "       then 1 else 0 " \
        "   end as inaho_68, " \
        "   case " \
        "       when route_id = 69 " \
        "       then 1 else 0 " \
        "   end as sikotuko_69, " \
        "   case " \
        "       when route_id = 70 " \
        "       then 1 else 0 " \
        "   end as kuukousinai_70, " \
        "   case " \
        "       when route_id = 71 " \
        "       then 1 else 0 " \
        "   end as yamato_71, " \
        "   case " \
        "       when route_id = 72 " \
        "       then 1 else 0 " \
        "   end as higasibutai_72, " \
        "   case " \
        "       when route_id = 74 " \
        "       then 1 else 0 " \
        "   end as midoridai_74, " \
        "   case " \
        "       when route_id = 80 " \
        "       then 1 else 0 " \
        "   end as kuukousinai_80, " \
        "   case " \
        "       when route_id = 81 " \
        "       then 1 else 0 " \
        "   end as sikotuko_81, " \
        "   case " \
        "       when route_id = 82 " \
        "       then 1 else 0 " \
        "   end as yamato_82, " \
        "   case " \
        "       when route_id = 83 " \
        "       then 1 else 0 " \
        "   end as higashi_83, " \
        "   case " \
        "       when route_id = 84 " \
        "       then 1 else 0 " \
        "   end as higashi_84, " \
        "   case " \
        "       when route_id = 85 " \
        "       then 1 else 0 " \
        "   end as higashi_85, " \
        "   case " \
        "       when route_id = 91 " \
        "       then 1 else 0 " \
        "   end as midoridai_91, " \
        "   case " \
        "       when route_id = 92 " \
        "       then 1 else 0 " \
        "   end as sakuragi_92, " \
        "   case " \
        "       when route_id = 93 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_93, " \
        "   case " \
        "       when route_id = 94 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_94, " \
        "   case " \
        "       when route_id = 95 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_95, " \
        "   case " \
        "       when route_id = 96 " \
        "       then 1 else 0 " \
        "   end as kouyoudai_96, " \
        "   case " \
        "       when route_id = 97 " \
        "       then 1 else 0 " \
        "   end as midaridai_97, " \
        "   case " \
        "       when route_id = 98 " \
        "       then 1 else 0 " \
        "   end as midaridai_98, " \
        "   case " \
        "       when route_id = 99 " \
        "       then 1 else 0 " \
        "   end as sakuragi_99, " \
        "   case " \
        "       when route_id = 100 " \
        "       then 1 else 0 " \
        "   end as sakuragi_100, " \
        "   case " \
        "       when route_id = 106 " \
        "       then 1 else 0 " \
        "   end as higashibutai_106, " \
        "   case " \
        "       when route_id = 107 " \
        "       then 1 else 0 " \
        "   end as higashibutai_107, " \
        "   case " \
        "       when route_id = 108 " \
        "       then 1 else 0 " \
        "   end as higashibutai_108, " \
        "   case " \
        "       when route_id = 109 " \
        "       then 1 else 0 " \
        "   end as higashibutai_109, " \
        "   case " \
        "       when route_id = 110 " \
        "       then 1 else 0 " \
        "   end as yuumai_110, " \
        "   di.route_id " \
        "from " \
        "   spot_bus_stop s, " \
        "   delay_info di, " \
        "   access_info ai," \
        "   bus_stop_flowting_population b " \
        "where " \
        "   di.hour = ai.hour " \
        "   and di.submit_in = ai.date " \
        "   and s.bus_stop_id = di.bus_stop_id " \
        "   and b.hour = di.hour::text " \
        "   and date_part('MONTH', di.submit_in)::text = b.month " \
        "   and date_part('YEAR', di.submit_in) = 2018 " \
        "   and di.day_type::text = b.day_type " \
        "   and b.bus_stop_id = di.bus_stop_id " \
        "order by s.bus_stop_id "


In [3]:
# sql_result = psql.read_sql(all_season_sql, conn)
sql_result = pd.read_csv("./softmax_routeDummy.csv")

In [4]:
max_len = sql_result.size

percent_list = [sql_result[sql_result['delay'] == 0].size / max_len * 100,
                sql_result[sql_result['delay'] == 1].size / max_len * 100,
                sql_result[sql_result['delay'] == 2].size / max_len * 100,
                sql_result[sql_result['delay'] == 3].size / max_len * 100,
                sql_result[sql_result['delay'] == 4].size / max_len * 100,
                sql_result[sql_result['delay'] == 5].size / max_len * 100,
                sql_result[sql_result['delay'] == 6].size / max_len * 100,
                sql_result[sql_result['delay'] == 7].size / max_len * 100,
                sql_result[sql_result['delay'] == 8].size / max_len * 100,
                sql_result[sql_result['delay'] == 9].size / max_len * 100]

In [5]:
def delay_threshold(one_line):
    delay_label = 0
    for i, percent in enumerate(percent_list):
        if one_line['delay'] == i:
            if percent >= 8:
                return delay_label
    delay_label = 1
    return delay_label


In [6]:
def data_split(x, y):
    x_train, x_test, y_train, y_test = model_selection.train_test_split(
        x, y, train_size=int(len(x) * 0.8))

    return x_train, x_test, y_train, y_test

In [7]:
def smote(x_train, y_train):
    print('サイズ調整前の学習データのlabel_zero length:{}'.format((len(y_train[y_train == 0]))))
    print('サイズ調整前の学習データのlabel_one length:{}'.format((len(y_train[y_train == 1]))))
    
    sm = SMOTE(ratio='auto', k_neighbors=5, random_state=7)
    balance_data, balance_target = \
        sm.fit_sample(x_train, y_train)
    
    print('サイズ調整後の学習データのlabel_zero length:{}'.format(len(balance_target[balance_target == 0])))
    print('サイズ調整後の学習データlabel_one length:{}'.format(len(balance_target[balance_target == 1])))

    return balance_data, balance_target

In [8]:
def convert_spot_to_bit(one_line):
    bit_column = one_line[9:46]
    spot_bit = 0
    
    for i, bit in enumerate(bit_column):
        if bit == 1:
            spot_bit += 2 ** i
        
    return spot_bit

In [9]:
def bus_stop_data_selection(bus_stop_id):
    filtered = sql_result[sql_result['bus_stop_id'] == bus_stop_id].copy()
    filtered = filtered[(filtered['delay'] == 0) |
                        (filtered['delay'] == 7) | (filtered['delay'] == 8) | (filtered['delay'] == 9)]

    delay_label = filtered.apply(lambda one_line: delay_threshold(one_line), axis=1)
    filtered['delay_label'] = delay_label

    spot_bit = filtered.apply(lambda one_line: convert_spot_to_bit(one_line), axis=1)
    spot_eigenvalue = spot_bit.apply(lambda bit: bit % 5000)
    filtered['spot_eigenvalue'] = spot_eigenvalue

    label_zero = len(filtered[filtered['delay_label'] == 0])
    label_one = len(filtered[filtered['delay_label'] == 1])

    print('label_zero length:{}'.format(label_zero))
    print('label_one length:{}'.format(label_one))
    
    print(label_one)

    if label_zero == 0:
        print('label_zeroが一件もありません')
        return 0, 0, 0
    elif label_one == 0:
        print('label_oneが一件もありません')
        return 0, 0, 0
    length = len(filtered)
    x = filtered[['time', 'youbi', 'event', 'season', 'access_num',
                  'cntpopulation']]
    y = filtered['delay_label']

    return x, y, length

In [10]:
def data_selection():
    labeled_list = sql_result.copy()
    labeled = sql_result.apply(lambda one_line:
                               delay_threshold(one_line), axis=1)
    labeled_list['delay_label'] = labeled

    spot_bit = sql_result.apply(lambda one_line:
                                convert_spot_to_bit(one_line), axis=1)
    spot_eigenvalue = spot_bit.apply(lambda bit: bit % 5000)
    labeled_list['spot_eigenvalue'] = spot_eigenvalue

    label_zero = labeled_list[labeled_list['delay_label'] == 0]
    label_one = labeled_list[labeled_list['delay_label'] == 1]
    length = 0

    print('label_zero length:{}'.format(len(label_zero)))
    print('label_one length:{}'.format(len(label_one)))

    if len(label_zero) == 0:
        print('label_zeroが一件もありません')
    elif len(label_one) == 0:
        print('label_oneが一件もありません')
    elif len(label_zero) <= len(label_one):
        label_one = label_one.sample(len(label_zero))
        length = len(label_zero)
    elif len(label_one) <= len(label_zero):
        label_zero = label_zero.sample(len(label_one))
        length = len(label_one)

    concat = pd.concat([label_zero, label_one], ignore_index=True, axis=0)
    y = concat['delay_label']
    x = concat[['time', 'youbi', 'event',
                'access_num', 'season', 'spot_eigenvalue', 'cntpopulation', 
                'kouyoudai_11', 'kouyoudai_12', 'kouyoudai_13', 'kouyoudai_14', 'kouyoudai_23', 'kouyoudai_24', 
                'poor_55', 'poor_56', 'poor_57', 
                'tosyokan_58', 
                'yuumai_59', 'yuumai_60', 
                'sakuragi_61', 'sakuragi_62', 'sakuragi_63', 'sakuragi_63', 'sakuragi_64', 
                'inaho_67', 'inaho_68', 'inaho_67', 'inaho_68', 
                'sikotuko_69', 'kuukousinai_70', 'yamato_71', 
                'higasibutai_72', 'midoridai_74', 
                'kuukousinai_80', 'sikotuko_81', 'yamato_82', 'higashi_83', 'higashi_84', 'higashi_85', 
                'midoridai_91', 'sakuragi_92', 'kouyoudai_93', 'kouyoudai_94', 'kouyoudai_95', 'kouyoudai_96', 
                'midaridai_97', 'midaridai_98', 
                'sakuragi_99', 'sakuragi_100', 
                'higashibutai_106', 'higashibutai_107', 'higashibutai_108', 'higashibutai_109', 
                'yuumai_110']]
       
    return x, y, length

In [14]:
from sklearn.model_selection import StratifiedKFold

kFold = StratifiedKFold(n_splits=5, shuffle=True,
                        random_state=np.random.seed(7))

all_epoch_acc = []
all_epoch_std = []

X, Y, length = bus_stop_data_selection(184)

# for i in range(5):
epoch_acc = []
epoch_std = []
scores = []
# for i in range(8):
epochs = 2
# if i == 0:
#     epochs = 2
# elif i == 1:
#     epochs = 6
# elif i == 2:
#     epochs = 10
# elif i == 3:
#     epochs = 14
# elif i == 4:
#     epochs = 18
# elif i == 5:
#     epochs = 22
# elif i == 6:
#     epochs = 26
# elif i == 7:
#     epochs = 30

for j in range(20):
    print('epochs{}'.format(epochs))

    if length != 0:
        for train, test in kFold.split(X, Y):
            balance_data, balance_target = smote(X.iloc[train], Y.iloc[train])
            balance_target_one_hot = np.identity(2)[balance_target.astype(int)]
            Y_test_one_hot = np.identity(2)[Y.iloc[test].astype(int)]

            model = Sequential()

            # 入力層
            model.add(Dense(10, input_dim=balance_data.shape[1],
                            activation='selu', kernel_initializer='he_normal',
                            kernel_regularizer=regularizers.l2(0.01),
                            activity_regularizer=regularizers.l1(0.01)
                            ))
            model.add(BatchNormalization())

            # model.add(Dense(5, activation='selu', kernel_initializer='he_normal',
            #                 kernel_regularizer=regularizers.l2(0.01),
            #                 activity_regularizer=regularizers.l1(0.01)))
            # model.add(BatchNormalization())

            # 出力層
            model.add(Dense(2, activation='softmax', kernel_initializer='he_normal'))

            model.compile(
                loss='categorical_crossentropy',
                optimizer=RMSprop(lr=0.01),
                metrics=['accuracy']
            )

            batch_size = int(len(balance_data) * 0.1)

            model.fit(balance_data, balance_target_one_hot, epochs=epochs, batch_size=batch_size, verbose=0)
            score = model.evaluate(X.iloc[test], Y_test_one_hot)

            scores.append(score[1] * 100)

        epoch_acc.append(np.mean(scores))
        epoch_std.append(np.std(scores))

    all_epoch_acc.append(np.mean(epoch_acc))
    all_epoch_std.append(epoch_std)


label_zero length:1253
label_one length:106
106
epochs2
サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:84
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/273 [==>...........................] - ETA: 20s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 20s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 20s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 21s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 20s



epochs2
サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:84
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/273 [==>...........................] - ETA: 21s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 21s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 21s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 22s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 22s



epochs2
サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:84
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/273 [==>...........................] - ETA: 22s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 23s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 23s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 23s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 23s



epochs2
サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:84
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/273 [==>...........................] - ETA: 23s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 24s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 23s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 24s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 23s



epochs2
サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:84
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/273 [==>...........................] - ETA: 24s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 25s



サイズ調整前の学習データのlabel_zero length:1002
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1002
サイズ調整後の学習データlabel_one length:1002


 32/272 [==>...........................] - ETA: 25s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


 32/271 [==>...........................] - ETA: 25s



サイズ調整前の学習データのlabel_zero length:1003
サイズ調整前の学習データのlabel_one length:85
サイズ調整後の学習データのlabel_zero length:1003
サイズ調整後の学習データlabel_one length:1003


In [13]:
print(np.mean(all_epoch_acc[0:20]))


82.16503978076052
