Skip to content
Browse files

Added neural network based prediction

Uses Weka and Jython to run the predictions.
It should set up the folders it needs to pass data from Python -> Jython -> Python.
Yay, we added another language(ish)!
  • Loading branch information...
1 parent 73fb06e commit c2cd17c3a31af6897363a60ee52413877720d27f Chuma Nnaji committed
Showing with 138 additions and 225 deletions.
  1. +0 −6 adjustment/ema.py
  2. +0 −2 adjustment/macd
  3. +0 −41 ai/ann/ann.py
  4. +0 −171 ai/ann/nnlib.py
  5. +87 −0 ai/ann/predict_company.py
  6. +48 −0 ai/ann/weka.py
  7. +1 −1 mine/mdata.sql
  8. +1 −1 webservice/index.php
  9. +1 −3 website/index.html
View
6 adjustment/ema.py
@@ -1,6 +0,0 @@
-#public float CalculateEMA(float todaysPrice, float numberOfDays, float EMAYesterday){
-# float k = 2 / (numberOfDays + 1);
-# return todaysPrice * k + EMAYesterday * (1 – k);
-# }
-#
-#We want to calculate for 12 and 26 day windows
View
2 adjustment/macd
@@ -1,2 +0,0 @@
-#Simply deducting the longer (for example the 26 day) Exponential Moving Average from the shorter (eg 12 day) EMA gives you the MACD value. This value oscillates around a zero point, zero being where the 26 day EMA is identical, and therefore (usually) crosses over the 12 day EMA.
-#Once the MACD has been calculated, a 9 day Exponential moving average of the MACD value is then calculated. This value is then plotted and is known as the signal line.
View
41 ai/ann/ann.py
@@ -1,41 +0,0 @@
-import datetime
-import pymysql
-from nnlib import NN
-
-def get_company(symbol):
- conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='', db='limitless')
- cur = conn.cursor()
- cur.execute('SELECT date, volume, high_price, low_price, open_price, close_price, close_adjusted, price_change, short_ema, long_ema, macd, signal_line, histogram FROM company_%(symbol)s' %{'symbol': symbol })
-
- days = datetime.timedelta(22)
- inputs = []
- results = cur.fetchall()
- for company in results:
- for search in results:
- if (search[0]) == (company[0] + days):
- inputs.append([list(company[1:]), [search[2]]])
- # break
- #print str(company[0]) + " : " + str(company[0] + days)
-
- print "START TRAINING"
- n = NN(12, 12, 1)
- n.train(inputs, 2)
- n.weights()
- n.test(inputs)
-
-if __name__ == '__main__':
- get_company('aapl')
- # Teach network XOR function
- pat = [
- [[0,0], [0]],
- [[0,1], [1]],
- [[1,0], [1]],
- [[1,1], [0]]
- ]
-
- # create a network with two input, two hidden, and one output nodes
- n = NN(2, 2, 1)
- # train it with some patterns
- n.train(pat)
- # test it
- n.test(pat)
View
171 ai/ann/nnlib.py
@@ -1,171 +0,0 @@
-# Back-Propagation Neural Networks
-#
-# Written in Python. See http://www.python.org/
-# Placed in the public domain.
-# Neil Schemenauer <nas@arctrix.com>
-
-import math
-import random
-import string
-
-random.seed(0)
-
-# calculate a random number where: a <= rand < b
-def rand(a, b):
- return (b-a)*random.random() + a
-
-# Make a matrix (we could use NumPy to speed this up)
-def makeMatrix(I, J, fill=0.0):
- m = []
- for i in range(I):
- m.append([fill]*J)
- return m
-
-# our sigmoid function, tanh is a little nicer than the standard 1/(1+e^-x)
-def sigmoid(x):
- return math.tanh(x)
-
-# derivative of our sigmoid function, in terms of the output (i.e. y)
-def dsigmoid(y):
- return 1.0 - y**2
-
-class NN:
- def __init__(self, ni, nh, no):
- # number of input, hidden, and output nodes
- self.ni = ni + 1 # +1 for bias node
- self.nh = nh
- self.no = no
-
- # activations for nodes
- self.ai = [1.0]*self.ni
- self.ah = [1.0]*self.nh
- self.ao = [1.0]*self.no
-
- # create weights
- self.wi = makeMatrix(self.ni, self.nh)
- self.wo = makeMatrix(self.nh, self.no)
- # set them to random vaules
- for i in range(self.ni):
- for j in range(self.nh):
- self.wi[i][j] = rand(-0.2, 0.2)
- for j in range(self.nh):
- for k in range(self.no):
- self.wo[j][k] = rand(-2.0, 2.0)
-
- # last change in weights for momentum
- self.ci = makeMatrix(self.ni, self.nh)
- self.co = makeMatrix(self.nh, self.no)
-
- def update(self, inputs):
- if len(inputs) != self.ni-1:
- raise ValueError, 'wrong number of inputs'
-
- # input activations
- for i in range(self.ni-1):
- #self.ai[i] = sigmoid(inputs[i])
- self.ai[i] = inputs[i]
-
- # hidden activations
- for j in range(self.nh):
- summ = 0.0
- for i in range(self.ni):
- summ = summ + self.ai[i] * self.wi[i][j]
- self.ah[j] = sigmoid(summ)
-
- # output activations
- for k in range(self.no):
- summ = 0.0
- for j in range(self.nh):
- summ = summ + self.ah[j] * self.wo[j][k]
- self.ao[k] = sigmoid(summ)
-
- return self.ao[:]
-
-
- def backPropagate(self, targets, N, M):
- if len(targets) != self.no:
- raise ValueError, 'wrong number of target values'
-
- # calculate error terms for output
- output_deltas = [0.0] * self.no
- for k in range(self.no):
- error = targets[k]-self.ao[k]
- output_deltas[k] = dsigmoid(self.ao[k]) * error
-
- # calculate error terms for hidden
- hidden_deltas = [0.0] * self.nh
- for j in range(self.nh):
- error = 0.0
- for k in range(self.no):
- error = error + output_deltas[k]*self.wo[j][k]
- hidden_deltas[j] = dsigmoid(self.ah[j]) * error
-
- # update output weights
- for j in range(self.nh):
- for k in range(self.no):
- change = output_deltas[k]*self.ah[j]
- self.wo[j][k] = self.wo[j][k] + N*change + M*self.co[j][k]
- self.co[j][k] = change
- #print N*change, M*self.co[j][k]
-
- # update input weights
- for i in range(self.ni):
- for j in range(self.nh):
- change = hidden_deltas[j]*self.ai[i]
- self.wi[i][j] = self.wi[i][j] + N*change + M*self.ci[i][j]
- self.ci[i][j] = change
-
- # calculate error
- error = 0.0
- for k in range(len(targets)):
- error = error + 0.5*(targets[k]-self.ao[k])**2
- return error
-
-
- def test(self, patterns):
- for p in patterns:
- print p[0], '->', self.update(p[0])
-
- def weights(self):
- print 'Input weights:'
- for i in range(self.ni):
- print self.wi[i]
- print
- print 'Output weights:'
- for j in range(self.nh):
- print self.wo[j]
-
- def train(self, patterns, iterations=1000, N=0.5, M=0.1):
- # N: learning rate
- # M: momentum factor
- for i in xrange(iterations):
- error = 0.0
- for p in patterns:
- inputs = p[0]
- targets = p[1]
- self.update(inputs)
- error = error + self.backPropagate(targets, N, M)
- if i % 100 == 0:
- pass #print 'error %-14f' % error
-
-
-def demo():
- # Teach network XOR function
- pat = [
- [[0,0], [0]],
- [[0,1], [1]],
- [[1,0], [1]],
- [[1,1], [0]]
- ]
-
- # create a network with two input, two hidden, and one output nodes
- n = NN(2, 2, 1)
- # train it with some patterns
- n.train(pat)
- # test it
- n.test(pat)
-
-
-
-if __name__ == '__main__':
- demo()
View
87 ai/ann/predict_company.py
@@ -0,0 +1,87 @@
+from string import upper
+import datetime
+import os
+import pymysql
+
+def write_to_file(filename, data, name):
+ f = open(filename, 'w')
+ f.write(
+"""@RELATION %s
+@ATTRIBUTE volume NUMERIC
+@ATTRIBUTE high_price NUMERIC
+@ATTRIBUTE low_price NUMERIC
+@ATTRIBUTE open_price NUMERIC
+@ATTRIBUTE close_price NUMERIC
+@ATTRIBUTE close_adjusted NUMERIC
+@ATTRIBUTE price_change NUMERIC
+@ATTRIBUTE short_ema NUMERIC
+@ATTRIBUTE long_ema NUMERIC
+@ATTRIBUTE macd NUMERIC
+@ATTRIBUTE signal_line NUMERIC
+@ATTRIBUTE histogram NUMERIC
+@ATTRIBUTE target_price NUMERIC
+
+@DATA
+""" %(name))
+
+ for line in data:
+ f.write(','.join(str(x) for x in line[0]))
+ f.write(',')
+ f.write(str(line[1]))
+ f.write('\n')
+ f.close()
+
+
+def get_company(symbol):
+ print 'Running ' + symbol + '...'
+ conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='', db='limitless')
+ cur = conn.cursor()
+ cur.execute('SELECT date, volume, high_price, low_price, open_price, close_price, close_adjusted, price_change, short_ema, long_ema, macd, signal_line, histogram FROM company_%(symbol)s ORDER BY date' %{'symbol': upper(symbol) })
+
+ days = datetime.timedelta(22)
+ train_data = []
+ test_data = []
+ dates = []
+ results = cur.fetchall()
+ for company in results:
+ test_data.append([company[1:], company[2]])
+ dates.append(company[0] + days)
+ for search in results:
+ if (search[0]) == (company[0] + days):
+ train_data.append([company[1:], search[2]])
+ break
+
+ filename = "arff/" + symbol + ".train.arff"
+ filename2 = "arff/" + symbol + ".test.arff"
+ write_to_file(filename, train_data, symbol)
+ write_to_file(filename2, test_data, symbol)
+
+ execs = '/usr/bin/jython weka.py arff/' + symbol + '.train.arff arff/' + symbol + '.test.arff'
+ os.system(execs);
+ return dates;
+
+if __name__ == '__main__':
+ for path in ['arff', 'models', 'predictions']
+ if not os.path.isdir(path):
+ os.makedirs(path)
+
+ newpath = 'C:\Program Files\alex'; if not os.path.exists(newpath): os.makedirs(newpath)
+ conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='', db='limitless')
+ cur = conn.cursor()
+ cur.execute('SELECT symbol FROM companies WHERE avg_volume IS NOT NULL')
+
+ results = cur.fetchall()
+ for symbol in results:
+ dates = get_company(symbol[0])
+
+ print "Saving...\n"
+ cur.execute('DROP TABLE IF EXISTS `weka_prediction_company_%s`' %(symbol[0]))
+ cur.execute('CREATE TABLE IF NOT EXISTS `weka_prediction_company_%s` (`date` date NOT NULL, `open_price` float DEFAULT NULL, `close_price` float DEFAULT NULL)' %(symbol[0]))
+
+ file = open("predictions/%s" %(symbol[0]))
+ counter = 0;
+ for line in file:
+ sql = "INSERT INTO `weka_prediction_company_%s`(date, open_price, close_price) VALUES ('%s', 0, %f)" %((symbol[0], dates[counter], float(line)))
+ cur.execute(sql)
+ counter = counter + 1
+ file.close()
View
48 ai/ann/weka.py
@@ -0,0 +1,48 @@
+import sys
+import java.io.FileReader as FileReader
+import java.lang.StringBuffer as StringBuffer
+import java.lang.Boolean as Boolean
+
+import weka.core.Instances as Instances
+import weka.classifiers.trees.J48 as J48
+import weka.classifiers.Evaluation as Evaluation
+import weka.core.Range as Range
+import weka.classifiers.functions.MultilayerPerceptron as MultilayerPerceptron
+import weka.core.SerializationHelper as SerializationHelper
+
+# check commandline parameters
+if (not (len(sys.argv) == 3)):
+ print "Usage: UsingJ48Ext.py <ARFF-file>"
+ sys.exit()
+
+file = FileReader(sys.argv[1])
+file2 = FileReader(sys.argv[2])
+data = Instances(file)
+test = Instances(file2)
+data.setClassIndex(data.numAttributes() - 1)
+test.setClassIndex(test.numAttributes() - 1)
+evaluation = Evaluation(data)
+buffer = StringBuffer()
+attRange = Range() # no additional attributes output
+outputDistribution = Boolean(False) # we don't want distribution
+nn = MultilayerPerceptron()
+nn.buildClassifier(data) # only a trained classifier can be evaluated
+
+#print evaluation.evaluateModel(nn, ['-t', sys.argv[1], '-T', sys.argv[2]])#;, [buffer, attRange, outputDistribution])
+res = evaluation.evaluateModel(nn, test, [buffer, attRange, outputDistribution])
+f = open('predictions/' + data.relationName(), 'w')
+for d in res:
+ f.write(str(d) + '\n');
+f.close()
+
+SerializationHelper.write("models/" + data.relationName() + ".model", nn)
+
+# print out the built model
+#print "--> Generated model:\n"
+#print nn
+
+#print "--> Evaluation:\n"
+#print evaluation.toSummaryString()
+
+#print "--> Predictions:\n"
+#print buffer
View
2 mine/mdata.sql
@@ -18,7 +18,7 @@
--
-- Table structure for table `companies`
--
-
+use limitless;
DROP TABLE IF EXISTS `companies`;
/*!40101 SET @saved_cs_client = @@character_set_client */;
/*!40101 SET character_set_client = utf8 */;
View
2 webservice/index.php
@@ -1,5 +1,5 @@
<?php
-mysql_connect('localhost', 'root', '') or die("mysql connect: " . mysql_error());
+mysql_connect('localhost', 'root', 'denny') or die("mysql connect: " . mysql_error());
mysql_select_db('limitless');
$stock = $_GET['stock'];
$sql = 'select UNIX_TIMESTAMP(date), close_price from company_' . strtoupper($stock) . ' ORDER BY date ASC';
View
4 website/index.html
@@ -10,9 +10,7 @@
<script type="text/javascript">
$(function() {
var stock = 'msft';
- $.getJSON('http://localhost/limitless/webservice/index.php?stock=' + stock, function(data) {
- // $.getJSON('http://172.16.0.8/~minhnguyen/limitless/webservice/index.php?stock=' + stock, function(data) {
- // Create the chart
+ $.getJSON('/limitless/webservice/index.php?stock=' + stock, function(data) {
console.log(data);
window.chart = new Highcharts.StockChart({
chart : {

0 comments on commit c2cd17c

Please sign in to comment.
Something went wrong with that request. Please try again.