# Imports

In [1]:
import torch 
print('torch.cuda.is_available:', torch.cuda.is_available())
print('torch.cuda.device_count:', torch.cuda.device_count())

torch.cuda.is_available: True
torch.cuda.device_count: 2


In [2]:
import numpy as np
import torch
torch.set_float32_matmul_precision('high')
from train import *

device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')

# Experiment settings
Settings are based on the training procedure that produced the final generator used in evaluating the stylized facts and training the RL agent in portfolio management.  
However, using different machines would result in cumulative differences in computation results likely due to varied machine precision.  
Original experiments were performed on machine with following specs:
1. OS: Ubuntu 20.04.1
2. CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz
3. GPU: Tesla V100-SXM2-32GB

### Dataset related

In [3]:
# samples
batch_size = 38 # number of samples in each batch
sample_len = 300 # length of each sample
sample_model = 'Realdt' # GBM, Heston, OU, RealData, Realdt, spx_rates
lead_lag = True # whether to use lead lag transformation
lags = [1] # number of lags to use for lead lag transformation: int or list[int]
seed = 42

# real data parameters
stride = 50 # for real data
start_date = '1995-01-01' # start date for real data
end_date = '2018-09-18' # end date for real data

### Generator and kernel related

In [4]:
# signature kernel
static_kernel_type = 'rq' # type of static kernel to use - rbf, rbfmix, rq, rqmix, rqlinear for
n_levels = 10 # number of levels in the truncated signature kernel

# generator
seq_dim = 1 # dimension of sequence vector
activation = 'Tanh' # pytorch names e.g. Tanh, ReLU. NOTE: does NOT change transformer layers'
hidden_size = 64
n_lstm_layers = 1 # number of LSTM layers
conditional = True # feed in history for LSTM generators
hist_len = 50

### Noise related

In [5]:
noise_dim = 4 # dimension of noise vector
ma = True # whether to use MA noise generator fitted to log returns gaussianized by Lambert W transformation
ma_p = 20

### Training Related

In [6]:
epochs = 10000 # number of batches
start_lr = 0.001 # starting learning rate
patience = 100 # number of epochs to wait before reducing lr
lr_factor = 0.5 # factor to multiply lr by for scheduler
early_stopping = patience*3 # number of epochs to wait before no improvement
kernel_sigma = 0.1 # starting kernel_sigma
num_losses = 20

### Save to tensorboard log

In [7]:
# save all parameters to a dictionary
rng = np.random.default_rng(seed)
torch.manual_seed(seed)

data_params, model_params, train_params = get_params_dicts(vars().copy())

# save parameters to tensorboard
writer = start_writer(data_params, model_params, train_params)

### Data, kernel, generator

In [8]:
dataloader = get_dataloader(**{**data_params, **model_params})
kernel = get_signature_kernel(**{**model_params, **train_params})
generator = get_generator(**{**model_params, **data_params})
generator.to(device)

Optimization terminated successfully    (Exit mode 0)
            Current function value: 5855.988337133396
            Iterations: 36
            Function evaluations: 836
            Gradient evaluations: 36
                        Zero Mean - ARCH Model Results                        
Dep. Variable:           gaussianized   R-squared:                       0.000
Mean Model:                 Zero Mean   Adj. R-squared:                  0.000
Vol Model:                       ARCH   Log-Likelihood:               -5855.99
Distribution:                  Normal   AIC:                           11754.0
Method:            Maximum Likelihood   BIC:                           11897.9
                                        No. Observations:                 6999
Date:                Sun, Sep 01 2024   Df Residuals:                     6999
Time:                        18:51:36   Df Model:                            0
                               Volatility Model                              
 

GenLSTM(
  (rnn): LSTM(6, 64, batch_first=True)
  (mean_net): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): Tanh()
    (4): Linear(in_features=64, out_features=1, bias=True)
  )
  (var_net): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): Tanh()
    (4): Linear(in_features=64, out_features=1, bias=True)
  )
  (output_net): Linear(in_features=64, out_features=1, bias=True)
)

# Train MMD-GAN

In [9]:
train(generator, kernel, dataloader, rng, writer, device, **{**train_params, **model_params, **data_params})

