In [1]:
import numpy as np
import pandas as pd
import random
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
data = pd.read_csv('data.csv') #Three stocks (R,X_s,X_b) Without predictors

In [3]:
#Initialize Neural Network and set-up the placeholders
tf.reset_default_graph()
NN_input = tf.placeholder(shape=[1,2],dtype=tf.float32)
NN_weights = tf.Variable(tf.random_uniform([2,10],0,0.01))
Q_FA = tf.matmul(NN_input,NN_weights)
A_Max = tf.argmax(Q_FA,1)

# Calculate loss for the NN from the Q values
Q_Next = tf.placeholder(shape=[1,10],dtype=tf.float32)
loss = tf.reduce_sum(tf.square(Q_Next - Q_FA))
trainer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
updateModel = trainer.minimize(loss)

#Define Action Matrix (Now discrete case) 
A = np.linspace(0,1,10) # portfolio weights of stocks (1-weight) is the weight in the bonds

In [5]:
# initialize
init = tf.global_variables_initializer()
gamma = 1 
epsilon = 0.1
train_data = 241 # also an expanding window
jList = []
TWlistTrain = []
TWlist = []
Index = []
MWeights = []
# data parsing
dates = data['Date']
mdata = data[['r','xs','xb']]
mdata.index = pd.DatetimeIndex(dates)
n = data.size/4-4
periods = 60
epochs = 20

with tf.Session() as sess:
    for i in range(241,int(n-periods-1)):
        OptimalWeights = np.zeros(periods-1)
        currentK = 0
        print(i)
        
        while currentK < periods - 1:
            #Initilization
            NN_data = mdata[i-200:i+currentK]  # rolling window
            # NN_data = mdata[0:i+currentK]   #Expanding window
            sess.run(init) # initialize the Neural Network again
            rAll = 0
            currentEpoch = 0
            
            while currentEpoch < epochs:
                #Training of the Q-Network for the data available (with Neural Nets) 
                for j in range(0,len(NN_data)):
                    s = NN_data.iloc[j,1:4].values.reshape(1,2)
                    #Choose an action by greedily (with e chance of random action) from the Q-network
                    a_int,allQ = sess.run([A_Max,Q_FA],feed_dict={NN_input:s})
                    a = A[a_int-1]  # -1 because the output neurons are labeled 1 till 101 and it will be an index
                    if np.random.rand(1) < epsilon:
                        a = random.choice(A)

                    #Get new state and reward from environment
                    s1 = data.iloc[j+i,2:4].values.reshape(1,2)
                    r = (a*s1[0][0] + (1-a)*s1[0][1]) #reward: this is now the wealth gained from this step, but could be other rewards like utility
                    Q = sess.run(Q_FA,feed_dict={NN_input:s1})

                    #Obtain maxQ' and set our target value for chosen action.
                    print('Q:')
                    print(Q)
                    print('allQ:')
                    print(allQ)
                    Q1 = np.max(Q)
                    print('Q1:')
                    print(Q1)
                    targetQ = allQ
                    targetQ[0,a_int] = r + gamma*Q1

                    #Train the neural network using target and predicted Q values
                    _,W1 = sess.run([updateModel,NN_input],feed_dict={NN_input:s,Q_Next:targetQ})
                    rAll += r
                    s = s1
                    if currentEpoch  > 5:
                        # decrease amount of random actions over time in order to improve exploitation rather than exploration
                        # only increase exploitation when a good action has been found (otherwise one exploits a bad solution)
                        e = 1./((i/50) + 10)
                        break
                currentEpoch += 1
        
            # After training now calculate the optimal weights for the K=60 periods to come
            s = mdata.iloc[i+periods,1:4].values.reshape(1,2)
            a_int,allQ = sess.run([A_Max,Q_FA],feed_dict={NN_input:s})
            aOpt = A[a_int-1]
            OptimalWeights[currentK] = aOpt
            currentK += 1
        
        # For insight purposes
        MWeights.append(np.mean(OptimalWeights))
        TWlist.append(np.exp(sum(OptimalWeights*mdata[i+1:i+currentK+1]['xs'] + (1-OptimalWeights)*mdata[i+1:i+currentK+1]['xb'])))
        Index.append(i)
        
        print('Writing away results')
        df = pd.DataFrame({'index date':Index,'TW':TWlist, 'Mean Weights Xs':MWeights})
        df.to_excel('Results_NN_g10_e20.xlsx', sheet_name='sheet1', index=False)
    # close session
