In [1]:
import gc

import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split

import T4sigWGAN as T4

gc.collect()

  from .autonotebook import tqdm as notebook_tqdm


10

In [2]:
total_dataset = T4.StockTimeSeriesDataset(T4.args.window_size)
train_size = int(0.9 * len(total_dataset))  # 90% for training
val_size = len(total_dataset) - train_size  # 10% for validation
train_dataset, val_dataset = random_split(total_dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=T4.args.batch_size, shuffle=False, num_workers=2,
                              drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=T4.args.batch_size, shuffle=False, num_workers=2, drop_last=True)

In [3]:
Encoder = T4.LogSigRNNEncoder(**T4.encoder_config).to(T4.args.device)
Decoder = T4.TimesFormerDecoder(**T4.decoder_config).to(T4.args.device)
Supervisor = T4.ModernTCN(T4.supervisor_config).to(T4.args.device)
Generator = T4.LogSigRNNGenerator(**T4.logsig_config).to(T4.args.device)
Discriminator = T4.tailGANDiscriminator(T4.discriminator_config).to(T4.args.device)
model = T4.T4sigWGAN(Encoder, Decoder, Generator, Supervisor, Discriminator, T4.args.batch_size).to(T4.args.device)
trainer = T4.FinetuneTrainer(T4.args, model, train_dataloader, val_dataloader)

Total Parameters: 22890978


In [4]:
stage = "Pretrain_1"

# model training
for epoch in tqdm(range(T4.args.epochs)):
    trainer.train(epoch, stage)
    val_loss = trainer.valid(epoch, stage)

# train log image save
trainer.evaluate('RMSE_loss_for_ER')

 10%|█         | 1/10 [00:09<01:24,  9.40s/it]


KeyboardInterrupt: 

In [None]:
trainer.save(stage)

In [None]:
stage = "Pretrain_2"

# model training
for epoch in tqdm(range(T4.args.epochs)):
    trainer.train(epoch, stage)
    val_loss = trainer.valid(epoch, stage)

# train log image save
trainer.evaluate('SigW1_supervisor_loss_for_S')

In [None]:
trainer.save(stage)

In [None]:
stage = "Finetune"

# model training
for epoch in tqdm(range(T4.args.epochs)):
    trainer.train(epoch, stage)
    val_loss = trainer.valid(epoch, stage)

# train log image save
trainer.evaluate('SigW1_supervisor_loss + RMSE_for_SR')
trainer.evaluate_tail('loss_D and loss_G + SigW1_generator_loss')

In [None]:
trainer.save(stage)

In [None]:
x_fake = model(600)
stacked = torch.stack([total_dataset[i] for i in range(600)])
x_real = stacked
T4.plot_summary(x_fake=x_fake.detach(), x_real=x_real.detach(), trainer="T4sigWGAN", G="LogSigRNN")
plt.savefig('./result/T4sigWGAN.png')
plt.close()

plt.figure(figsize=(12, 1.5), dpi=400)
plt.plot(x_fake.detach().cpu().numpy()[1][:, 1], label="Fake")
plt.plot(total_dataset[1].detach().cpu().numpy()[:, 1], label="Real")
plt.legend(loc="upper right")
plt.tight_layout()
plt.show()

In [None]:
# metric_iteration = 5
#
# discriminative_score = list()
# for _ in range(metric_iteration):
#     temp_disc = T4.discriminative_score_metrics(ori_data, generated_data)
#     discriminative_score.append(temp_disc)
#
# print('Discriminative score: ' + str(np.round(np.mean(discriminative_score), 4)))
#
#
# predictive_score = list()
# for tt in range(metric_iteration):
#     temp_pred = T4.predictive_score_metrics(ori_data, generated_data)
#     predictive_score.append(temp_pred)
#
# print('Predictive score: ' + str(np.round(np.mean(predictive_score), 4)))
#
#
# T4.visualization(ori_data, generated_data, 'pca')
# T4.visualization(ori_data, generated_data, 'tsne')