In [23]:
from bson import ObjectId

from trading.db import get_database, transform_son
from trading.classifier.random_forest import RFClassifier

In [31]:
def get_chart_data(db, chart_id):
    chart = db.candle_data.find_one({'_id': ObjectId(chart_id)})
    return transform_son(chart)    

In [25]:
def get_random_forest_classifier():
    config = {'classifier_id': None}
    rf_classifier = RFClassifier(config)
    return rf_classifier

In [26]:
def save_random_forest_classifier(db, rf):
    serialized_classifier = rf.serialize()
    serialized_classifier['flags'] = ['pattern']
    classifier_id = ObjectId()
    serialized_classifier['_id'] = classifier_id
    db.classifiers.insert(serialized_classifier)
    print 'Inserted Classifier Id {classifier_id}'.format(classifier_id=classifier_id)

In [68]:
def prepare_candle_data(candles, pattern):
    X = []
    y = []

    features = ['close', 'open', 'high', 'low']

    for tick_data in candles:
        candle_data = tick_data['candle']
        data = []
        for feature in features:
            data.append(candle_data[feature])
        if 'pattern' in tick_data:
            y.append(candle_data['pattern'])
        else:
            y.append('None')
        X.append(data)

    return X, y

In [69]:
def train_pattern_classifier(target_chart_id):
    db = get_database()
    chart_data = get_chart_data(db, target_chart_id)
    classifier = get_random_forest_classifier()
    candles = chart_data['candles']
    print('Found %s data points', len(candles))
    X,y = prepare_candle_data(candles[0:2], 'buy')
    classifier.train(X,y)
    save_random_forest_classifier(db, classifier)


In [57]:
target_chart_id = '572e06a17f9b5e6f45a0c8c6'

In [53]:
train_pattern_classifier(target_chart_id)

{u'date': {u'utc': 1420117200.0, u'hour': 13, u'month': 1, u'year': 2015, u'day': 1, u'minute': 0}, u'candle': {u'high': 1.20977, u'close': 1.20962, u'open': 1.20965, u'low': 1.20962}, u'id': 1462634008.363468}


In [73]:
db = get_database()

In [76]:
chart = db.candle_data.find_one()

In [78]:

def make_date_id_map(candles):
    date_id_map = {}

    for candle in candles:
        date = candle['date']
        formatted_date = '-'.join([str(date['year']), str(date['month']), str(date['day']), str(date['hour']), str(date['minute'])])
        date_id_map[formatted_date] = candle
    return date_id_map

In [79]:
candles = chart['candles']

In [80]:
len(candles)

31017

In [83]:
id_map = make_date_id_map(candles)

In [85]:
len(id_map)

31017

In [88]:
len(id_map.keys())

31017