<a href="https://colab.research.google.com/github/jsxhhyf/Optiver/blob/main/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!mkdir -p ~/.kaggle
!cp ./drive/MyDrive/Colab\ Notebooks/Kaggle/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [2]:
!mkdir -p ./optiver-realized-volatility-prediction/parquet
!mkdir -p ./optiver-realized-volatility-prediction/csv
!cp ./drive/MyDrive/Colab\ Notebooks/Kaggle/preprocessed/*.parquet ./optiver-realized-volatility-prediction/parquet
!cp ./drive/MyDrive/Colab\ Notebooks/Kaggle/preprocessed/*.csv ./optiver-realized-volatility-prediction/csv

In [3]:
!pip install wandb
!pip install pytorch-lightning
!pip install tensor-sensor[torch]

Collecting wandb
  Downloading wandb-0.11.2-py2.py3-none-any.whl (1.8 MB)
[?25l[K     |▏                               | 10 kB 21.2 MB/s eta 0:00:01[K     |▍                               | 20 kB 23.6 MB/s eta 0:00:01[K     |▌                               | 30 kB 21.0 MB/s eta 0:00:01[K     |▊                               | 40 kB 17.2 MB/s eta 0:00:01[K     |█                               | 51 kB 8.2 MB/s eta 0:00:01[K     |█                               | 61 kB 8.7 MB/s eta 0:00:01[K     |█▎                              | 71 kB 7.4 MB/s eta 0:00:01[K     |█▍                              | 81 kB 8.2 MB/s eta 0:00:01[K     |█▋                              | 92 kB 8.5 MB/s eta 0:00:01[K     |█▉                              | 102 kB 7.8 MB/s eta 0:00:01[K     |██                              | 112 kB 7.8 MB/s eta 0:00:01[K     |██▏                             | 122 kB 7.8 MB/s eta 0:00:01[K     |██▎                             | 133 kB 7.8 MB/s eta 0:00:01

In [4]:
# IMPORTS
import os, sys, random, datetime
import pandas as pd
# import modin.pandas as pd
import numpy as np

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import *

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import Metric

from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import accuracy_score

from tqdm import tqdm, trange

import wandb
import tsensor

In [5]:
wandb.init(project="Optiver",)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## 数据准备

In [6]:
# GLOBAL VARIABLES
BASE_DIRECTORY = "./optiver-realized-volatility-prediction"

In [7]:
# Get training labels
train_labels = pd.read_csv(os.path.join(BASE_DIRECTORY, 'csv','train.csv'))
train_labels.head()

Unnamed: 0,stock_id,time_id,target
0,0,5,0.004136
1,0,11,0.001445
2,0,16,0.002168
3,0,31,0.002195
4,0,62,0.001747


In [8]:
# 获得有效 id 列表
id_list = []
file_list = os.listdir(os.path.join(BASE_DIRECTORY, 'parquet'))
for file in file_list:
    if file.endswith('.parquet'):
        id_list.append(int(file[:-8]))

In [9]:
# 划分训练验证集
train_index, valid_index = train_test_split(id_list, test_size=0.3)

In [10]:
# Extract data

train_list = []
for id in tqdm(train_index):
    if os.path.exists(os.path.join(BASE_DIRECTORY, 'parquet', f'{id}.parquet')):
        t = pd.read_parquet(os.path.join(BASE_DIRECTORY, 'parquet', f'{id}.parquet'))
        groups = t.groupby('time_id')
        for g in groups:
            if len(g[1]) != 600:
                print(len(g[1]))
            train_list.append(g[1].reset_index(drop=True))

100%|██████████| 78/78 [01:15<00:00,  1.03it/s]


In [11]:
valid_list = []
for id in tqdm(valid_index):
    if os.path.exists(os.path.join(BASE_DIRECTORY, 'parquet', f'{id}.parquet')):
        t = pd.read_parquet(os.path.join(BASE_DIRECTORY, 'parquet', f'{id}.parquet'))
        groups = t.groupby('time_id')
        for g in groups:
            valid_list.append(g[1].reset_index(drop=True))

100%|██████████| 34/34 [00:43<00:00,  1.28s/it]


In [12]:
BATCHSIZE = 512

In [13]:
class MyDataset(Dataset):
    def __init__(self, df_list):
        self.df_list = df_list

    def __len__(self):
        return len(self.df_list)

    def __getitem__(self, idx):
        return self.df_list[idx]

def my_collate(data):
    x = [torch.from_numpy(df.drop(['stock_id', 'time_id', 'WAP'], axis=1).values).half() for df in data]
    y = [train_labels[(train_labels.stock_id==df.stock_id[0]) & (train_labels.time_id==df.time_id[0])].iloc[0,2] for df in data]
    return x, y

In [14]:
train_dataset = MyDataset(train_list)
train_loader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True, collate_fn=my_collate, drop_last=True, num_workers=2)
valid_dataset = MyDataset(valid_list)
valid_loader = DataLoader(valid_dataset, batch_size=BATCHSIZE, collate_fn=my_collate, drop_last=True, num_workers=2)

## Lightning Model

In [15]:
class RMSPE(torch.nn.Module):
    def __init__(self):
        super(RMSPE, self).__init__()
        return

    def forward(self, z_pred, z_true):
        loss = ((z_pred-z_true)/z_true).pow(2).mean().sqrt()
        return loss

In [29]:
class RMSPE(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state('squared_error_rate', default=torch.Tensor([]), dist_reduce_fx='cat')

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape

        self.squared_error_rate = torch.cat([self.squared_error_rate, ((preds-target)/target).pow(2)])

    def compute(self):
        return self.squared_error_rate.mean().sqrt()

In [16]:
class LitModel(pl.LightningModule):
    def __init__(
        self,
        lstm_input_size,
        lstm_hidden_size,
        num_lstm_layers,
        num_fc_features1,
        num_fc_features2,
        lr,
    ):
        super().__init__()

        self.batch_size = batch_size

        self.lstm = torch.nn.LSTM(
            input_size=lstm_input_size,
            hidden_size=lstm_hidden_size,
            num_layers=num_lstm_layers,
            batch_first=True,
        )

        self.lstm_output_size = lstm_hidden_size

        # self.sigmoid = torch.nn.Sigmoid()

        self.batch_norm = torch.nn.BatchNorm2d(9)

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(
                in_features=self.lstm_output_size, out_features=num_fc_features1
            ),
            # torch.nn.ReLU(),
            torch.nn.Sigmoid(),
            # torch.nn.Linear(
            #     in_features=num_fc_features1, out_features=num_fc_features2
            # ),
            # torch.nn.ReLU(),
            torch.nn.Linear(in_features=num_fc_features1, out_features=1),
        )

        self.loss = RMSPE()

        self.lr = lr

    # forward function of fc ################################################
    def forward(self, xs):
        temp_list = []
        for x in xs:
            temp_list.append(x.unsqueeze(0))
        input = torch.cat(temp_list)
        input = self.
        out, (h_n, c_n) = self.lstm(input)
        out = self.fc(h_n).squeeze()
        return out

    def training_step(self, batch, batch_idx):
        xs, ys = batch
        zs = self.forward(xs)
        ys = torch.Tensor(ys).half().cuda()
        loss = self.loss(zs, ys)
        self.log("train/loss", loss, on_epoch=True, on_step=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        xs, ys = val_batch
        zs = self.forward(xs)
        ys = torch.Tensor(ys).half().cuda()
        loss = self.loss(zs, ys)
        self.log("valid/loss", loss, on_epoch=True, on_step=True)

        return loss

    # def validation_epoch_end(self, validation_step_outputs):
    #     flattened_logits = torch.flatten(torch.cat(validation_step_outputs))
    #     self.logger.experiment.log(
    #         {
    #             "valid/logits": wandb.Histogram(flattened_logits.cpu()),
    #             "global_step": self.global_step,
    #         }
    #     )

    # def predict_step(self, pred_batch, batch_idx):
    #     xs, ys = pred_batch
    #     z = self.forward(xs)
    #     return z, ys

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

## 训练

In [17]:
wandb_logger = WandbLogger(
    name="0.0",
    # tags=[],
    version="0.0",
    notes = 'initial'
    # save_code=True,
)

In [18]:
model = LitModel(
    lstm_input_size=9,
    lstm_hidden_size=16,
    num_lstm_layers=1,
    num_fc_features1=32,
    num_fc_features2=192,
    lr=1e-4,
)

In [20]:
trainer = pl.Trainer(
    # callbacks=[
    #     early_stop_callback,
    #     checkpoint_callback,
    # ],
    gpus=-1,
    logger=wandb_logger,
    log_every_n_steps=10,
    max_epochs=10,
    deterministic=False,
    precision=16,
    default_root_dir='./drive/MyDrive/Colab Notebooks/checkpoints',
    # val_check_interval=1,
    auto_scale_batch_size='binsearch',
)

trainer.fit(model, train_loader, valid_loader)

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | lstm | LSTM       | 1.7 K 
1 | fc   | Sequential | 577   
2 | loss | RMSPE      | 0     
------------------------------------
2.3 K     Trainable params
0         Non-trainable params
2.3 K     Total params
0.009     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."




  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
