In [1]:
from deep_traffic_generation.tcvae import TCVAE
from deep_traffic_generation.core.datasets import TrafficDataset
from deep_traffic_generation.SecondStageVAE import VAE
from deep_traffic_generation.core.utils import get_dataloaders

from sklearn.preprocessing import MinMaxScaler

import numpy as np

import torch
import pytorch_lightning as pl

import warnings
warnings.filterwarnings("ignore")

torch.manual_seed(42)
np.random.seed(42)

# Load First Stage

In [2]:
dataset = TrafficDataset.from_file(
    "../../deep_traffic_generation/data/traffic_noga_tilFAF_train.pkl",
    features=["track", "groundspeed", "altitude", "timedelta"],
    scaler=MinMaxScaler(feature_range=(-1,1)),
    shape="image",
    info_params={"features": ["latitude", "longitude"], "index": -1},
)
dataset

Dataset TrafficDataset
    Number of datapoints: 14000
    MinMaxScaler(feature_range=(-1, 1))

In [3]:
from os import walk

filenames = next(walk("../../deep_traffic_generation/lightning_logs/tcvae/version_14/checkpoints"), (None, None, []))[2]

FirstStage = TCVAE.load_from_checkpoint(
    "../../deep_traffic_generation/lightning_logs/tcvae/version_14/checkpoints/" + filenames[0],
    hparams_file="../../deep_traffic_generation/lightning_logs/tcvae/version_14/hparams.yaml",
    dataset_params=dataset.parameters,
)
FirstStage.eval()

TCVAE(
  (encoder): Sequential(
    (0): TCN(
      (network): Sequential(
        (0): ResidualBlock(
          (tmp_block1): TemporalBlock(
            (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,))
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (tmp_block2): TemporalBlock(
            (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,))
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
        )
        (1): ResidualBlock(
          (tmp_block1): TemporalBlock(
            (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (tmp_block2): TemporalBlock(
            (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (2): ResidualBlock(
          (tmp_block1): TemporalBlock(


In [4]:
# Latent Space

h = FirstStage.encoder(dataset.data)
q = FirstStage.lsr(h)
z = q.rsample()
input_SecondStage = z.detach().cpu()

scaler=MinMaxScaler(feature_range=(-1, 1))
input_SecondStage = torch.Tensor(scaler.fit_transform(input_SecondStage))

SecondStage_train_loader, SecondStage_val_loader, SecondStage_test_loader = get_dataloaders(
        input_SecondStage,
        0.8, #train_ratio
        0.2, #val_ratio
        200, #batch_size
        200, #test_batch_size
    )

# Second Stage Training

In [5]:
torch.manual_seed(42)

SecondStage = VAE(input_dim= 256, latent_dim= 256, h_dims=[1024,1024,1024])
trainer = pl.Trainer(gpus=0, max_epochs=1000, progress_bar_refresh_rate=1)
trainer.fit(SecondStage, SecondStage_train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name      | Type   | Params
-------------------------------------
0 | encoder   | FCN    | 2.4 M 
1 | decoder   | FCN    | 2.6 M 
2 | fc_mu     | Linear | 262 K 
3 | fc_var    | Linear | 262 K 
4 | out_activ | Tanh   | 0     
-------------------------------------
5.5 M     Trainable params
0         Non-trainable params
5.5 M     Total params
22.089    Total estimated model params size (MB)


Training: -1it [00:00, ?it/s]