Skip to content

Commit

Permalink
Simplify class constructor attribute 'layers' or 'neurons'
Browse files Browse the repository at this point in the history
  • Loading branch information
nok committed Oct 21, 2017
1 parent 29412ab commit 7820508
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 120 deletions.
Expand Up @@ -6,13 +6,14 @@


iris_data = load_iris()
X, y = iris_data.data, iris_data.target
X = iris_data.data
y = iris_data.target

clf = tree.DecisionTreeClassifier()
clf.fit(X, y)

output = Porter(clf).export()
# output = Porter(clf, language='java').export()
porter = Porter(clf)
output = porter.export()
print(output)

"""
Expand Down
39 changes: 22 additions & 17 deletions examples/estimator/classifier/MLPClassifier/java/basics.ipynb
Expand Up @@ -115,18 +115,25 @@
"\n",
" private Activation hidden;\n",
" private Activation output;\n",
" private double[][] layers;\n",
" private double[][] network;\n",
" private double[][][] weights;\n",
" private double[][] bias;\n",
"\n",
" public MLPClassifier(String hidden, String output, double[][] layers, double[][][] weights, double[][] bias) {\n",
" public MLPClassifier(String hidden, String output, int[] layers, double[][][] weights, double[][] bias) {\n",
" this.hidden = Activation.valueOf(hidden.toUpperCase());\n",
" this.output = Activation.valueOf(output.toUpperCase());\n",
" this.layers = layers;\n",
" this.network = new double[layers.length + 1][];\n",
" for (int i = 0, l = layers.length; i < l; i++) {\n",
" this.network[i + 1] = new double[layers[i]];\n",
" }\n",
" this.weights = weights;\n",
" this.bias = bias;\n",
" }\n",
"\n",
" public Brain(String hidden, String output, int neurons, double[][][] weights, double[][] bias) {\n",
" this(hidden, output, new int[] { neurons }, weights, bias);\n",
" }\n",
"\n",
" private double[] compute(Activation activation, double[] v) {\n",
" switch (activation) {\n",
" case LOGISTIC:\n",
Expand Down Expand Up @@ -167,26 +174,24 @@
" }\n",
"\n",
" public int predict(double[] neurons) {\n",
" double[][] network = new double[this.layers.length + 1][];\n",
" System.arraycopy(new double[][] {neurons}, 0, network, 0, 1);\n",
" System.arraycopy(this.layers, 0, network, 1, this.layers.length);\n",
" this.network[0] = neurons;\n",
" \n",
" for (int i = 0; i < network.length - 1; i++) {\n",
" for (int j = 0; j < network[i + 1].length; j++) {\n",
" for (int l = 0; l < network[i].length; l++) {\n",
" network[i + 1][j] += network[i][l] * this.weights[i][l][j];\n",
" for (int i = 0; i < this.network.length - 1; i++) {\n",
" for (int j = 0; j < this.network[i + 1].length; j++) {\n",
" for (int l = 0; l < this.network[i].length; l++) {\n",
" this.network[i + 1][j] += this.network[i][l] * this.weights[i][l][j];\n",
" }\n",
" network[i + 1][j] += this.bias[i][j];\n",
" this.network[i + 1][j] += this.bias[i][j];\n",
" }\n",
" if ((i + 1) < (network.length - 1)) {\n",
" network[i + 1] = this.compute(this.hidden, network[i + 1]);\n",
" if ((i + 1) < (this.network.length - 1)) {\n",
" this.network[i + 1] = this.compute(this.hidden, this.network[i + 1]);\n",
" }\n",
" }\n",
" network[network.length - 1] = this.compute(this.output, network[network.length - 1]);\n",
" this.network[this.network.length - 1] = this.compute(this.output, this.network[this.network.length - 1]);\n",
" \n",
" int classIdx = 0;\n",
" for (int i = 0; i < network[network.length - 1].length; i++) {\n",
" classIdx = network[network.length - 1][i] > network[network.length - 1][classIdx] ? i : classIdx;\n",
" for (int i = 0; i < this.network[this.network.length - 1].length; i++) {\n",
" classIdx = this.network[this.network.length - 1][i] > this.network[this.network.length - 1][classIdx] ? i : classIdx;\n",
" }\n",
" return classIdx;\n",
" }\n",
Expand All @@ -201,7 +206,7 @@
" }\n",
"\n",
" // Parameters:\n",
" double[][] layers = {new double[50], new double[3]};\n",
" int[] layers = {50, 3};\n",
" double[][][] weights = {{{-0.055317816158370905, -0.25162425407767419, -0.33325197861130057, -0.13177626632767314, -0.23549246545141095, -0.27177010710624283, -0.20915665516342635, 0.015057965302913687, -0.72987930075766627, -0.89266106096455711, -0.053869498543610707, -0.22216066222318273, -0.39899502812673099, -0.28681010148596864, -0.31507011153899617, -0.67039973171352785, -0.46627854428112808, -0.55220519291597547, -0.6575355679434185, -0.2012625909486952, 0.33662471612525224, 0.25425126671040732, -0.12438197591047517, 0.22085158121433821, -0.43004111901976899, 0.12047821481685771, -0.27663295488184064, -0.30729210399464546, -0.2201096819240726, 0.1969894797949216, -0.26776467575302032, -0.5596322304662299, -0.66057684440281672, 0.022109851297338449, 0.054364802303910886, -0.30531205478429846, -0.94689187532840058, -1.1472602273053643, -0.32113622918572421, -0.97806246488102433, 0.12604362195657265, 0.5231042582505061, -0.14636842984592813, -0.11855589566328104, -0.2645119437495197, -0.034737117103522681, -0.057695932963591795, -0.13758846013369239, -0.14148094036584954, -0.24664384043086116}, {-0.32041711985049837, -0.043058968510450234, -0.19224497780614169, -0.15629983284213439, -0.0056178077580114668, -0.29775374218304873, 0.04941098010393874, -0.26163930851937761, -0.19756633578664284, -0.58925308366604867, -0.26510631981125071, -0.28998233921283878, -0.041920189170287073, -0.43290432093263526, -0.30002643148337588, -0.4865930166377811, -0.081521874811590178, -0.27842983575218733, -0.039622604256916022, 0.05770247643254544, 0.71332882051710467, -0.36349619254759102, -0.24047875144770522, 0.53045027219372654, -0.59792191320512733, -0.35631856975017934, 0.28500135453630931, -0.10148787221769563, 0.16720550745832014, 0.25971388860710165, 0.2555334798063561, -0.10856461864559638, -0.41610950088536136, -0.10073289562774027, -0.35234810455711674, 0.13500748765932857, -0.59236329802269883, -0.69284565960163447, 0.10895932947719877, -0.53435965541922292, -0.28749631798822201, 0.95302314181559555, -0.03339139951489787, -0.21465662301256941, -0.061241193383340838, -0.17531266101796375, 0.032502649222591232, 0.049118905364866981, -0.33141470527927364, 0.078095412793859842}, {-0.11556829517891457, -0.28839659290844383, 0.25729079162402252, -0.095152035767087209, 0.27235259530926254, 0.082238817438121087, -0.32278089365740004, 0.610895041675383, -0.53946895727666411, 0.66336833631950098, -0.21843631503665203, -0.34459111930044778, 0.26988257637824442, 0.22285364019619236, -0.28932878616881214, -0.16019244916792269, -0.15188207190529554, -0.17375211722189435, 0.12003588634597379, -0.25048218839808589, -1.1814760275905136, -0.11679924783811843, -0.31445752437419289, -0.71528558652206731, 0.4739675739265925, 0.14299307810727774, 0.035214113037884201, 0.22801710207885825, -0.25054728530241616, -0.55191137079060104, 0.057171971863154362, -0.16787314862568981, -0.35815855192213464, -0.3208968914994626, 0.60992105666976559, -0.22919074588760435, -0.63633627319871822, -0.3709104501371045, 0.24235752387495327, -0.60846958788171768, -0.25014809614706662, -1.0335138637201629, -0.29338371286091719, -0.18073925955574421, -0.3036274299231928, -0.26166657231672047, -0.29878288629035105, 0.14199047853995642, 0.039810711538198082, -0.32495770210408736}, {-0.28534610919825071, 0.22769872913725117, 0.045399612503103627, -0.19780148024951577, -0.16511364103110895, 0.16254807942801869, -0.20304390231246003, 0.20984846083539002, 0.05685419958788708, 0.62516653349039786, -0.17343217055411875, -0.022987144452160622, 0.079240581612972011, 0.40417766317484144, -0.22880223204823208, -0.42337014160558745, -0.38940003536068196, -0.17326402738615926, 0.088427404430028334, 0.04590025494515363, -0.5468653035931812, 0.42810346180594117, 0.053162665137136862, -0.32103853894220169, 0.31367095881914298, 0.24999765928913131, 0.11282020075848748, -0.15671789424443483, -0.2891056817821327, -0.28253677782850595, 0.086477013336133002, -0.36053269357622608, 0.086846992712735341, -0.28897125255757344, 0.059892059841316563, 0.19412094475437708, -0.53174460572971327, 0.049633740361929211, 0.016446620815376744, 0.0064188476524307449, -0.26556784617072887, -0.65689592060012747, 0.15670824169328113, 0.25181970948115312, 0.27187307037723824, 0.28797696815559204, -0.36711566037352317, -0.17708922988507345, 0.077851045457767137, 0.29933962842647954}}, {{-0.33618760940544834, 0.32081899159033467, -0.083051115109439883}, {0.0096383304388566442, 0.2419360820188009, 0.35899310789861344}, {0.050274587597035091, 0.086184558536937531, -0.14428921022123034}, {0.058431568313521046, 0.16824371439906796, 0.2411152128064154}, {0.17164895659196974, 0.13327594625746894, 0.24526414111440331}, {-0.11932084344947159, 0.11492655724963714, -0.033057727812256854}, {-0.079334977171167304, -0.060016494360769022, -0.066295991899890286}, {-0.71486824845646135, 0.1600868064828162, 0.46699990709340145}, {-0.083938207334024614, 0.27187161898314255, 0.047703604119507331}, {-0.58793523568582362, -0.27887692495319361, 0.91235920748293431}, {0.25570703360541275, 0.27175182540995058, 0.10949681033724466}, {-0.041009631889342313, -0.23307375710751005, 0.19163291322703435}, {-0.15189717246377268, -0.025154028942508011, 0.44780865283409288}, {-0.3115278729955171, -0.0056627446166434628, 0.66915144386450198}, {0.046335662746525126, -0.023074380835666274, -0.10585719439044175}, {-0.070740857338027027, -0.51355079283235039, -0.071290569035165702}, {0.31604404174783435, -0.21123109460406667, 0.21568944659130201}, {-0.19292113469545499, 0.13803331731090354, 0.48075587735422864}, {-0.039151124981216957, -0.35635262270414042, 0.040892690463512384}, {-0.18502508791073968, 0.06225274230793363, -0.12632668189370647}, {1.3423577687187591, -0.73226914108310381, -0.21773848562508968}, {-0.25166254943541466, -0.102124129031267, -0.11494109237159073}, {0.1538190166683838, -0.19636071211303363, -0.16955231982502666}, {0.69624421451105989, -0.21575476077610578, -0.22195121116483693}, {-0.67316493619822249, -0.25154862675010187, 0.48832090783946763}, {-0.60011861521328624, 0.2289721749490819, 0.20400673019934396}, {0.029615125302479721, 0.10372152966977911, -0.23919109178136511}, {0.16925716226788123, -0.18703765803400574, 0.013022157605494675}, {0.1919803411529804, -0.32143163003060055, -0.11818935970200066}, {0.80761667938359938, -0.33839244383018491, 0.03954930782973469}, {0.24669667496849926, 0.30268177307271782, 0.21964458163057088}, {-0.14219315348518174, -0.1147780335070106, 0.32706428493072132}, {0.30427505815686307, -0.24247142530193441, 0.35094508479670411}, {-0.31319460193132892, 0.18184804315717218, 0.15593394728341345}, {-0.6898315948318976, 0.15504602360040004, 0.29863974148791939}, {-0.082896141260148076, 0.18651145103423833, -0.044376894020555092}, {-0.065412112288315602, 0.19446473254155389, 0.25684436618993894}, {0.7474704410194063, -1.0212095046332008, 0.33152483608646371}, {0.3194605228960728, 0.091923324270713175, 0.3323621136813108}, {-0.28265940510607479, -0.29937990221663735, 0.38549704422299363}, {-0.36478754627854076, -0.14371799448150674, -0.13593364762469251}, {0.9447018816870989, -0.48391730610634803, -0.70129832484414678}, {0.24382869144366329, 0.12965597232536735, 0.12848807536103687}, {0.25456653136274177, -0.29157794873701726, -0.15670873719185607}, {0.32956189577408246, -0.19924632055817562, -0.16975463655496348}, {-0.16003760303725742, 0.16834508896389047, -0.02895200248099724}, {0.79191574248070318, -0.94258942667724988, -0.33557168526432885}, {0.20093565667960181, -0.13637901222029258, -0.31788160371143598}, {0.062872216044461984, 0.2313758213174964, -0.080066186818751758}, {0.16813372423064626, 0.0074972820959486042, 0.027557136131471489}}};\n",
" double[][] bias = {{0.30011741283138643, -0.029751221601027604, 0.27707089984418304, 0.09437747263089169, -0.073328190572502505, -0.009339555268726818, 0.069540321946648831, 0.030114067358476861, 0.1926469413394869, 0.023144226681427866, -0.070082924717630057, 0.24099343057814002, -0.2575927249328282, -0.39112904027196077, -0.24328056130217912, -0.13877438337256454, -0.3840609961969641, 0.20571932163518283, 0.12990256746838594, -0.32332067950525173, -0.11080976020947253, -0.1415205938964629, -0.24600210345938872, 0.26790419688164752, -0.29106359813082983, 0.24950761348487704, 0.054676119964720049, 0.25255465627456269, 0.22982296359481452, 0.29374548533376299, -0.026746489455462041, -0.044633956332849806, 0.015346134348148844, -0.1428540988439016, -0.059281796245450824, 0.030997844617798281, -0.48568377250041256, -0.21003762891120728, -0.044215767340361145, 0.018240118057367579, -0.1466920807505617, 0.39966256678781126, 0.051904810189690342, -0.28673053302698287, 0.19195282255033613, 0.074687451363289303, -0.35766933287571318, -0.05320421333257852, 0.11937922437695309, 0.27906785198501721}, {0.37658323831187673, 0.4591331949999668, -0.63671974160497635}};\n",
"\n",
Expand Down

0 comments on commit 7820508

Please sign in to comment.