In [None]:
import os
from tensorflow import keras
from keras.callbacks import EarlyStopping
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from datetime import date, timedelta
import yfinance as yf
from tqdm import tqdm


TIMESTEPS = 21
FEATURES = 5
PREDICT_DAYS = 8
MAX_EPOCHS = 100
PATIENCE = MAX_EPOCHS // 10
TEST_RUN = 10
STOCK_CODES = [
    "AAPL",
    "AMD",
    "CVNA",
    "GOOGL",
    "MSFT",
    "NVDA",
    "PLTR",
    "RBLX",
    "RIVN",
    "T",
    "TSLA",

]

In [None]:
def predict(stock_code):
  today = str((date.today()-timedelta(days=0)))
  end_date = today
  start_date = "2020" + today[4:]
  path = os.path.join("data", f'{stock_code}.csv')
  if not os.path.exists(path):
    print(f"Downloading {stock_code} data from {start_date} to {end_date}")
    data = yf.download(stock_code, start=start_date, end=end_date)
    data.to_csv(path)


  data = pd.read_csv(path)

  data.pop("Volume")
  data.pop("Date")
  arr = np.array(data)
  inputs, target = [], []
  for i in range(len(arr)-TIMESTEPS):
    inputs.append(arr[i:i+TIMESTEPS])
    target.append(arr[i+TIMESTEPS])


  x_train, x_test, y_train, y_test = train_test_split(np.array(inputs), np.array(target), test_size=0.2, shuffle=False)

  # plt.plot(data.index, data['Close'])

  model = keras.Sequential([
    keras.layers.Input((TIMESTEPS, FEATURES)),
    keras.layers.LSTM(128),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(FEATURES)
  ])
  # model.summary()
  model.compile(loss='mse', 
                optimizer=keras.optimizers.Adam(learning_rate=0.001),
                metrics=['mae'])

  history = model.fit(x_train, y_train, batch_size=16, epochs=MAX_EPOCHS, validation_data=(x_test, y_test),
              callbacks=[
                EarlyStopping(
                    monitor="loss",
                    patience=PATIENCE,
                    restore_best_weights=True,
                    start_from_epoch=8,
                ),
                EarlyStopping(
                    monitor="val_loss",
                    patience=PATIENCE,
                    restore_best_weights=True,
                    start_from_epoch=8,
                ),
            ],
            verbose=0
            )
  mae = history.history['mae']

  results = []
  current_input = []
  for i in range(TIMESTEPS-1, -1, -1):
    index = len(data)-i-1
    current_input.append(data.iloc[index])
  results.append(current_input[-1])
  for i in range(PREDICT_DAYS):
    output = model.predict(np.array([current_input]), verbose=0)
    results.append(list(output[0]))
    current_input.pop(0)
    current_input.append(output[0])
  return results, mae

  # print_results(results)

In [None]:
# return 1 if buy else short
def make_decision(results, mae):
  assert PREDICT_DAYS == len(results)-1
  avg_close = sum([r[3] for r in results[1:]])/PREDICT_DAYS
  avg_close_mae = sum(mae[-PATIENCE:])/PATIENCE/FEATURES*2
  yesterday = results[0][3]
  if avg_close-avg_close_mae > yesterday:
    return 1
  elif avg_close+avg_close_mae < yesterday:
    return 0
  else:
    return 0.5

def print_results(r):
  print("DAY     OPEN     HIGH      LOW    CLOSE   ADJ-CL")
  print("--- -------- -------- -------- -------- -------- ")
  for i, v in enumerate(r):
    print("{:<3} ".format(i), end="")
    print(("{:8.2f} "*5).format(*v))

def print_average_results(predictions):
  sums = [[0] * FEATURES for i in range(PREDICT_DAYS + 1)]
  n = len(predictions)
  for p in predictions:
    for i, d in enumerate(p):
      for j, v in enumerate(d):
        sums[i][j] += v
  for i in range(len(sums)):
    for j in range(len(sums[i])):
      sums[i][j] /= n
  print_results(sums)



