In [7]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x1ee102cc430>

In [13]:
#scaling function for input data
#데이터마다 크기가 다르기 때문에 모두 0-1사이의 값으로 만들어주어 계산을 더 편하게 하기 위함
def minmax_scaler(data):
    numerator = data - np.min(data, 0)
    denominator = np.max(data, 0) - np.mean(data, 0)
    
    return numerator / (denominator + 1e-7)

In [17]:
#input을 위한 dataset
def build_dataset(time_series, seq_length):
    dataX = []
    dataY = []
    for i in range(0, len(time_series) - seq_length):
        _x = time_series[i : i+seq_length, :]
        _y = time_series[i + seq_length, [-1]]
        print(_x, '->', _y)
        dataX.append(_x)
        dataY.append(_y)
    return np.array(dataX), np.array(dataY)

In [18]:
#hyper parameters
seq_length = 7
data_dim = 5 #시가, 종가, 최고가 등 5개 dimension
hidden_dim = 10
output_dim = 1
learning_rate = 0.01
iterations = 500

In [19]:
xy = np.loadtxt('data-02-stock_daily.csv', delimiter = ',')
xy = xy[::-1]

# split train, test
train_size = int(len(xy) * 0.7)
train_set = xy[0:train_size]
test_set = xy[train_size - seq_length:]

train_set = minmax_scaler(train_set)
test_set = minmax_scaler(test_set)

trainX, trainY = build_dataset(train_set, seq_length)
testX, testY = build_dataset(test_set, seq_length)

trainX_tensor = torch.FloatTensor(trainX)
trainY_tensor = torch.FloatTensor(trainY)

testX_tensor = torch.FloatTensor(testX)
testY_tensor = torch.FloatTensor(testY)

[[3.96506638e-01 3.88826810e-01 3.76642468e-01 5.69911869e-04
  3.71378979e-01]
 [3.59748067e-01 3.80351198e-01 4.09777299e-01 3.64962793e-03
  3.80000013e-01]
 [3.90506479e-01 3.83428284e-01 3.99750439e-01 3.17835465e-04
  3.62983138e-01]
 [3.46287741e-01 3.91256291e-01 4.10007599e-01 0.00000000e+00
  4.20400135e-01]
 [5.69434301e-01 5.87656350e-01 4.30061664e-01 1.52561023e-02
  4.19498688e-01]
 [4.06507009e-01 4.92911391e-01 4.41241098e-01 5.57987559e-01
  4.34937505e-01]
 [4.32454313e-01 4.41570720e-01 3.19477749e-01 6.97199492e-01
  2.85056223e-01]] -> [0.25693949]
