In [3]:
! pip install finance-datareader

Collecting finance-datareader
  Downloading finance_datareader-0.9.31-py3-none-any.whl (17 kB)
Collecting requests-file
  Downloading requests_file-1.5.1-py2.py3-none-any.whl (3.7 kB)
Installing collected packages: requests-file, finance-datareader
Successfully installed finance-datareader-0.9.31 requests-file-1.5.1


In [5]:
import FinanceDataReader as fdr
from tqdm import tqdm

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 
import os
import itertools
import random

from sklearn.preprocessing import MinMaxScaler, StandardScaler

import tensorflow as tf
from tensorflow import keras
from keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.layers import Dense,  GRU, Concatenate, Dropout, LSTM, SimpleRNN, TimeDistributed
from tensorflow.keras import Sequential, Input, Model, layers, optimizers
from tensorflow.keras.optimizers import SGD, Adam

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
path = '/content/drive/MyDrive/주식 종가 예측/open/'
list_name = 'stock_list.csv'
stock_list = pd.read_csv(os.path.join(path,list_name))
stock_list['종목코드'] = stock_list['종목코드'].apply(lambda x : str(x).zfill(6))
stock_list

Unnamed: 0,종목명,종목코드,상장시장
0,삼성전자,005930,KOSPI
1,SK하이닉스,000660,KOSPI
2,NAVER,035420,KOSPI
3,카카오,035720,KOSPI
4,삼성바이오로직스,207940,KOSPI
...,...,...,...
365,맘스터치,220630,KOSDAQ
366,다날,064260,KOSDAQ
367,제이시스메디칼,287410,KOSDAQ
368,크리스에프앤씨,110790,KOSDAQ


In [8]:
start_date = '20210104'
end_date = '20211105'

start_weekday = pd.to_datetime(start_date).weekday()
max_weeknum = pd.to_datetime(end_date).strftime('%V')
business_days = pd.DataFrame(pd.date_range(start_date,end_date,freq='B'), columns = ['Date'])

stock_code = stock_list.loc[1,'종목코드'] # 삼성전자

stock_price = fdr.DataReader(stock_code, start = start_date, end = end_date)[['Close']].reset_index()
stock_price = pd.merge(business_days, stock_price, how = 'outer')
stock_price['weekday'] = stock_price.Date.apply(lambda x : x.weekday())
stock_price['weeknum'] = stock_price.Date.apply(lambda x : x.strftime('%V'))
stock_price.Close = stock_price.Close.ffill()
stock_price = pd.pivot_table(data = stock_price, values = 'Close', columns = 'weekday', index = 'weeknum')

stock_price.head()

weekday,0,1,2,3,4
weeknum,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1,126000.0,130500.0,131000.0,134500.0,138000.0
2,133000.0,129000.0,133000.0,130500.0,127500.0
3,130000.0,130500.0,130500.0,131500.0,128500.0
4,135000.0,129000.0,128500.0,123000.0,122500.0
5,125000.0,130000.0,130000.0,125000.0,127500.0


In [9]:
X_train = stock_price.iloc[0:-12].to_numpy()
X_val = stock_price.iloc[-12:-2].to_numpy()
X_test = stock_price.iloc[-2].to_numpy()

In [10]:
y_train = stock_price.iloc[1:-11].to_numpy()
y_val = stock_price.iloc[-11:-1].to_numpy()
y_test = stock_price.iloc[-1].to_numpy()

In [11]:
scaler = MinMaxScaler()

X_train_scaled = scaler.fit_transform(X_train.reshape(-1, 1)).reshape(32, 5)
X_val_scaled = scaler.fit_transform(X_val.reshape(-1, 1)).reshape(10, 5)
X_test_scaled = scaler.fit_transform(X_test.reshape(-1, 1)).reshape(1, 5)

y_train_scaled = scaler.fit_transform(y_train.reshape(-1, 1)).reshape(32, 5)
y_val_scaled = scaler.fit_transform(y_val.reshape(-1, 1)).reshape(10, 5)
y_test_scaled = scaler.fit_transform(y_test.reshape(-1, 1)).reshape(1, 5)

