In [1]:
import yaml
from pathlib import Path

# Load params.yaml
params_path = Path("hparams.yaml")
with open(params_path, "r") as f:
    hparams = yaml.safe_load(f)

print(type(hparams))
print(hparams["batch_size"])
print(hparams["features"])

<class 'dict'>
500
['track', 'groundspeed', 'altitude', 'timedelta']


In [2]:
from sklearn.preprocessing import MinMaxScaler
from data.dataset import TrafficDataset

dataset_tcvae = TrafficDataset.from_file(
        "data/traffic_noga_tilFAF_train.pkl",
        features=["track", "groundspeed", "altitude", "timedelta"],
        scaler=MinMaxScaler(feature_range=(-1, 1)),
        shape= "image",
        info_params={"features": ["latitude", "longitude"], "index": -1},
    )

for sample in dataset_tcvae:
    break

len(sample), sample[0].shape,sample[1].shape

(2, torch.Size([4, 200]), torch.Size([2]))

In [5]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset_tcvae, batch_size=hparams["batch_size"], shuffle=True, num_workers=4)

for batch in loader:
    break

len(batch), batch[0].shape, batch[1].shape

(2, torch.Size([500, 4, 200]), torch.Size([500, 2]))

In [6]:
from tcn import TemporalBlock

t = TemporalBlock(
    in_channels=4,
    out_channels=hparams["h_dims"][-1],
    kernel_size=hparams["kernel_size"],
    dilation=hparams["dilation_base"],
    out_activ=None,
    dropout=0.2,
)

t

  WeightNorm.apply(module, name, dim)


TemporalBlock(
  (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
  (dropout): Dropout(p=0.2, inplace=False)
)

In [7]:
t(batch[0]).shape

torch.Size([500, 64, 200])

In [8]:
from tcn import ResidualBlock

r = ResidualBlock(
    in_channels=4,
    out_channels=hparams["h_dims"][-1],
    kernel_size=hparams["kernel_size"],
    dilation=hparams["dilation_base"],
    h_activ=None,
    dropout=0.2,
)

r

ResidualBlock(
  (tmp_block1): TemporalBlock(
    (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (tmp_block2): TemporalBlock(
    (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
)

In [9]:
r(batch[0]).shape

torch.Size([500, 64, 200])

In [10]:
from tcn import TCN

tn = TCN(
    input_dim=4,
    out_dim=hparams["h_dims"][-1],
    h_dims=hparams["h_dims"],
    kernel_size=hparams["kernel_size"],
    dilation_base=hparams["dilation_base"],
    h_activ=None,
    dropout=0.2,)

tn

TCN(
  (network): Sequential(
    (0): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (tmp_block2): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
    )
    (1): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (tmp_block2): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
    )
    (2): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(4,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
  

In [11]:
tn(batch[0]).shape

torch.Size([500, 64, 200])

In [12]:
from encoder import Encoder

e = Encoder(
    input_dim=4,
    out_dim=hparams["h_dims"][-1],
    h_dims=hparams["h_dims"],
    kernel_size=hparams["kernel_size"],
    dilation_base=hparams["dilation_base"],
    sampling_factor=hparams["sampling_factor"],
    h_activ=None,
    dropout=0.2,
)

e

Encoder(
  (tcn): TCN(
    (network): Sequential(
      (0): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (tmp_block2): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
      )
      (1): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (tmp_block2): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (2): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilati

In [26]:
#befoore flatten shape was 500x64x20
encoder_output = e(batch[0])
encoder_output.shape

torch.Size([500, 1280])

In [24]:
h_dim = hparams["h_dims"][-1] * (
            int(batch[0][0].shape[-1] / hparams["sampling_factor"])
        )
h_dim

1280

In [30]:
from lsr import NormalLSR

lsr = NormalLSR(
    input_dim=h_dim,
    out_dim=hparams["encoding_dim"],
)

lsr(encoder_output), lsr(encoder_output).rsample().shape

(Independent(Normal(loc: torch.Size([500, 64]), scale: torch.Size([500, 64])), 1),
 torch.Size([500, 64]))

In [31]:
sampled = lsr(encoder_output).rsample()

In [35]:
batch[0][0].shape[0]

4

In [36]:
from decoder import TCDecoder
from torch import nn

d = decoder = TCDecoder(
            input_dim=hparams["encoding_dim"],
            out_dim=batch[0][0].shape[0],
            h_dims=hparams["h_dims"][::-1],
            seq_len=batch[0][0].shape[-1],
            kernel_size=hparams["kernel_size"],
            dilation_base=hparams["dilation_base"],
            sampling_factor=hparams["sampling_factor"],
            dropout=hparams["dropout"],
            h_activ=nn.ReLU(),
            # h_activ=None,
        )

reconstructed = d(sampled)
reconstructed.shape

  WeightNorm.apply(module, name, dim)


torch.Size([500, 4, 200])