In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from MiniBatch import MiniBatch

df = pd.read_csv('dataset.csv')
df

Unnamed: 0,outlook,temperature,humidity,windy,play
0,sunny,85,85,False,no
1,sunny,80,90,True,no
2,overcast,83,86,False,yes
3,rainy,70,96,False,yes
4,rainy,68,80,False,yes
5,rainy,65,70,True,no
6,overcast,64,65,True,yes
7,sunny,72,95,False,no
8,sunny,69,70,False,yes
9,rainy,75,80,False,yes


In [2]:
def rationalize(list_df) :
    max_df = list_df.max()
    result = list(map(lambda x: x/(max_df)*1, list_df))
    return result

df['temperature'] = rationalize(df['temperature'])
df['humidity'] = rationalize(df['humidity'])
df

Unnamed: 0,outlook,temperature,humidity,windy,play
0,sunny,1.0,0.885417,False,no
1,sunny,0.941176,0.9375,True,no
2,overcast,0.976471,0.895833,False,yes
3,rainy,0.823529,1.0,False,yes
4,rainy,0.8,0.833333,False,yes
5,rainy,0.764706,0.729167,True,no
6,overcast,0.752941,0.677083,True,yes
7,sunny,0.847059,0.989583,False,no
8,sunny,0.811765,0.729167,False,yes
9,rainy,0.882353,0.833333,False,yes


In [3]:
df['outlook'] = pd.Categorical(df['outlook'])
dfOutlook = pd.get_dummies(df['outlook'], prefix = 'outlook')
dfOutlook

Unnamed: 0,outlook_overcast,outlook_rainy,outlook_sunny
0,0,0,1
1,0,0,1
2,1,0,0
3,0,1,0
4,0,1,0
5,0,1,0
6,1,0,0
7,0,0,1
8,0,0,1
9,0,1,0


In [4]:
df['windy'] = pd.Categorical(df['windy'])
dfWindy = pd.get_dummies(df['windy'], prefix = 'windy')
dfWindy

Unnamed: 0,windy_False,windy_True
0,1,0
1,0,1
2,1,0
3,1,0
4,1,0
5,0,1
6,0,1
7,1,0
8,1,0
9,1,0


In [5]:
df['play'] = df['play'].astype('category')
df['play'] = df['play'].cat.codes

df = pd.concat([df, dfOutlook, dfWindy], axis=1)
df = df.drop(columns=['outlook', 'windy'], axis=1)
df

Unnamed: 0,temperature,humidity,play,outlook_overcast,outlook_rainy,outlook_sunny,windy_False,windy_True
0,1.0,0.885417,0,0,0,1,1,0
1,0.941176,0.9375,0,0,0,1,0,1
2,0.976471,0.895833,1,1,0,0,1,0
3,0.823529,1.0,1,0,1,0,1,0
4,0.8,0.833333,1,0,1,0,1,0
5,0.764706,0.729167,0,0,1,0,0,1
6,0.752941,0.677083,1,1,0,0,0,1
7,0.847059,0.989583,0,0,0,1,1,0
8,0.811765,0.729167,1,0,0,1,1,0
9,0.882353,0.833333,1,0,1,0,1,0


In [6]:
features = df.drop(columns=['play'], axis=1)
targets = df['play']

print (features.head)
print ()
print (targets.head)

<bound method NDFrame.head of     temperature  humidity  outlook_overcast  outlook_rainy  outlook_sunny  \
0      1.000000  0.885417                 0              0              1   
1      0.941176  0.937500                 0              0              1   
2      0.976471  0.895833                 1              0              0   
3      0.823529  1.000000                 0              1              0   
4      0.800000  0.833333                 0              1              0   
5      0.764706  0.729167                 0              1              0   
6      0.752941  0.677083                 1              0              0   
7      0.847059  0.989583                 0              0              1   
8      0.811765  0.729167                 0              0              1   
9      0.882353  0.833333                 0              1              0   
10     0.882353  0.729167                 0              0              1   
11     0.847059  0.937500                 1   