In [12]:
X_train_scaled = np.reshape(X_train_scaled, (X_train_scaled.shape[0], X_train_scaled.shape[1]))
X_val_scaled = np.reshape(X_val_scaled, (X_val_scaled.shape[0], X_val_scaled.shape[1]))
X_test_scaled = np.reshape(X_test_scaled, (X_test_scaled.shape[0], X_test_scaled.shape[1]))

In [13]:
def LSTM_model(X_train, y_train, X_val, y_val):

  model = Sequential()
  model.add(LSTM(units = 16, input_shape = (X_train.shape[1], 1), activation = 'tanh', return_sequences = True))
  model.add(LSTM(units = 16, activation = 'tanh', return_sequences = True))
  model.add(LSTM(units = 16, activation = 'tanh', return_sequences = True))
  model.add(TimeDistributed(Dense(units = 1)))

  model.compile(optimizer = Adam(learning_rate = 0.01), loss = 'mean_absolute_error', metrics=['mae']) 

  es = EarlyStopping(monitor ='val_loss', mode ='min', verbose = 1, patience = 5)
  model_save_folder_path = '/content/drive/MyDrive/주식 종가 예측/'
  model_path = model_save_folder_path + 'lstm_best_model.h5'
  mc = ModelCheckpoint(model_path, monitor ='val_loss', mode ='min', verbose = 1, save_best_only = True)

  model.fit(X_train, y_train, epochs = 100, batch_size = 2, verbose = 1, shuffle = False, callbacks = [es, mc], validation_data=(X_val, y_val))

  return model

In [14]:
model = LSTM_model(X_train_scaled, y_train_scaled, X_val_scaled, y_val_scaled)

Epoch 1/100
Epoch 00001: val_loss improved from inf to 0.37193, saving model to /content/drive/MyDrive/주식 종가 예측/lstm_best_model.h5
Epoch 2/100
Epoch 00002: val_loss improved from 0.37193 to 0.27146, saving model to /content/drive/MyDrive/주식 종가 예측/lstm_best_model.h5
Epoch 3/100
Epoch 00003: val_loss improved from 0.27146 to 0.26621, saving model to /content/drive/MyDrive/주식 종가 예측/lstm_best_model.h5
Epoch 4/100
Epoch 00004: val_loss did not improve from 0.26621
Epoch 5/100
Epoch 00005: val_loss improved from 0.26621 to 0.25732, saving model to /content/drive/MyDrive/주식 종가 예측/lstm_best_model.h5
Epoch 6/100
Epoch 00006: val_loss did not improve from 0.25732
Epoch 7/100
Epoch 00007: val_loss improved from 0.25732 to 0.25385, saving model to /content/drive/MyDrive/주식 종가 예측/lstm_best_model.h5
Epoch 8/100
Epoch 00008: val_loss did not improve from 0.25385
Epoch 9/100
Epoch 00009: val_loss improved from 0.25385 to 0.25055, saving model to /content/drive/MyDrive/주식 종가 예측/lstm_best_model.h5
Epoch

In [15]:
model = keras.models.load_model("/content/drive/MyDrive/주식 종가 예측/lstm_best_model.h5")

In [51]:
y_pred = model.predict(X_test_scaled)

In [52]:
y_pred = y_pred.reshape(1,5)

In [53]:
y_pred = scaler.inverse_transform(y_pred)

In [54]:
y_true = scaler.inverse_transform(y_test_scaled)

In [20]:
np.mean(np.abs(y_true - y_pred) / y_true) * 100

0.6907755684417394

In [55]:
submission_name = 'sample_submission.csv'
submission = pd.read_csv(os.path.join(path,submission_name))