plt.plot(MWeights)
plt.plot(TWlist)
        

241
Q:
[[ -5.75205020e-04  -7.75923472e-05  -3.77528573e-04  -2.56071915e-04
   -2.16473461e-04  -4.86345030e-04  -3.16921156e-04  -2.19412323e-04
   -1.33099966e-04  -5.52457874e-04]]
allQ:
[[  5.95199235e-05   1.62809811e-05   3.55584198e-05   8.52110134e-07
    4.43023237e-05   4.94430897e-05   4.76137538e-05   4.35118418e-05
    7.93888830e-06   5.53595382e-05]]
Q1:
-7.75923e-05
Q:
[[-0.00077125 -0.00013672 -0.00049371 -0.00024347 -0.00037704 -0.0006503
  -0.00048424 -0.00037669 -0.000156   -0.00073554]]
allQ:
[[ -5.30241407e-04  -1.36516639e-04  -3.21832282e-04  -3.58189682e-05
   -3.72083770e-04  -4.43107885e-04  -4.09483619e-04  -3.66218563e-04
   -7.74071523e-05  -4.97031608e-04]]
Q1:
-0.000136722
Q:
[[ -3.90181958e-04  -1.26094135e-04  -2.26211734e-04   3.68030451e-05
   -3.41854437e-04  -3.23666085e-04  -3.47486988e-04  -3.34146956e-04
   -3.89662637e-05  -3.60525300e-04]]