100%|██████████| 3/3 [00:07<00:00,  2.36s/it]


Epoch 0, loss: 709563.8958333334, avg_last_20_loss: 709563.8958333334
Saving model at epoch 0


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 1, loss: 473359.2083333333, avg_last_20_loss: 591461.5520833334
Saving model at epoch 1


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 2, loss: 337491.7083333333, avg_last_20_loss: 506804.9375
Saving model at epoch 2


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 3, loss: 189212.671875, avg_last_20_loss: 427406.87109375
Saving model at epoch 3


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 4, loss: 114609.55729166667, avg_last_20_loss: 364847.4083333333
Saving model at epoch 4


100%|██████████| 3/3 [00:04<00:00,  1.53s/it]


Epoch 5, loss: 119154.27083333333, avg_last_20_loss: 323898.5520833333
Saving model at epoch 5


100%|██████████| 3/3 [00:04<00:00,  1.53s/it]


Epoch 6, loss: 83619.80208333333, avg_last_20_loss: 289573.01636904763
Saving model at epoch 6


100%|██████████| 3/3 [00:04<00:00,  1.53s/it]


Epoch 7, loss: 96361.3203125, avg_last_20_loss: 265421.5543619792
Saving model at epoch 7


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 8, loss: 91099.2421875, avg_last_20_loss: 246052.40856481483
Saving model at epoch 8


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 9, loss: 76247.15104166667, avg_last_20_loss: 229071.8828125
Saving model at epoch 9


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 10, loss: 79205.76432291667, avg_last_20_loss: 215447.69022253787
Saving model at epoch 10


100%|██████████| 3/3 [00:04<00:00,  1.51s/it]


Epoch 11, loss: 71667.53255208333, avg_last_20_loss: 203466.01041666666
Saving model at epoch 11


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 12, loss: 56534.044270833336, avg_last_20_loss: 192163.5514823718
Saving model at epoch 12


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 13, loss: 60967.720052083336, avg_last_20_loss: 182792.42066592263
Saving model at epoch 13


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 14, loss: 58027.140625, avg_last_20_loss: 174474.73532986114
Saving model at epoch 14


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 15, loss: 70765.92447916667, avg_last_20_loss: 167992.93465169272
Saving model at epoch 15


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 16, loss: 47617.727864583336, avg_last_20_loss: 160912.04013480394
Saving model at epoch 16


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 17, loss: 73541.75651041667, avg_last_20_loss: 156058.13548900464
Saving model at epoch 17


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 18, loss: 58384.65625, avg_last_20_loss: 150917.4260553728
Saving model at epoch 18


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 19, loss: 55033.266927083336, avg_last_20_loss: 146123.21809895834
Saving model at epoch 19


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 20, loss: 86222.3359375, avg_last_20_loss: 114956.14010416667
Saving model at epoch 20


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 21, loss: 73555.36588541667, avg_last_20_loss: 94965.94798177082
Saving model at epoch 21


100%|██████████| 3/3 [00:04<00:00,  1.52s/it]


Epoch 22, loss: 42568.720052083336, avg_last_20_loss: 80219.79856770832
Saving model at epoch 22


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 23, loss: 61695.973958333336, avg_last_20_loss: 73843.96367187498
Saving model at epoch 23


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 24, loss: 55326.174479166664, avg_last_20_loss: 70879.79453125
Saving model at epoch 24


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 25, loss: 66421.0546875, avg_last_20_loss: 68243.13372395834
Saving model at epoch 25


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 26, loss: 60068.032552083336, avg_last_20_loss: 67065.54524739584
Saving model at epoch 26


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 27, loss: 45016.23828125, avg_last_20_loss: 64498.29114583334
Saving model at epoch 27


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 28, loss: 55358.166666666664, avg_last_20_loss: 62711.23736979168
Saving model at epoch 28


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 29, loss: 80294.82161458333, avg_last_20_loss: 62913.6208984375


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 30, loss: 65396.486979166664, avg_last_20_loss: 62223.15703125
Saving model at epoch 30


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 31, loss: 54874.783854166664, avg_last_20_loss: 61383.51959635418
Saving model at epoch 31


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 32, loss: 53549.579427083336, avg_last_20_loss: 61234.296354166676
Saving model at epoch 32


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 33, loss: 43153.186197916664, avg_last_20_loss: 60343.569661458336
Saving model at epoch 33


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 34, loss: 45043.703125, avg_last_20_loss: 59694.397786458336
Saving model at epoch 34