In [104]:
def LSTM_model(X_train, y_train, X_val, y_val):

  model = Sequential()
  model.add(LSTM(units = 16, input_shape = (X_train.shape[1], 1), activation = 'tanh', return_sequences = True))
  model.add(Dropout(rate = 0.2))
  model.add(LSTM(units = 16, activation = 'tanh', return_sequences = True))
  model.add(Dropout(rate = 0.2))
  model.add(LSTM(units = 16, activation = 'tanh', return_sequences = True))
  model.add(Dropout(rate = 0.2))
  model.add(TimeDistributed(Dense(units = 1)))

  model.compile(optimizer = Adam(learning_rate = 0.01), loss = 'mean_absolute_error', metrics=['mae']) 

  es = EarlyStopping(monitor ='val_loss', mode ='min', verbose = 1, patience = 5)
  model_save_folder_path = '/content/drive/MyDrive/주식 종가 예측/lstm'
  model_path = model_save_folder_path + f'{code}.h5'
  mc = ModelCheckpoint(model_path, monitor ='val_loss', mode ='min', verbose = 100, save_best_only = True)

  model.fit(X_train, y_train, epochs = 500, batch_size = 2, verbose = 1, shuffle = False, callbacks = [es, mc], validation_data=(X_val, y_val))

  return model

In [105]:
submission_name = 'sample_submission.csv'
submission = pd.read_csv(os.path.join(path,submission_name))

In [106]:
for code in tqdm(stock_list['종목코드'].values):
  
  # 해당 종목의 데이터 가져온 후 주 단위 데이터로 나타내기
  stock_price = fdr.DataReader(code, start = start_date, end = end_date)[['Close']].reset_index()
  stock_price = pd.merge(business_days, stock_price, how = 'outer')
  stock_price['weekday'] = stock_price.Date.apply(lambda x : x.weekday())
  stock_price['weeknum'] = stock_price.Date.apply(lambda x : x.strftime('%V'))
  stock_price.Close = stock_price.Close.ffill()
  stock_price = pd.pivot_table(data = stock_price, values = 'Close', columns = 'weekday', index = 'weeknum')

  X_train = stock_price.iloc[0:-12].to_numpy()
  X_val = stock_price.iloc[-12:-2].to_numpy()
  X_test = stock_price.iloc[-2].to_numpy()

  y_train = stock_price.iloc[1:-11].to_numpy()
  y_val = stock_price.iloc[-11:-1].to_numpy()
  y_test = stock_price.iloc[-1].to_numpy()

  # 스케일링 시행
  scaler = MinMaxScaler()

  X_train_scaled = scaler.fit_transform(X_train.reshape(-1, 1)).reshape(32, 5)
  X_val_scaled = scaler.fit_transform(X_val.reshape(-1, 1)).reshape(10, 5)
  X_test_scaled = scaler.fit_transform(X_test.reshape(-1, 1)).reshape(1, 5)

  y_train_scaled = scaler.fit_transform(y_train.reshape(-1, 1)).reshape(32, 5)
  y_val_scaled = scaler.fit_transform(y_val.reshape(-1, 1)).reshape(10, 5)
  y_test_scaled = scaler.fit_transform(y_test.reshape(-1, 1)).reshape(1, 5)

  X_train_scaled = np.reshape(X_train_scaled, (X_train_scaled.shape[0], X_train_scaled.shape[1]))
  X_val_scaled = np.reshape(X_val_scaled, (X_val_scaled.shape[0], X_val_scaled.shape[1]))
  X_test_scaled = np.reshape(X_test_scaled, (X_test_scaled.shape[0], X_test_scaled.shape[1]))
  
  # 모델 학습
  model = LSTM_model(X_train_scaled, y_train_scaled, X_val_scaled, y_val_scaled)
  model = keras.models.load_model(f"/content/drive/MyDrive/주식 종가 예측/lstm{code}.h5")

  # 예측
  y_pred = model.predict(X_test_scaled)
  y_pred = y_pred.reshape(1,5)
  y_pred = scaler.inverse_transform(y_pred)

  print(y_pred)
    
  for i in range (0, 5):
    submission.loc[i, code] = y_pred[0][i]
    submission.loc[i + 5, code] = y_pred[0][i] # 임시로 채워주기 위한 용도. 최종 예측 시에는 바꿔주어야 함!

Output hidden; open in https://colab.research.google.com to view.

In [107]:
submission.isna().sum().sum()

0

In [109]:
submission.to_csv("/content/drive/MyDrive/주식 종가 예측/submission_lstm3.csv", index = False)

In [110]:
submission

