In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from collections import namedtuple
from random import shuffle
from itertools import groupby
import datetime as dt

In [3]:
import gensim.models.doc2vec as d2v
import numpy as np
import cvxpy as cvx

In [4]:
from articles import Articles
from helper.collections import Namedtuples
import model.utility as ut
import model.problem as pr
from sp500 import get_sp500_records

In [15]:
# Fetch articles from db, then fetch sp500 records corresponding to the articles dates.
# Weed out articles out of business hours.
articles = list(Articles(['id','date']))
articles = Namedtuples(articles)
articles = sorted(articles,key=lambda a: a.date)
sp500 = get_sp500_records(beg_date=articles[0].date,end_date=articles[-1].date)
articles = [a for a in articles if a.date >= sp500[0].beg_date and a.date < sp500[-1].end_date]

In [16]:
model = d2v.Doc2Vec.load('300model')

In [30]:
sp500[0]._asdict()['beg_date']

datetime.datetime(2014, 1, 2, 16, 0, 0, 1, tzinfo=<DstTzInfo 'America/New_York' LMT-1 day, 19:04:00 STD>)

In [31]:
Sample = namedtuple('Sample',['beg_date','end_date','data','logreturn'])
def create_sample(record):
    articles_that_day = [a for a in articles if
                         a.date >= record.beg_date and a.date <= record.end_date]
    ids = [str(a.id) for a in articles_that_day]
    data = np.mean(model.docvecs[ids],axis=0)
    return Sample(data=data,**record._asdict())
samples = [create_sample(record) for record in sp500]

In [17]:
Sample = namedtuple('Sample',['date','data','logreturn'])

def match(record,article_date):
    return article_date > record.beg_date and article_date <= record.end_date

def group_date(article_date):
    if article_date.time

def create_sample(date,daily_articles):
    ids = [str(a.id) for a in daily_articles]
    data = np.mean(model.docvecs[ids],axis=0)                                     # Aggregate method.
    # logreturn = next(record.logreturn for record in sp500 if record.date >= date) # sp500 already sorted!
    logreturn = next(record.logreturn for record in sp500 if match(record,date))
    return Sample(date=date,data=data,logreturn=logreturn)

samples = Namedtuples([create_sample(*grp) for grp in groupby(articles,key=lambda article: article.date)])
# shuffle(samples)


In [105]:
train_size = int(0.8 * len(samples))
train_samples = samples[:train_size]
test_samples = samples[train_size:]

From this point on we want to compute the $\hat q$ for our problem, and see how it compares with a test set. 

In [106]:
β = 1
r_threshold = 60
u = ut.LinearPlateauUtility(β,r_threshold)

In [128]:
def add_bias(X):
    n,p = X.shape
    bias = np.ones(n)
    res = np.empty(shape=(n,p+1))
    res[:,1:] = X
    res[:,0] = bias
    return res

x = add_bias(np.array(train_samples.datas))
r = np.array(train_samples.logreturns)

x_test = add_bias(np.array(test_samples.datas))
r_test = np.array(test_samples.logreturns)

In [0]:
2+3

In [138]:
problem = pr.Problem(x,r,λ=1,u=u)
problem.solver = cvx.CVXOPT
problem.solve()

KeyboardInterrupt: 

In [130]:
problem.outsample_cost(x_test,r_test)

0.045054046762413688

In [131]:
np.linalg.norm(problem.q)

303.34697075381393

In [132]:
problem.q

array([  2.93751051e+02,  -1.96597244e-01,  -6.50896597e-01,
        -1.89605620e+00,   6.53254553e+00,   7.50113317e+00,
        -2.88062032e+00,   1.91028488e+00,  -1.74966519e-01,
         7.92277860e+00,   6.36992258e-01,  -2.88137297e+00,
        -7.73442808e+00,   2.48595134e+00,  -9.97884146e+00,
        -6.94288716e-02,   1.82848104e+00,   8.64331864e-01,
        -6.13194719e-01,  -4.67596980e+00,  -4.06346578e+00,
         1.43182442e+00,  -5.80246286e+00,  -3.42225554e+00,
         4.19559403e+00,   8.69184576e-02,   1.48193546e+00,
         3.99227405e+00,  -1.33997772e+00,  -4.02666588e+00,
         6.43198381e+00,   1.29840965e+00,  -5.18256348e+00,
         4.84210189e+00,  -2.52523795e+00,   5.38527929e+00,
         2.65678202e+00,  -3.35771677e+00,  -2.75850241e+00,
        -2.62686961e+00,   2.17131733e+00,   7.45715303e+00,
        -1.69901721e+00,   2.04486786e+00,   7.48859202e+00,
         1.12060230e+00,   1.18295730e+00,  -4.02402646e+00,
        -3.29064085e+00,