100%|██████████| 3/3 [00:04<00:00,  1.51s/it]


Epoch 35, loss: 35456.052734375, avg_last_20_loss: 57928.90419921875
Saving model at epoch 35


100%|██████████| 3/3 [00:04<00:00,  1.51s/it]


Epoch 36, loss: 41625.126302083336, avg_last_20_loss: 57629.27412109374
Saving model at epoch 36


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 37, loss: 35736.473307291664, avg_last_20_loss: 55739.0099609375
Saving model at epoch 37


100%|██████████| 3/3 [00:04<00:00,  1.55s/it]


Epoch 38, loss: 36667.406901041664, avg_last_20_loss: 54653.14749348959
Saving model at epoch 38


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 39, loss: 50434.805338541664, avg_last_20_loss: 54423.2244140625
Saving model at epoch 39


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 40, loss: 40125.3203125, avg_last_20_loss: 52118.373632812494
Saving model at epoch 40


100%|██████████| 3/3 [00:04<00:00,  1.52s/it]


Epoch 41, loss: 36505.322916666664, avg_last_20_loss: 50265.87148437499
Saving model at epoch 41


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 42, loss: 30810.900390625, avg_last_20_loss: 49677.980501302074
Saving model at epoch 42


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 43, loss: 37782.427083333336, avg_last_20_loss: 48482.30315755208
Saving model at epoch 43


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 44, loss: 42587.416666666664, avg_last_20_loss: 47845.36526692708
Saving model at epoch 44


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 45, loss: 33156.866536458336, avg_last_20_loss: 46182.15585937501
Saving model at epoch 45


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 46, loss: 25197.73046875, avg_last_20_loss: 44438.64075520833
Saving model at epoch 46


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 47, loss: 32740.358723958332, avg_last_20_loss: 43824.84677734375
Saving model at epoch 47


100%|██████████| 3/3 [00:04<00:00,  1.51s/it]


Epoch 48, loss: 33949.711588541664, avg_last_20_loss: 42754.4240234375
Saving model at epoch 48


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 49, loss: 32026.943359375, avg_last_20_loss: 40341.030110677086
Saving model at epoch 49


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 50, loss: 41292.598958333336, avg_last_20_loss: 39135.835709635416
Saving model at epoch 50


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 51, loss: 32144.763020833332, avg_last_20_loss: 37999.33466796875
Saving model at epoch 51


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 52, loss: 25153.161458333332, avg_last_20_loss: 36579.51376953126
Saving model at epoch 52


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 53, loss: 33010.536458333336, avg_last_20_loss: 36072.381282552094
Saving model at epoch 53


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 54, loss: 40085.52734375, avg_last_20_loss: 35824.472493489586
Saving model at epoch 54


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 55, loss: 25497.25, avg_last_20_loss: 35326.53235677084
Saving model at epoch 55


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 56, loss: 39952.09765625, avg_last_20_loss: 35242.88092447917
Saving model at epoch 56


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 57, loss: 40291.486979166664, avg_last_20_loss: 35470.631608072916


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 58, loss: 33171.462239583336, avg_last_20_loss: 35295.834375


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 59, loss: 27644.652994791668, avg_last_20_loss: 34156.326757812494
Saving model at epoch 59


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 60, loss: 26516.30859375, avg_last_20_loss: 33475.876171875
Saving model at epoch 60


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 61, loss: 30671.360026041668, avg_last_20_loss: 33184.17802734375
Saving model at epoch 61


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 62, loss: 30900.690104166668, avg_last_20_loss: 33188.66751302083


100%|██████████| 3/3 [00:04<00:00,  1.46s/it]


Epoch 63, loss: 23473.911458333332, avg_last_20_loss: 32473.24173177083
Saving model at epoch 63