allQ:
[[-0.00065287 -0.00015634 -0.00040112 -0.00010056 -0.00042695 -0.00054668
  -0.00048303 -0.0004212

allQ:
[[  5.03914955e-04   1.62244891e-04   3.57161014e-04   6.64056715e-05
    4.41429700e-04   4.96287423e-04   4.76607063e-04   4.34936257e-04
    9.87039093e-05   5.97604667e-04]]
Q1:
0.000614192
Q:
[[ -4.68744656e-05   1.67365761e-06  -3.53133801e-05  -4.71923886e-05
    3.25904603e-06  -4.18924428e-05  -1.20568420e-05   2.09569112e-06
   -1.69524792e-05  -4.52450477e-05]]
allQ:
[[  2.69616023e-04   5.41859881e-05   1.95163098e-04   1.15335264e-04
    1.49946078e-04   2.57221749e-04   1.92201405e-04   1.49910687e-04
    6.79311779e-05   3.14212055e-04]]
Q1:
3.25905e-06
Q:
[[ -5.97164326e-05  -5.56453779e-05  -3.77853430e-05   8.12238577e-05
   -1.48585736e-04  -6.80932426e-05  -1.26593222e-04  -1.43977129e-04
    5.18319575e-06  -1.06481049e-04]]
allQ:
[[ -4.42306977e-04  -1.35105904e-04  -3.14405363e-04  -7.61532283e-05
   -3.68154841e-04  -4.33750858e-04  -4.04277671e-04  -3.63225176e-04
   -9.00215309e-05  -5.54470753e-04]]
Q1:
8.12239e-05
Q:
[[ 0.00046014  0.00010112  0.000332

Q:
[[ -2.85376504e-04  -7.99061600e-05  -2.03759308e-04  -6.91884416e-05
   -2.19915091e-04  -2.77960557e-04  -2.46641721e-04  -2.15641470e-04
   -3.58060061e-06  -4.36298025e-04]]
allQ:
[[ -1.22982869e-03  -2.86588474e-04  -8.85164074e-04  -4.41230368e-04
   -7.87473284e-04  -1.18252193e-03  -9.49373934e-04  -7.79777358e-04
   -7.45746220e-05  -1.78114371e-03]]
Q1:
-3.5806e-06
Q:
[[  1.10482296e-03   3.19396146e-04   7.87617464e-04   2.42988346e-04
    8.79251689e-04   1.07878074e-03   9.74599912e-04   8.60842702e-04
    7.20611715e-05   1.70633465e-03]]
allQ:
[[ -5.39544621e-04  -1.80021190e-04  -3.81695834e-04  -5.91206299e-05
   -4.96082765e-04  -5.33214537e-04  -5.23200375e-04  -4.82630450e-04
   -1.44299138e-05  -8.74532969e-04]]
Q1:
0.00170633
Q:
[[  3.57141107e-04   1.11304522e-04   2.53617007e-04   5.85921916e-05
    3.06576374e-04   3.50863615e-04   3.30881187e-04   2.99130072e-04
   -4.40665317e-05   5.65403956e-04]]
allQ:
[[ -4.37237206e-04  -1.45576414e-04  -3.09357274e-04

  -0.00082141 -0.00059602 -0.00021003 -0.00165408]]
allQ:
[[ -2.12528059e-04  -6.26612746e-05  -1.51354485e-04  -4.68050966e-05
   -1.74817484e-04  -2.07872072e-04  -1.90049323e-04  -1.68979313e-04
   -3.38734571e-06  -3.36049270e-04]]
Q1:
-0.000210031
Q:
[[ 0.00072457  0.00013935  0.00052472  0.00033834  0.0003821   0.00068858
   0.00050218  0.0003843   0.00012016  0.00098133]]
allQ:
[[  2.60806148e-04   7.84417789e-05   1.85555153e-04   5.37157175e-05
    2.18982328e-04   2.55511259e-04   2.36255510e-04   2.11357677e-04
    1.39116592e-05   4.15808317e-04]]
Q1:
0.000981327
Q:
[[  1.27234060e-04   4.66244746e-05   8.95438425e-05   6.09220842e-06
    1.30896442e-04   1.26914063e-04   1.31655877e-04   1.24689177e-04
   -1.88211379e-06   2.27054406e-04]]
allQ:
[[  3.89069406e-04   1.15775467e-04   2.76956038e-04   8.31260477e-05
    3.23095388e-04   3.80833721e-04   3.50004470e-04   3.12091172e-04
    2.20434449e-05   6.31718198e-04]]
Q1:
0.000227054
Q:
[[  2.36872016e-04   1.41996992e-0

    6.75158954e-05   6.34887547e-04]]
Q1:
4.99221e-05
Q:
[[ -8.85167858e-04  -2.32643026e-04  -6.32519193e-04  -2.28768506e-04
   -6.93554175e-04  -8.61815759e-04  -7.60543277e-04  -6.62853650e-04
    5.10266618e-05  -1.64181914e-03]]
allQ:
[[ -6.39180944e-05  -9.98268588e-05  -3.54189106e-05   1.86792080e-04
   -3.06773814e-04  -8.43728776e-05  -2.19047113e-04  -2.64061440e-04
    1.22074984e-04  -4.06597799e-04]]
Q1:
5.10267e-05
Q:
[[ -7.45477329e-04  -2.87059054e-04  -5.21443668e-04   3.08791641e-07
   -8.65844486e-04  -7.50112173e-04  -8.20664864e-04  -7.95541564e-04
    1.72917615e-04  -1.69887021e-03]]
allQ:
[[  4.14298993e-04   1.16414252e-04   2.95118196e-04   9.90568660e-05
    3.47884896e-04   4.05376282e-04   3.70847672e-04   3.29844625e-04
   -3.46152228e-05   7.94558146e-04]]
Q1:
0.000172918
Q:
[[ 0.00139142  0.00039005  0.00099126  0.0003348   0.00116551  0.0013612
   0.00124366  0.00110538 -0.00011494  0.00263115]]
allQ:
[[ -1.32738787e-05  -3.89557608e-05  -5.10437530e-

4.01853e-05
Q:
[[ -3.36498561e-05   1.60778964e-05  -2.70739474e-05  -7.14387133e-05
    5.25821233e-05  -2.62910507e-05   1.92274983e-05   3.82234648e-05
   -4.56642083e-05   4.69103761e-05]]
allQ:
[[  7.04621547e-04   8.03135990e-05   5.15346008e-04   4.12980007e-04
    2.52715836e-04   6.61216327e-04   4.19568940e-04   2.82732304e-04
   -2.33225510e-05   1.06660393e-03]]
Q1:
5.25821e-05
Q:
[[ -1.56373673e-04  -8.36548406e-06  -1.15538460e-04  -1.16214498e-04
   -2.53771977e-05  -1.44230304e-04  -7.44638819e-05  -3.81776481e-05
   -1.70421063e-05  -1.90695951e-04]]
allQ:
[[ -5.26542390e-05  -2.12657524e-05  -3.66221939e-05   8.78143419e-06
   -6.84414263e-05  -5.34625324e-05  -6.14504315e-05  -6.07769180e-05
    3.75995223e-05  -1.54871348e-04]]
Q1:
-8.36548e-06
Q:
[[ 0.000533    0.00010369  0.00038452  0.00020088  0.00033057  0.00051157
   0.00040204  0.0003254  -0.00011802  0.00101942]]
allQ:
[[ -2.36863212e-04  -1.06383945e-04  -1.63417542e-04   6.73454851e-05
   -3.42687243e-04  

Q:
[[  8.00747439e-05  -2.46110103e-05   6.27130721e-05   1.24777347e-04
   -8.30313802e-05   6.63750980e-05  -1.77836773e-05  -5.41309055e-05
    5.87231843e-05  -6.55634940e-05]]
allQ:
[[  9.63702551e-05   3.25162691e-05   6.78391079e-05   7.24007987e-06
    1.06345928e-04   9.61761907e-05   9.98662144e-05   9.46232976e-05
   -3.30426192e-05   2.87013478e-04]]
Q1:
0.000124777
Q:
[[  8.34340521e-04   1.03508399e-04   6.09204173e-04   4.72147542e-04
    3.29771603e-04   7.85982702e-04   5.17196255e-04   3.61506973e-04
    1.10543697e-05   1.43850036e-03]]
allQ:
[[  3.46821616e-04   9.27735600e-05   2.47122400e-04   8.18320477e-05
    3.02227214e-04   3.39764927e-04   3.12079763e-04   2.78186897e-04
   -7.84416770e-05   8.91925767e-04]]
Q1:
0.0014385
Q:
[[  1.27121530e-04   2.53228536e-05   9.16454446e-05   4.99645394e-05
    8.19551860e-05   1.22258454e-04   9.74436043e-05   7.96413078e-05
   -1.42600547e-05   2.80138513e-04]]
allQ:
[[  5.47345553e-06   1.15211897e-05   2.66404277e-06 

allQ:
[[ -9.73160859e-05  -3.04173227e-05  -6.88245564e-05  -1.27172289e-05
   -9.91746710e-05  -9.65212821e-05  -9.61767655e-05  -8.93871038e-05
    3.09766292e-05  -3.12296062e-04]]
Q1:
0.00128321
Q:
[[ -7.61422212e-04  -7.33880879e-05  -5.57887251e-04  -4.81539231e-04
   -2.30021411e-04  -7.11174041e-04  -4.31605644e-04  -2.77007988e-04
   -3.94735689e-05  -1.22490106e-03]]
allQ:
[[  2.50735320e-04   7.04799968e-05   1.78256203e-04   5.10791506e-05
    2.29353900e-04   2.46576965e-04   2.32417238e-04   2.10059443e-04
   -6.95126873e-05   7.46219303e-04]]
Q1:
-3.94736e-05
Q:
[[ -3.50336311e-04  -8.31611251e-05  -2.50869984e-04  -1.06917360e-04
   -2.69662589e-04  -3.40429222e-04  -2.94882309e-04  -2.54201266e-04
    6.98387739e-05  -9.20621795e-04]]
allQ:
[[ -6.84024199e-05  -1.74527686e-05  -4.88386322e-05  -1.80537209e-05
   -5.66832932e-05  -6.67932763e-05  -5.99452833e-05  -5.27518678e-05
    1.58018138e-05  -1.88634454e-04]]
Q1:
6.98388e-05
Q:
[[ -5.03702497e-04  -1.44994163e-04

    2.68417090e-04  -1.62464811e-03]]
allQ:
[[ -7.57634843e-05  -3.14932322e-06  -5.59345426e-05  -5.72759745e-05
   -8.78677565e-06  -6.95723866e-05  -3.48159083e-05  -1.68960687e-05
   -3.22909091e-06  -9.32068797e-05]]
Q1:
0.000268417
Q:
[[ -5.57609674e-05  -4.56115813e-05  -3.61476414e-05   5.81574459e-05
   -1.50748921e-04  -6.28745984e-05  -1.10027904e-04  -1.23488862e-04
    1.01788275e-04  -4.43640834e-04]]
allQ:
[[ -5.26941003e-05  -7.64508331e-06  -3.82683793e-05  -2.71549197e-05
   -2.43507493e-05  -4.98634399e-05  -3.48846952e-05  -2.57902429e-05
    1.09221946e-05  -1.12237212e-04]]
Q1:
0.000101788
Q:
[[ -5.17727574e-04  -6.18600316e-05  -3.77528573e-04  -2.97495222e-04
   -1.95101675e-04  -4.86345030e-04  -3.16921156e-04  -2.19412323e-04
    7.50230392e-05  -9.87990526e-04]]
allQ:
[[  5.01432260e-05   1.46742068e-05   3.55584198e-05   8.70518943e-06
    4.78180809e-05   4.94430897e-05   4.76137538e-05   4.35118418e-05
   -2.81349130e-05   1.70868225e-04]]
Q1:
7.5023e-05
Q

Q1:
0.000368693
Q:
[[ -2.85197602e-04  -7.09960877e-05  -2.03759308e-04  -7.12535402e-05
   -2.30511811e-04  -2.77960557e-04  -2.46641721e-04  -2.15641470e-04
    1.31464316e-04  -9.60775826e-04]]
allQ:
[[-0.0012293  -0.00024753 -0.00088516 -0.00045284 -0.0007989  -0.00118252
  -0.00094937 -0.00077978  0.00045729 -0.00355004]]
Q1:
0.000131464
Q:
[[ 0.00110409  0.00028502  0.00078762  0.00025051  0.00092623  0.00107878
   0.0009746   0.00086084 -0.00045837  0.00382225]]
allQ:
[[ -5.39084314e-04  -1.63504083e-04  -3.81695834e-04  -6.16688703e-05
   -5.33274666e-04  -5.33214537e-04  -5.23200375e-04  -4.82630450e-04
    2.65401381e-04  -2.11234624e-03]]
Q1:
0.00382225
Q:
[[  3.56869568e-04   1.00282145e-04   2.53617007e-04   6.06470348e-05
    3.26537905e-04   3.50863615e-04   3.30881187e-04   2.99130072e-04
   -2.23614421e-04   1.31792191e-03]]
allQ:
[[ -4.36865521e-04  -1.32187750e-04  -3.09357274e-04  -5.07563091e-05
   -4.31114138e-04  -4.32025234e-04  -4.23383899e-04  -3.90313857e-04


KeyboardInterrupt: 

In [None]:
plt.plot(MWeights)
plt.plot(TWlist)

In [25]:
print(mdata[i+1:i+currentK+1]['xs'])
print(optimalweights.size)
TWlist.append(np.exp(sum(optimalweights*mdata[i+1:i+currentK+1]['xs'] + (1-optimalweights)*mdata[i+1:i+currentK+1]['xb'])))


Date
1974-04-01   -0.052899
1974-05-01   -0.048840
1974-06-01   -0.029525
1974-07-01   -0.077171
1974-08-01   -0.095223
1974-09-01   -0.116435
1974-10-01    0.159368
1974-11-01   -0.047272
1974-12-01   -0.032963
1975-01-01    0.136383
1975-02-01    0.053828
1975-03-01    0.025616
1975-04-01    0.041822
1975-05-01    0.050782
1975-06-01    0.047023
1975-07-01   -0.065493
1975-08-01   -0.028924
1975-09-01   -0.043360
1975-10-01    0.050890
1975-11-01    0.026656
1975-12-01   -0.015531
1976-01-01    0.121953
1976-02-01    0.003247
1976-03-01    0.022632
1976-04-01   -0.014439
1976-05-01   -0.013702
1976-06-01    0.039979
1976-07-01   -0.010515
1976-08-01   -0.005718
1976-09-01    0.020251
1976-10-01   -0.024473
1976-11-01    0.001406
1976-12-01    0.057966
1977-01-01   -0.040187
1977-02-01   -0.019652
1977-03-01   -0.013079
1977-04-01    0.001169
1977-05-01   -0.014918
1977-06-01    0.047184
1977-07-01   -0.017135
1977-08-01   -0.017918
1977-09-01   -0.003268
1977-10-01   -0.044413
1977-1

ValueError: operands could not be broadcast together with shapes (60,) (59,) 