[[3.59748067e-01 3.80351198e-01 4.09777299e-01 3.64962793e-03
  3.80000013e-01]
 [3.90506479e-01 3.83428284e-01 3.99750439e-01 3.17835465e-04
  3.62983138e-01]
 [3.46287741e-01 3.91256291e-01 4.10007599e-01 0.00000000e+00
  4.20400135e-01]
 [5.69434301e-01 5.87656350e-01 4.30061664e-01 1.52561023e-02
  4.19498688e-01]
 [4.06507009e-01 4.92911391e-01 4.41241098e-01 5.57987559e-01
  4.34937505e-01]
 [4.32454313e-01 4.4157

 [0.24849901 0.24846383 0.28259705 0.12498606 0.27767494]] -> [0.23242857]
[[0.19784803 0.2112679  0.25078762 0.17046941 0.23823248]
 [0.25385072 0.2492198  0.28236639 0.24290301 0.25327679]
 [0.23244422 0.25224303 0.27694984 0.18614198 0.26325034]
 [0.23974194 0.25915317 0.29413375 0.19530441 0.27344874]
 [0.24990476 0.24603468 0.28514414 0.16604163 0.26944842]
 [0.24849901 0.24846383 0.28259705 0.12498606 0.27767494]
 [0.23920135 0.2452787  0.25528224 0.23096774 0.23242857]] -> [0.23214682]
[[0.25385072 0.2492198  0.28236639 0.24290301 0.25327679]
 [0.23244422 0.25224303 0.27694984 0.18614198 0.26325034]
 [0.23974194 0.25915317 0.29413375 0.19530441 0.27344874]
 [0.24990476 0.24603468 0.28514414 0.16604163 0.26944842]
 [0.24849901 0.24846383 0.28259705 0.12498606 0.27767494]
 [0.23920135 0.2452787  0.25528224 0.23096774 0.23242857]
 [0.21006472 0.21337325 0.24341155 0.16644715 0.23214682]] -> [0.21845469]
[[0.23244422 0.25224303 0.27694984 0.18614198 0.26325034]
 [0.23974194 0.259153

 [0.23091749 0.23025763 0.25290596 0.28380515 0.22289252]] -> [0.23348561]
[[0.23405286 0.25838418 0.27901015 0.15944381 0.28149261]
 [0.24594537 0.26027353 0.29935215 0.12804385 0.26796941]
 [0.23497173 0.23225506 0.2438587  0.26288281 0.22407591]
 [0.2062136  0.24062275 0.25440428 0.16627179 0.26616629]
 [0.23437719 0.24094664 0.28056616 0.11199864 0.26611023]
 [0.23091749 0.23025763 0.25290596 0.28380515 0.22289252]
 [0.22778212 0.22032415 0.24316722 0.20784247 0.23348561]] -> [0.26272926]
[[0.24594537 0.26027353 0.29935215 0.12804385 0.26796941]
 [0.23497173 0.23225506 0.2438587  0.26288281 0.22407591]
 [0.2062136  0.24062275 0.25440428 0.16627179 0.26616629]
 [0.23437719 0.24094664 0.28056616 0.11199864 0.26611023]
 [0.23091749 0.23025763 0.25290596 0.28380515 0.22289252]
 [0.22778212 0.22032415 0.24316722 0.20784247 0.23348561]
 [0.20691637 0.25384925 0.25221448 0.21164554 0.26272926]] -> [0.2690964]
[[0.23497173 0.23225506 0.2438587  0.26288281 0.22407591]
 [0.2062136  0.2406227

 [1.18490748 1.20010679 1.26152599 0.14508641 1.24236267]] -> [1.16883106]
[[1.12787788 1.15529858 1.18620934 0.24771439 1.1789733 ]
 [1.04846879 1.13840111 1.07009404 0.37670079 1.16015389]
 [1.12182358 1.20501932 1.19226031 0.26350752 1.20601971]
 [1.23761263 1.25328235 1.34013297 0.21962431 1.3111614 ]
 [1.23750474 1.26165003 1.28186799 0.18673382 1.23464328]
 [1.18490748 1.20010679 1.26152599 0.14508641 1.24236267]
 [1.1839345  1.19986362 1.19185689 0.23961506 1.16883106]] -> [1.34333515]
[[1.04846879 1.13840111 1.07009404 0.37670079 1.16015389]
 [1.12182358 1.20501932 1.19226031 0.26350752 1.20601971]
 [1.23761263 1.25328235 1.34013297 0.21962431 1.3111614 ]
 [1.23750474 1.26165003 1.28186799 0.18673382 1.23464328]
 [1.18490748 1.20010679 1.26152599 0.14508641 1.24236267]
 [1.1839345  1.19986362 1.19185689 0.23961506 1.16883106]
 [1.23015279 1.28329826 1.29535217 0.29246343 1.34333515]] -> [1.41089408]
[[1.12182358 1.20501932 1.19226031 0.26350752 1.20601971]
 [1.23761263 1.253282

 [1.31470114 1.33197518 1.3902111  0.13481787 1.3005662 ]] -> [1.37677907]
[[1.0905722  1.28338466 1.24671962 0.16656945 1.34487952]
 [1.21177536 1.23007937 1.27145705 0.18586853 1.22034847]
 [1.14864591 1.24920142 1.31694669 0.11519714 1.2577209 ]
 [1.19116474 1.32935561 1.32862883 0.16521392 1.38145114]
 [1.28009031 1.27631251 1.39322076 0.33619459 1.34301069]
 [1.31470114 1.33197518 1.3902111  0.13481787 1.3005662 ]
 [1.27036811 1.32320035 1.44640735 0.09020023 1.37677907]] -> [1.44097982]
[[1.21177536 1.23007937 1.27145705 0.18586853 1.22034847]
 [1.14864591 1.24920142 1.31694669 0.11519714 1.2577209 ]
 [1.19116474 1.32935561 1.32862883 0.16521392 1.38145114]
 [1.28009031 1.27631251 1.39322076 0.33619459 1.34301069]
 [1.31470114 1.33197518 1.3902111  0.13481787 1.3005662 ]
 [1.27036811 1.32320035 1.44640735 0.09020023 1.37677907]
 [1.31781212 1.37336166 1.44326012 0.13366911 1.44097982]] -> [1.58766792]
[[1.14864591 1.24920142 1.31694669 0.11519714 1.2577209 ]
 [1.19116474 1.329355