## Training and Testing

In [7]:
features_train, features_test, targets_train, targets_test = train_test_split(features, targets, test_size=0.5, stratify=targets, random_state=None)
print (features_train)
print ()
print (targets_train)
print ()

print (features_test)
print ()
print (targets_test)
print ()

    temperature  humidity  outlook_overcast  outlook_rainy  outlook_sunny  \
1      0.941176  0.937500                 0              0              1   
12     0.952941  0.781250                 1              0              0   
10     0.882353  0.729167                 0              0              1   
5      0.764706  0.729167                 0              1              0   
0      1.000000  0.885417                 0              0              1   
11     0.847059  0.937500                 1              0              0   
3      0.823529  1.000000                 0              1              0   

    windy_False  windy_True  
1             0           1  
12            1           0  
10            0           1  
5             0           1  
0             1           0  
11            0           1  
3             1           0  

1     0
12    1
10    1
5     0
0     0
11    1
3     1
Name: play, dtype: int8

    temperature  humidity  outlook_overcast  outlook_rainy  o

In [8]:
model = MiniBatch(5, 8, 5, 0.25, 0.0001, 5)
model.fit(features, targets)
print ('Predict')
predict = model.predict(features)

print (len(predict))
print (len(targets_test))

print ('Accuracy\t', accuracy_score(targets, predict))