Unnamed: 0,Day,000060,000080,000100,000120,000150,000240,000250,000270,000660,000670,000720,000810,000880,000990,001230,001440,001450,001740,002380,002790,003000,003090,003380,003410,003490,003670,003800,004000,004020,004170,004370,004490,004800,004990,005250,005290,005300,005380,005385,...,272290,273130,278280,278530,282330,285130,287410,290510,290650,292150,293490,293780,294090,294870,298000,298020,298050,298380,299030,299660,299900,307950,314130,316140,319400,319660,321550,323990,326030,330590,330860,336260,336370,347860,348150,348210,352820,357780,363280,950130
0,2021-11-01,28704.029297,34302.140625,60774.695312,146845.484375,111234.132812,17458.220703,49625.878906,85911.820312,106385.132812,693920.0625,51551.265625,232377.96875,32554.511719,57832.320312,16738.251953,2307.753662,26100.445312,5193.036133,325437.625,50854.253906,13532.287109,33258.003906,9223.833008,7995.249023,30694.681641,147419.4375,47462.820312,89884.179688,42975.1875,237255.21875,284290.4375,80531.773438,101149.078125,33661.847656,29205.464844,32957.75,143545.328125,211419.703125,98572.484375,...,39395.90625,107282.945312,281292.90625,12888.289062,166394.078125,177804.515625,7135.273926,7425.48291,34662.75,13442.373047,87294.617188,49796.53125,55435.152344,25240.892578,269757.96875,589742.875,669157.6875,22413.632812,51699.476562,72314.382812,29166.59375,119684.195312,36824.519531,13404.897461,3328.766602,39126.398438,18881.566406,64567.632812,96447.601562,5688.470703,48876.789062,52844.050781,84341.960938,35094.394531,25330.822266,50912.808594,358977.46875,252184.4375,26432.646484,17745.361328
1,2021-11-02,28963.724609,34456.144531,61210.578125,147639.140625,120794.476562,17615.291016,50105.65625,86059.71875,106229.609375,695123.9375,51447.648438,232290.609375,32648.642578,58163.316406,16910.486328,2330.95752,26239.214844,5216.418457,324640.875,50896.792969,13668.485352,33207.1875,9253.379883,7984.702148,30720.070312,147720.703125,47511.03125,88637.84375,43472.277344,237035.421875,283303.59375,80710.015625,101646.609375,33713.929688,29688.085938,33114.363281,143993.453125,212085.6875,98863.273438,...,39471.136719,107156.164062,282753.25,12902.09668,166552.109375,176327.984375,7074.989258,7453.399902,34735.664062,13466.375,88899.804688,49889.605469,55730.878906,25394.972656,268116.90625,593863.4375,673835.125,22520.066406,51938.265625,77680.914062,29788.964844,119488.34375,38892.527344,13405.113281,3337.875244,39176.535156,18852.451172,66213.71875,97959.710938,5674.993164,48799.851562,52924.539062,84726.679688,35211.769531,25589.96875,50969.246094,361372.46875,253801.1875,26657.714844,17832.925781
2,2021-11-03,29196.884766,34504.226562,61119.773438,147754.703125,124673.828125,17509.740234,49938.078125,86101.703125,106131.296875,694010.875,51455.375,232563.859375,32650.679688,58475.90625,16905.728516,2326.994141,26321.220703,5222.967773,324963.21875,50485.417969,13666.895508,33000.945312,9310.296875,7987.53125,30471.552734,147608.859375,47551.34375,86597.40625,43288.519531,235950.359375,283155.15625,80594.34375,101199.195312,33565.898438,29516.167969,33482.558594,144375.53125,213084.53125,99643.78125,...,39821.335938,107113.140625,278475.34375,12902.824219,166636.0625,175815.546875,7391.284668,7455.554199,34875.878906,13482.244141,95975.9375,49975.445312,55896.597656,25470.9375,266867.5,590592.4375,676121.4375,22598.984375,52074.570312,79204.757812,31743.144531,118579.007812,38200.125,13441.694336,3351.780029,39195.945312,18879.042969,66511.453125,97371.859375,5674.041016,48680.667969,53091.375,84563.507812,35542.921875,25573.373047,51558.691406,373327.34375,258959.78125,26783.949219,17926.666016
3,2021-11-04,29360.066406,34496.28125,60898.222656,147321.21875,121688.3125,17044.255859,49940.074219,86114.398438,106386.25,691479.9375,51384.488281,232520.6875,32607.919922,58844.097656,16920.056641,2324.156494,26095.962891,5216.301758,324775.96875,50071.953125,13643.571289,33108.691406,9332.359375,8012.536133,30361.628906,146810.15625,47584.34375,84120.359375,42971.34375,234993.484375,283111.90625,80391.914062,100455.828125,33386.148438,29737.298828,33911.320312,144031.90625,213619.953125,99310.710938,...,40767.417969,107041.9375,273917.4375,12897.216797,166574.328125,175474.625,7246.680664,7434.592285,34851.988281,13500.085938,94046.84375,50054.777344,55989.15625,25289.71875,264613.40625,585577.75,677151.5,22665.03125,52116.628906,77226.28125,31533.128906,117890.648438,38779.425781,13462.295898,3378.971191,39258.902344,18879.146484,66474.71875,96586.945312,5659.055664,48899.515625,53082.699219,83481.820312,35622.675781,25480.732422,52232.378906,368357.78125,258173.71875,26833.1875,17996.013672
4,2021-11-05,29418.212891,34453.789062,60644.917969,146485.453125,118303.609375,16895.970703,49798.625,86114.734375,106526.585938,688765.375,51223.535156,231688.15625,32417.974609,58715.234375,16915.404297,2312.398438,25841.443359,5198.772461,323279.5,49700.996094,13568.760742,32677.484375,9335.970703,8074.314941,30295.746094,145987.546875,47611.269531,85574.835938,42596.890625,234252.859375,283016.5,79979.914062,99701.882812,33234.851562,29537.570312,33916.453125,141906.703125,213308.09375,98631.28125,...,40047.457031,106902.273438,271301.5,12873.130859,166451.734375,174401.015625,7059.459961,7394.550781,34752.027344,13481.767578,94648.0625,50129.984375,56012.867188,25109.578125,260685.4375,576688.0625,677491.25,22711.53125,52004.96875,73877.21875,31715.875,117130.3125,38256.9375,13467.110352,3365.054932,39333.582031,18842.326172,66143.65625,96004.351562,5640.241211,49297.914062,53024.050781,83157.953125,35534.308594,25382.191406,52127.488281,364534.09375,257329.296875,26738.369141,18036.457031
5,2021-11-29,28704.029297,34302.140625,60774.695312,146845.484375,111234.132812,17458.220703,49625.878906,85911.820312,106385.132812,693920.0625,51551.265625,232377.96875,32554.511719,57832.320312,16738.251953,2307.753662,26100.445312,5193.036133,325437.625,50854.253906,13532.287109,33258.003906,9223.833008,7995.249023,30694.681641,147419.4375,47462.820312,89884.179688,42975.1875,237255.21875,284290.4375,80531.773438,101149.078125,33661.847656,29205.464844,32957.75,143545.328125,211419.703125,98572.484375,...,39395.90625,107282.945312,281292.90625,12888.289062,166394.078125,177804.515625,7135.273926,7425.48291,34662.75,13442.373047,87294.617188,49796.53125,55435.152344,25240.892578,269757.96875,589742.875,669157.6875,22413.632812,51699.476562,72314.382812,29166.59375,119684.195312,36824.519531,13404.897461,3328.766602,39126.398438,18881.566406,64567.632812,96447.601562,5688.470703,48876.789062,52844.050781,84341.960938,35094.394531,25330.822266,50912.808594,358977.46875,252184.4375,26432.646484,17745.361328
6,2021-11-30,28963.724609,34456.144531,61210.578125,147639.140625,120794.476562,17615.291016,50105.65625,86059.71875,106229.609375,695123.9375,51447.648438,232290.609375,32648.642578,58163.316406,16910.486328,2330.95752,26239.214844,5216.418457,324640.875,50896.792969,13668.485352,33207.1875,9253.379883,7984.702148,30720.070312,147720.703125,47511.03125,88637.84375,43472.277344,237035.421875,283303.59375,80710.015625,101646.609375,33713.929688,29688.085938,33114.363281,143993.453125,212085.6875,98863.273438,...,39471.136719,107156.164062,282753.25,12902.09668,166552.109375,176327.984375,7074.989258,7453.399902,34735.664062,13466.375,88899.804688,49889.605469,55730.878906,25394.972656,268116.90625,593863.4375,673835.125,22520.066406,51938.265625,77680.914062,29788.964844,119488.34375,38892.527344,13405.113281,3337.875244,39176.535156,18852.451172,66213.71875,97959.710938,5674.993164,48799.851562,52924.539062,84726.679688,35211.769531,25589.96875,50969.246094,361372.46875,253801.1875,26657.714844,17832.925781
7,2021-12-01,29196.884766,34504.226562,61119.773438,147754.703125,124673.828125,17509.740234,49938.078125,86101.703125,106131.296875,694010.875,51455.375,232563.859375,32650.679688,58475.90625,16905.728516,2326.994141,26321.220703,5222.967773,324963.21875,50485.417969,13666.895508,33000.945312,9310.296875,7987.53125,30471.552734,147608.859375,47551.34375,86597.40625,43288.519531,235950.359375,283155.15625,80594.34375,101199.195312,33565.898438,29516.167969,33482.558594,144375.53125,213084.53125,99643.78125,...,39821.335938,107113.140625,278475.34375,12902.824219,166636.0625,175815.546875,7391.284668,7455.554199,34875.878906,13482.244141,95975.9375,49975.445312,55896.597656,25470.9375,266867.5,590592.4375,676121.4375,22598.984375,52074.570312,79204.757812,31743.144531,118579.007812,38200.125,13441.694336,3351.780029,39195.945312,18879.042969,66511.453125,97371.859375,5674.041016,48680.667969,53091.375,84563.507812,35542.921875,25573.373047,51558.691406,373327.34375,258959.78125,26783.949219,17926.666016
8,2021-12-02,29360.066406,34496.28125,60898.222656,147321.21875,121688.3125,17044.255859,49940.074219,86114.398438,106386.25,691479.9375,51384.488281,232520.6875,32607.919922,58844.097656,16920.056641,2324.156494,26095.962891,5216.301758,324775.96875,50071.953125,13643.571289,33108.691406,9332.359375,8012.536133,30361.628906,146810.15625,47584.34375,84120.359375,42971.34375,234993.484375,283111.90625,80391.914062,100455.828125,33386.148438,29737.298828,33911.320312,144031.90625,213619.953125,99310.710938,...,40767.417969,107041.9375,273917.4375,12897.216797,166574.328125,175474.625,7246.680664,7434.592285,34851.988281,13500.085938,94046.84375,50054.777344,55989.15625,25289.71875,264613.40625,585577.75,677151.5,22665.03125,52116.628906,77226.28125,31533.128906,117890.648438,38779.425781,13462.295898,3378.971191,39258.902344,18879.146484,66474.71875,96586.945312,5659.055664,48899.515625,53082.699219,83481.820312,35622.675781,25480.732422,52232.378906,368357.78125,258173.71875,26833.1875,17996.013672
9,2021-12-03,29418.212891,34453.789062,60644.917969,146485.453125,118303.609375,16895.970703,49798.625,86114.734375,106526.585938,688765.375,51223.535156,231688.15625,32417.974609,58715.234375,16915.404297,2312.398438,25841.443359,5198.772461,323279.5,49700.996094,13568.760742,32677.484375,9335.970703,8074.314941,30295.746094,145987.546875,47611.269531,85574.835938,42596.890625,234252.859375,283016.5,79979.914062,99701.882812,33234.851562,29537.570312,33916.453125,141906.703125,213308.09375,98631.28125,...,40047.457031,106902.273438,271301.5,12873.130859,166451.734375,174401.015625,7059.459961,7394.550781,34752.027344,13481.767578,94648.0625,50129.984375,56012.867188,25109.578125,260685.4375,576688.0625,677491.25,22711.53125,52004.96875,73877.21875,31715.875,117130.3125,38256.9375,13467.110352,3365.054932,39333.582031,18842.326172,66143.65625,96004.351562,5640.241211,49297.914062,53024.050781,83157.953125,35534.308594,25382.191406,52127.488281,364534.09375,257329.296875,26738.369141,18036.457031
