In [30]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np

import os
import sys
from pathlib import Path

sys.path.append(os.path.abspath('..'))

import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
from crypto_price_analysis import BinanceDataset

binance_dataset = BinanceDataset(Path("../data"), "ETHUSDT", "1m", 45)

binance_dataset.download_binance_dataset()
data_chunks, mean, std = binance_dataset.preprocess_binance_dataset()

In [32]:
from crypto_price_analysis.dataset import ChunkedSequenceDataset
from torch.utils.data import DataLoader, SubsetRandomSampler

sequence_length = 45
batch_size = 32
validation_split = 0.05

dataset = ChunkedSequenceDataset(list(map(lambda chunk: torch.Tensor(chunk.values.astype(np.float32)).to(device), data_chunks)), sequence_length)

dataset_size = len(dataset)
indices = np.arange(dataset_size, dtype=int)

split = int(validation_split * dataset_size)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]


train_sampler = SubsetRandomSampler(list(train_indices))
val_sampler = SubsetRandomSampler(list(val_indices))


train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler)

In [33]:
import torch.nn as nn

class PriceForecast(nn.Module):

    def __init__(self, device: torch.device):
        super(PriceForecast, self).__init__()
        self.lstm1 = nn.LSTM(input_size=7, hidden_size=32, num_layers=2, batch_first=True, device=device)
        self.linear1 = nn.Linear(32, 7, device=device)

    def forward(self, x):
        x, _ = self.lstm1(x)
        return self.linear1(x[:,-1,:])
    

In [34]:
from crypto_price_analysis.train import train_model

model = PriceForecast(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = nn.MSELoss()

history = train_model(model, train_loader, val_loader, optimizer, loss, 1)

epoch: 0 duration: 220.27 loss: 0.68616 val_loss: 0.63179 