[[[87.53397900611365, 58.22535801478902, 68.10126836201559, 102.25675517515386, 114.64371173210023], [45.61251622654027, 34.10296067748237, 42.43546502720706, 25.94760391324901, 37.88541446222495], [17.263433671676097, 22.94555218507679, 9.824629688576398, 18.17609117846867, 9.365971748962812], [7.998626269148694, 10.169507810145664, 6.306330330410549, 7.971146338117819, 7.7053620570075845], [3.5095223846087302, 4.297715327547799, 2.6534012075053472, 1.6201845887271755, 3.6007921417155924], [0.84372072721366, 0.6534270053413628, 1.1622295037972012, 1.5161015750321538, 1.6181172944421516], [0.5198805871591253, 0.3321068111217794, 0.3420762964941634, 0.4285208924127776, 0.4325417607215127], [0.16608289287404177, 0.13121757367688916, 0.0471856937663127, 0.1321877713221973, 0.07423613532638267], [0.0012689165022382774]]]
[[[87.53397900611365, 58.22535801478902, 68.10126836201559, 102.25675517515386, 114.64371173210023], [45.61251622654027, 34.10296067748237, 42.43546502720706, 25.947603913

[[[246348359.26971957, 245774112.96705392, 246122265.26339749, 247150495.42350873, 247286476.83671337], [1055849.6105807177, 1050872.2827919812, 1057408.3078990357, 1049192.4080877616, 1051326.9922322654], [10657.012493915383, 10780.409827082247, 10452.130676865014, 10620.279542825296, 10432.696389183802], [203.021704340004, 209.7884557320868, 196.60107157465268, 203.3919987498648, 204.0339691427287], [9.32751793588968, 9.818302533445435, 8.787202803765142, 7.953423556200871, 9.279758111846236], [0.7556281305575825, 0.6728078970517847, 0.9352235132448288, 1.019391588060961, 1.0268654371165107], [0.1866138911784237, 0.16792942106656772, 0.1706749705236329, 0.1625327875968137, 0.17036781165787748], [0.04191218776147473, 0.06103854955621654, 0.013087185210843625, 0.056264038856902836, 0.050916146742505365], [-0.0398995201992236]]]
[[[246348359.26971957, 245774112.96705392, 246122265.26339749, 247150495.42350873, 247286476.83671337], [1055849.6105807177, 1050872.2827919812, 1057408.3078990

[[[3.364358670798484e+18, 3.3643584539552707e+18, 3.364358585631391e+18, 3.364358978567974e+18, 3.3643590287758034e+18], [400677539835.25574, 400675357987.9601, 400678302386.68506, 400674663285.3097, 400675519245.79553], [4740703.2888381295, 4742373.108747453, 4737797.262524328, 4740078.624774904, 4737519.593666827], [2830.3367883626015, 2852.9003681156273, 2808.271740322612, 2831.7892115644777, 2834.578524078577], [30.49948888825279, 31.408997238216553, 29.49805030771228, 27.854852593001397, 30.401352353380602], [1.5207762658641115, 1.3918816193328905, 1.812130716926948, 1.942249389089013, 1.9482792600773169], [0.2816626664245324, 0.24972890368504588, 0.27571978520565676, 0.24635451433296895, 0.27066630878903863], [0.07323592726161508, 0.08418919761404355, 0.039194073996923554, 0.06491907704217642, 0.08688571641618997], [0.0035603416139063124]]]
[[[3.364358670798484e+18, 3.3643584539552707e+18, 3.364358585631391e+18, 3.364358978567974e+18, 3.3643590287758034e+18], [400677539835.25574,

[[[8.901037021201577e+27, 8.901037021199257e+27, 8.901037021200666e+27, 8.90103702120487e+27, 8.901037021205408e+27], [4285270432959987.5, 4285270352968258.0, 4285270461003186.5, 4285270327543134.0, 4285270358838175.0], [173811390.34851944, 173817622.51155278, 173800437.4712599, 173808959.57976168, 173799390.04605842], [10626.326604289683, 10658.496790422092, 10594.417588789438, 10628.538362639649, 10632.96999017907], [43.06997140426729, 43.8745466430627, 42.16507824325856, 40.6247686116037, 42.98278347833279], [1.4606838686940586, 1.3663006997052862, 1.6678992329586864, 1.7790310666971816, 1.7948385502806254], [0.22864320052752146, 0.18491951273161664, 0.19029715368737218, 0.19771050235596252, 0.22104122009140767], [0.06293945590365829, 0.0517786245081889, 0.03279355958923, 0.0621270123186225, 0.03579212341425266], [0.002118488714809345]]]
[[[8.901037021201577e+27, 8.901037021199257e+27, 8.901037021200666e+27, 8.90103702120487e+27, 8.901037021205408e+27], [4285270432959987.5, 42852703

[[[4.230385517058426e+34, 4.230385517058425e+34, 4.230385517058426e+34, 4.230385517058426e+34, 4.230385517058426e+34], [1.9284453388039475e+18, 1.9284453378328177e+18, 1.9284453391444946e+18, 1.928445337524194e+18, 1.9284453379040353e+18], [2109161496.7128124, 2109178951.949623, 2109130731.0370145, 2109154607.3569763, 2109127788.44635], [29802.98589269295, 29851.26874603894, 29754.814408907645, 29806.3808862422, 29813.338182805957], [64.32209814624339, 65.22819015914196, 63.31884690196932, 61.55052320643889, 64.21479115197309], [1.6793707486561165, 1.5796090897874773, 1.9086267402093275, 2.00945927871717, 2.0118476168851043], [0.2160074171934605, 0.19488814942352034, 0.21441031533307026, 0.19076891342256244, 0.21223868516966696], [0.04686115856109463, 0.05498531809645129, 0.030296508767581356, 0.044307216582306325, 0.06187666347664184], [0.0036961592726695515]]]
[[[4.230385517058426e+34, 4.230385517058425e+34, 4.230385517058426e+34, 4.230385517058426e+34, 4.230385517058426e+34], [1.928

[[[1.935976041066028e+38, 1.935976041066028e+38, 1.935976041066028e+38, 1.935976041066028e+38, 1.935976041066028e+38], [1.855693919153461e+19, 1.8556939190839665e+19, 1.855693919177832e+19, 1.855693919061882e+19, 1.8556939190890623e+19], [1508574782.2551322, 1508578593.7394185, 1508568054.793153, 1508573269.201468, 1508567411.304494], [6510.408646650076, 6516.490401278301, 6504.308915515067, 6510.8475878249865, 6511.753587122262], [8.076139130816207, 8.156181278663333, 7.979828991619314, 7.81004093944985, 8.070427585469467], [0.17094388074880495, 0.16235528451989317, 0.1764822517686184, 0.19193773164909994, 0.1980896741141342], [0.027479140225994136, 0.01672272734701267, -0.0008072341056373445, 0.01863619111104291, 0.026862448088720913], [-0.0012214700048992985, -0.0021435013254150585, 0.0029060367355585176, 0.026144136952261206, -0.013832946335403211], [-0.04316496624245049]]]
[[[1.935976041066028e+38, 1.935976041066028e+38, 1.935976041066028e+38, 1.935976041066028e+38, 1.935976041066

[[[-1.8166542626557453e+40, -1.8166542626557453e+40, -1.8166542626557453e+40, -1.8166542626557453e+40, -1.8166542626557453e+40], [-6.189750323729012e+19, -6.1897503236482425e+19, -6.189750323757336e+19, -6.1897503236225745e+19, -6.1897503236541645e+19], [-1752636570.1654117, -1752639669.6376517, -1752631098.0262375, -1752635338.4826326, -1752630574.6070056], [-5292.982123416919, -5297.281669173814, -5288.670771182264, -5293.289498526643, -5293.933408176201], [-5.703259993557421, -5.759660804911732, -5.636832478788184, -5.522585494509207, -5.695532931630918], [-0.1084382079324892, -0.10644490176742132, -0.1223047989272599, -0.12845140969618446, -0.13151504403103667], [-0.011998161350022195, -0.004869550089090433, -0.016422213943878786, -0.014207329707920184, -0.010034261264234278], [-0.017531189330553896, -0.005151578193504397, -0.0036923740390277723, 0.011373783815559235, -0.0026429115655779503], [-0.051704997316508364]]]
[[[-1.8166542626557453e+40, -1.8166542626557453e+40, -1.81665426

[[[-6.010300243317872e+42, -6.010300243317872e+42, -6.010300243317872e+42, -6.010300243317872e+42, -6.010300243317872e+42], [-1.3081436929525796e+21, -1.3081436929487408e+21, -1.308143692953926e+21, -1.3081436929475207e+21, -1.308143692949022e+21], [-8326055765.197648, -8326062740.575651, -8326043443.894646, -8326052987.646896, -8326042265.325279], [-11909.803106884132, -11916.617736157294, -11902.961801626894, -11910.290754008454, -11911.323707639516], [-9.016509698861089, -9.098841426515055, -8.931449290915992, -8.772725424425659, -9.00243622395106], [-0.14466391869618636, -0.13812450840337992, -0.1743059309158346, -0.1752522957551046, -0.17014511208211344], [-0.006851349712144125, -0.01360591856486406, -0.033177384618014054, -0.012142786426635205, -0.0069563358421525275], [-0.010096524442465553, -0.009567249436734423, -0.0004878355550272488, 0.01905361181502676, -0.021439752630405134], [-0.049407356575059055]]]
[[[-6.010300243317872e+42, -6.010300243317872e+42, -6.010300243317872e+4

[[[1.4859115867525898e+45, 1.4859115867525898e+45, 1.4859115867525898e+45, 1.4859115867525898e+45, 1.4859115867525898e+45], [4.68244986830645e+22, 4.6824498682999024e+22, 4.682449868308746e+22, 4.682449868297822e+22, 4.682449868300382e+22], [141929880239.42114, 141929968098.02884, 141929725023.33707, 141929845234.0218, 141929710176.38507], [149949.4772272116, 150024.7325262181, 149873.81031531142, 149954.92211156175, 149966.4036364094], [99.50396598852697, 100.28996642105653, 98.61989861242166, 97.00193976402304, 99.40644974903545], [1.5609193675051096, 1.495187087302534, 1.718243162805261, 1.7936730446180398, 1.8028812560168843], [0.15765443353583133, 0.13083119366609353, 0.14896333383936666, 0.1424979491247611, 0.1507962977636], [0.04014723481090831, 0.03829463279425052, 0.026764644559541487, 0.020069485082857576, 0.026414197637174792], [0.0023607206580183618]]]
[[[1.4859115867525898e+45, 1.4859115867525898e+45, 1.4859115867525898e+45, 1.4859115867525898e+45, 1.4859115867525898e+45],

[[[-1.7424027393800386e+46, -1.7424027393800386e+46, -1.7424027393800386e+46, -1.7424027393800386e+46, -1.7424027393800386e+46], [-8.255451674414983e+22, -8.255451674410623e+22, -8.255451674416513e+22, -8.255451674409237e+22, -8.255451674410943e+22], [-94502960900.32318, -94502997310.79047, -94502896561.85402, -94502946381.01079, -94502890407.62924], [-62125.86267466142, -62150.81433845362, -62100.7551223875, -62127.66967981096, -62131.503755408085], [-32.94907289287442, -33.18930331622944, -32.687009382381746, -32.19932682320466, -32.9149339771635], [-0.46275944254158363, -0.44379011557341874, -0.5224511049944651, -0.5371204781607994, -0.5330485501975868], [-0.03555887626505057, -0.03876092920663546, -0.06115231166920572, -0.03842831904097495, -0.035006554149055465], [-0.018284833311549408, -0.015931373148416344, -0.003285127705805198, 0.01497550300284384, -0.02748807486859395], [-0.057905759209866826]]]
[[[-1.7424027393800386e+46, -1.7424027393800386e+46, -1.7424027393800386e+46, -1.

[[[-5.320162031657548e+47, -5.320162031657548e+47, -5.320162031657548e+47, -5.320162031657548e+47, -5.320162031657548e+47], [-5.293031739354207e+23, -5.293031739352836e+23, -5.2930317393546874e+23, -5.2930317393524e+23, -5.2930317393529365e+23], [-296881981565.3032, -296882063695.563, -296881836420.95264, -296881948798.55145, -296881822537.29346], [-140076.74363050092, -140124.9828316901, -140028.16073555205, -140080.2515324705, -140087.70342674197], [-63.64306964908374, -64.06212814270528, -63.17121848122029, -62.29923929531412, -63.58844801091168], [-0.8407372778882982, -0.8107096247778954, -0.9237912511187075, -0.9594654549786457, -0.9639405866872062], [-0.07493561305795388, -0.06193441553805667, -0.08160877031546988, -0.07276435294092186, -0.07100224178354574], [-0.03434495537707115, -0.0212912960382703, -0.009030386386712984, 0.004754931386023498, -0.01787342505493235], [-0.06988087176137156]]]
[[[-5.320162031657548e+47, -5.320162031657548e+47, -5.320162031657548e+47, -5.320162031

[[[9.861295179452561e+48, 9.861295179452561e+48, 9.861295179452561e+48, 9.861295179452561e+48, 9.861295179452561e+48], [3.054590716014128e+24, 3.054590716013662e+24, 3.054590716014291e+24, 3.0545907160135136e+24, 3.054590716013696e+24], [1007958194937.5642, 1007958413215.0714, 1007957809157.3818, 1007958107826.1243, 1007957772255.852], [372114.4977530445, 372228.9024782887, 371999.2270388227, 372122.83061708586, 372140.55993785156], [150.80388052491674, 151.7474542811105, 149.75339612972837, 147.7878064992898, 150.68240953444075], [1.898057781372661, 1.8218250263539288, 2.0802749635902504, 2.1568413866440794, 2.156924310420209], [0.16347495062318926, 0.15111336597765676, 0.17003707903882834, 0.14814029434619314, 0.16230419489460043], [0.033129041915483465, 0.035879421479611, 0.021317694784056054, 0.022555036913560717, 0.046685916636324945], [0.010403198285849681]]]
[[[9.861295179452561e+48, 9.861295179452561e+48, 9.861295179452561e+48, 9.861295179452561e+48, 9.861295179452561e+48], [3.

[[[1.1838370168628228e+50, 1.1838370168628228e+50, 1.1838370168628228e+50, 1.1838370168628228e+50, 1.1838370168628228e+50], [8.810614679466639e+24, 8.810614679466003e+24, 8.810614679466862e+24, 8.8106146794658e+24, 8.810614679466049e+24], [1376711555741.0774, 1376711760085.6375, 1376711194546.148, 1376711474154.3013, 1376711159996.2349], [348247.20103342534, 348336.26080869744, 348157.40151094244, 348253.7038379991, 348267.57813759043], [117.28798185342518, 117.95270751441153, 116.54062368711752, 115.14063414682634, 117.20272403481843], [1.3570060349579816, 1.3079930926275611, 1.4760618190682857, 1.5313993284550813, 1.5370087407947741], [0.11564762337309781, 0.0981603584323573, 0.11224141107078803, 0.10557061177437366, 0.11156961861137035], [0.02842414135213438, 0.025973560410282592, 0.018438680343891053, 0.012280526138085232, 0.020982613120899954], [0.0035103088701108126]]]
[[[1.1838370168628228e+50, 1.1838370168628228e+50, 1.1838370168628228e+50, 1.1838370168628228e+50, 1.18383701686

[[[-5.79712479522544e+50, -5.79712479522544e+50, -5.79712479522544e+50, -5.79712479522544e+50, -5.79712479522544e+50], [-1.7052329531931006e+25, -1.7052329531930185e+25, -1.7052329531931294e+25, -1.7052329531929926e+25, -1.7052329531930246e+25], [-1773611976128.1929, -1773612194520.0037, -1773611590086.09, -1773611888917.0564, -1773611553159.4065], [-372011.62690747075, -372098.8066640326, -371923.70033516356, -372017.9954064875, -372031.6064727243], [-114.72654835756421, -115.35295347208375, -114.02512006394767, -112.7081780837921, -114.64322271409554], [-1.271449573798598, -1.2281062262362594, -1.3902416745143442, -1.4376482447574708, -1.4401206899723107], [-0.10068166821230236, -0.08903039162551858, -0.1153257140272936, -0.09797785585579702, -0.09777001266647081], [-0.04269463529793754, -0.026501192914643607, -0.010177785470745594, 0.005907409337969507, -0.03088678037768855], [-0.08395679343433983]]]
[[[-5.79712479522544e+50, -5.79712479522544e+50, -5.79712479522544e+50, -5.79712479

[[[5.089603342440144e+51, 5.089603342440144e+51, 5.089603342440144e+51, 5.089603342440144e+51, 5.089603342440144e+51], [5.4665201485628475e+25, 5.466520148562687e+25, 5.466520148562904e+25, 5.466520148562636e+25, 5.466520148562699e+25], [3467412143203.4053, 3467412479307.373, 3467411549054.3623, 3467412008956.9443, 3467411492221.2554], [572257.8772071123, 572377.4524813452, 572137.2256258607, 572266.6273018694, 572285.3501298847], [157.23372842391075, 158.03952618952096, 156.32799957005057, 154.61930889080693, 157.13033072533085], [1.659047638919687, 1.600944820762309, 1.7963345516472753, 1.861663747866208, 1.867365904826372], [0.13513684347983407, 0.11632359029525527, 0.12656984502356955, 0.12187038686255691, 0.13270240746820272], [0.03268847998889008, 0.025163595110148936, 0.017591227028000013, 0.021079477235886268, 0.023772631384549336], [0.011270210464018505]]]
[[[5.089603342440144e+51, 5.089603342440144e+51, 5.089603342440144e+51, 5.089603342440144e+51, 5.089603342440144e+51], [5.

Predict
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
14
7
Accuracy	 0.6428571428571429