In [None]:
for s in STOCK_CODES:
  decisions = 0
  predictions = []
  print(f"Analyzing {s}:")
  for t in tqdm(range(TEST_RUN)):
    prediction, mae = predict(s)
    decision = make_decision(prediction, mae)  
    decisions += decision
    predictions.append(prediction)
    # print(f" Decision score: {decision}")

  if decisions < TEST_RUN*0.3:
    print(f"SHORT {s}")
    print_average_results(predictions)
  elif decisions > TEST_RUN*0.7:
    print(f"BUY {s}")
    print_average_results(predictions)
  else:
    print("NOT SURE")

Analyzing AAPL:


100%|██████████| 10/10 [05:54<00:00, 35.44s/it]


NOT SURE
Analyzing AMD:


100%|██████████| 10/10 [05:22<00:00, 32.26s/it]


SHORT AMD
DAY     OPEN     HIGH      LOW    CLOSE   ADJ-CL
--- -------- -------- -------- -------- -------- 
0      95.20    97.43    93.45    97.40    97.40 
1      96.65    98.61    94.50    96.52    96.53 
2      95.86    97.78    93.72    95.72    95.73 
3      95.10    97.00    92.97    94.95    94.97 
4      94.42    96.30    92.30    94.26    94.28 
5      93.85    95.70    91.73    93.68    93.70 
6      93.37    95.22    91.26    93.20    93.23 
7      92.99    94.82    90.88    92.82    92.84 
8      92.68    94.50    90.58    92.50    92.53 
Analyzing CVNA:


100%|██████████| 10/10 [05:01<00:00, 30.17s/it]


NOT SURE
Analyzing GOOGL:


100%|██████████| 10/10 [07:05<00:00, 42.53s/it]


SHORT GOOGL
DAY     OPEN     HIGH      LOW    CLOSE   ADJ-CL
--- -------- -------- -------- -------- -------- 
0     116.11   118.48   116.01   116.51   116.51 
1     116.19   117.45   114.86   116.20   116.21 
2     115.70   116.96   114.37   115.71   115.72 
3     115.24   116.50   113.92   115.26   115.27 
4     114.84   116.09   113.52   114.85   114.86 
5     114.48   115.73   113.16   114.49   114.50 
6     114.16   115.41   112.84   114.18   114.19 
7     113.88   115.13   112.57   113.90   113.91 
8     113.64   114.88   112.33   113.66   113.67 
Analyzing MSFT:


100%|██████████| 10/10 [06:43<00:00, 40.33s/it]


NOT SURE
Analyzing NVDA:


100%|██████████| 10/10 [05:12<00:00, 31.20s/it]


NOT SURE
Analyzing PLTR:


100%|██████████| 10/10 [06:42<00:00, 40.28s/it]


SHORT PLTR
DAY     OPEN     HIGH      LOW    CLOSE   ADJ-CL
--- -------- -------- -------- -------- -------- 
0       9.21     9.60     9.02     9.52     9.52 
1       9.22     9.49     8.98     9.25     9.25 
2       9.08     9.35     8.84     9.11     9.11 
3       8.96     9.23     8.73     9.00     9.00 
4       8.85     9.12     8.63     8.89     8.89 
5       8.76     9.02     8.54     8.80     8.80 
6       8.68     8.93     8.46     8.72     8.71 
7       8.60     8.86     8.38     8.64     8.64 
8       8.54     8.79     8.32     8.58     8.58 
Analyzing RBLX:


100%|██████████| 10/10 [04:14<00:00, 25.40s/it]


NOT SURE
Analyzing RIVN:


100%|██████████| 10/10 [03:39<00:00, 21.96s/it]


NOT SURE
Analyzing T:


100%|██████████| 10/10 [10:05<00:00, 60.52s/it]


NOT SURE
Analyzing TSLA:


100%|██████████| 10/10 [06:39<00:00, 39.94s/it]

NOT SURE