100%|██████████| 3/3 [00:04<00:00,  1.46s/it]


Epoch 64, loss: 28431.796875, avg_last_20_loss: 31765.4607421875
Saving model at epoch 64


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 65, loss: 18922.771484375, avg_last_20_loss: 31053.755989583333
Saving model at epoch 65


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 66, loss: 32690.96875, avg_last_20_loss: 31428.41790364583


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 67, loss: 31536.305989583332, avg_last_20_loss: 31368.21526692708


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 68, loss: 29361.512369791668, avg_last_20_loss: 31138.805305989583


100%|██████████| 3/3 [00:04<00:00,  1.46s/it]


Epoch 69, loss: 31062.215494791668, avg_last_20_loss: 31090.568912760413


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 70, loss: 25123.915364583332, avg_last_20_loss: 30282.13473307291
Saving model at epoch 70


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 71, loss: 27094.880859375, avg_last_20_loss: 30029.640625
Saving model at epoch 71


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 72, loss: 24852.135091145832, avg_last_20_loss: 30014.58930664063
Saving model at epoch 72


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 73, loss: 24777.979166666668, avg_last_20_loss: 29602.961442057287
Saving model at epoch 73


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 74, loss: 22903.2021484375, avg_last_20_loss: 28743.84518229166
Saving model at epoch 74


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 75, loss: 26307.65625, avg_last_20_loss: 28784.365494791662


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 76, loss: 23220.022135416668, avg_last_20_loss: 27947.76171875
Saving model at epoch 76


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 77, loss: 19814.1103515625, avg_last_20_loss: 26923.892887369788
Saving model at epoch 77


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 78, loss: 20932.03515625, avg_last_20_loss: 26311.921533203124
Saving model at epoch 78


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 79, loss: 27430.202473958332, avg_last_20_loss: 26301.199007161456
Saving model at epoch 79


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 80, loss: 23313.182291666668, avg_last_20_loss: 26141.04269205729
Saving model at epoch 80


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 81, loss: 18376.716145833332, avg_last_20_loss: 25526.31049804687
Saving model at epoch 81


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 82, loss: 20420.068359375, avg_last_20_loss: 25002.279410807292
Saving model at epoch 82


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 83, loss: 21493.6875, avg_last_20_loss: 24903.268212890624
Saving model at epoch 83


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 84, loss: 22997.530598958332, avg_last_20_loss: 24631.554899088536
Saving model at epoch 84


100%|██████████| 3/3 [00:04<00:00,  1.48s/it]


Epoch 85, loss: 23742.240885416668, avg_last_20_loss: 24872.528369140626


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 86, loss: 25598.506510416668, avg_last_20_loss: 24517.90525716146
Saving model at epoch 86


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 87, loss: 18796.965494791668, avg_last_20_loss: 23880.938232421875
Saving model at epoch 87


100%|██████████| 3/3 [00:04<00:00,  1.46s/it]


Epoch 88, loss: 25265.657552083332, avg_last_20_loss: 23676.14549153646
Saving model at epoch 88


100%|██████████| 3/3 [00:04<00:00,  1.47s/it]


Epoch 89, loss: 20460.940755208332, avg_last_20_loss: 23146.081754557294
Saving model at epoch 89


100%|██████████| 3/3 [00:04<00:00,  1.50s/it]


Epoch 90, loss: 25443.932942708332, avg_last_20_loss: 23162.08263346354


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 91, loss: 20865.880208333332, avg_last_20_loss: 22850.632600911456
Saving model at epoch 91


100%|██████████| 3/3 [00:04<00:00,  1.51s/it]


Epoch 92, loss: 15025.775716145834, avg_last_20_loss: 22359.314632161455
Saving model at epoch 92


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 93, loss: 18072.0244140625, avg_last_20_loss: 22024.016894531247
Saving model at epoch 93


100%|██████████| 3/3 [00:04<00:00,  1.49s/it]


Epoch 94, loss: 16158.642252604166, avg_last_20_loss: 21686.788899739586
Saving model at epoch 94


 67%|██████▋   | 2/3 [00:02<00:01,  1.49s/it]