In [1]:
import yaml
from pathlib import Path

# Load params.yaml
params_path = Path("hparams.yaml")
with open(params_path, "r") as f:
    hparams = yaml.safe_load(f)

print(type(hparams))
print(hparams["batch_size"])
print(hparams["features"])

<class 'dict'>
500
['track', 'groundspeed', 'altitude', 'timedelta']


In [2]:
from sklearn.preprocessing import MinMaxScaler
from data.dataset import TrafficDataset

dataset_tcvae = TrafficDataset.from_file(
        "data/traffic_noga_tilFAF_train.pkl",
        features=["track", "groundspeed", "altitude", "timedelta"],
        scaler=MinMaxScaler(feature_range=(-1, 1)),
        shape= "image",
        info_params={"features": ["latitude", "longitude"], "index": -1},
    )

for sample in dataset_tcvae:
    break

len(sample), sample[0].shape,sample[1].shape

(2, torch.Size([4, 200]), torch.Size([2]))

In [None]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset_tcvae, batch_size=hparams["batch_size"], shuffle=True, num_workers=4)

for batch in loader:
    break

len(batch), batch[0].shape, batch[1].shape

(2, torch.Size([500, 4, 200]), torch.Size([500, 2]))

In [13]:
from tcn import TemporalBlock

t = TemporalBlock(
    in_channels=4,
    out_channels=hparams["h_dims"][-1],
    kernel_size=hparams["kernel_size"],
    dilation=hparams["dilation_base"],
    out_activ=None,
    dropout=0.2,
)

t

  WeightNorm.apply(module, name, dim)


TemporalBlock(
  (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
  (dropout): Dropout(p=0.2, inplace=False)
)

In [7]:
t(batch[0]).shape

torch.Size([500, 64, 200])

In [8]:
from tcn import ResidualBlock

r = ResidualBlock(
    in_channels=4,
    out_channels=hparams["h_dims"][-1],
    kernel_size=hparams["kernel_size"],
    dilation=hparams["dilation_base"],
    h_activ=None,
    dropout=0.2,
)

r

ResidualBlock(
  (tmp_block1): TemporalBlock(
    (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (tmp_block2): TemporalBlock(
    (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
)

In [9]:
r(batch[0]).shape

torch.Size([500, 64, 200])

In [10]:
from tcn import TCN

tn = TCN(
    input_dim=4,
    out_dim=hparams["h_dims"][-1],
    h_dims=hparams["h_dims"],
    kernel_size=hparams["kernel_size"],
    dilation_base=hparams["dilation_base"],
    h_activ=None,
    dropout=0.2,)

tn

TCN(
  (network): Sequential(
    (0): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (tmp_block2): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
    )
    (1): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (tmp_block2): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
    )
    (2): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(4,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
  

In [11]:
tn(batch[0]).shape

torch.Size([500, 64, 200])

In [37]:
from encoder import Encoder

e = Encoder(
    input_dim=4,
    out_dim=hparams["h_dims"][-1],
    h_dims=hparams["h_dims"][:-1],
    kernel_size=hparams["kernel_size"],
    dilation_base=hparams["dilation_base"],
    sampling_factor=hparams["sampling_factor"],
    h_activ=None,
    dropout=0.2,
)

e

Encoder(
  (tcn): TCN(
    (network): Sequential(
      (0): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (tmp_block2): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
      )
      (1): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (tmp_block2): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (2): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilati

In [26]:
#befoore flatten shape was 500x64x20
encoder_output = e(batch[0])
encoder_output.shape

torch.Size([500, 1280])

In [24]:
h_dim = hparams["h_dims"][-1] * (
            int(batch[0][0].shape[-1] / hparams["sampling_factor"])
        )
h_dim

1280

In [40]:
hparams["h_dims"][-1], batch[0][0].shape[-1], hparams["sampling_factor"]

(64, 200, 10)

In [30]:
from lsr import NormalLSR

lsr = NormalLSR(
    input_dim=h_dim,
    out_dim=hparams["encoding_dim"],
)

lsr(encoder_output), lsr(encoder_output).rsample().shape

(Independent(Normal(loc: torch.Size([500, 64]), scale: torch.Size([500, 64])), 1),
 torch.Size([500, 64]))

In [31]:
sampled = lsr(encoder_output).rsample()

In [35]:
batch[0][0].shape[0]

4

In [42]:
from lsr import VampPriorLSR

vamplsr = VampPriorLSR(
    original_dim=4,
    original_seq_len=200,
    input_dim=h_dim,
    out_dim=hparams["encoding_dim"],
    encoder=e,
    n_components=hparams["n_components"],  # Number of components in the VampPrior
)

vamplsr(encoder_output), vamplsr(encoder_output).rsample().shape

(Independent(Normal(loc: torch.Size([500, 64]), scale: torch.Size([500, 64])), 1),
 torch.Size([500, 64]))

In [44]:
sampled_vamp = vamplsr(encoder_output).rsample()

In [36]:
from decoder import TCDecoder
from torch import nn

d = decoder = TCDecoder(
            input_dim=hparams["encoding_dim"],
            out_dim=batch[0][0].shape[0],
            h_dims=hparams["h_dims"][::-1],
            seq_len=batch[0][0].shape[-1],
            kernel_size=hparams["kernel_size"],
            dilation_base=hparams["dilation_base"],
            sampling_factor=hparams["sampling_factor"],
            dropout=hparams["dropout"],
            h_activ=nn.ReLU(),
            # h_activ=None,
        )

reconstructed = d(sampled)
reconstructed.shape

  WeightNorm.apply(module, name, dim)


torch.Size([500, 4, 200])

In [45]:
reconstructed_vamp = d(sampled_vamp)
reconstructed_vamp.shape

torch.Size([500, 4, 200])

In [5]:
from tcvae import TCVAE

tcv = TCVAE(hparams=hparams)
rebuilt = tcv(batch[0])

In [10]:
len(rebuilt)

3

In [11]:
rebuilt[-1].shape

torch.Size([500, 4, 200])

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_device()
print("Using device:", device)

Using device: mps


In [2]:
import yaml
from pathlib import Path

# Load params.yaml
params_path = Path("hparams.yaml")
with open(params_path, "r") as f:
    hparams = yaml.safe_load(f)

print(type(hparams))
print(hparams["batch_size"])
print(hparams["features"])

<class 'dict'>
500
['track', 'groundspeed', 'altitude', 'timedelta']


In [3]:
from sklearn.preprocessing import MinMaxScaler
from data.dataset import TrafficDataset

dataset_tcvae = TrafficDataset.from_file(
        "data/traffic_noga_tilFAF_train.pkl",
        features=["track", "groundspeed", "altitude", "timedelta"],
        scaler=MinMaxScaler(feature_range=(-1, 1)),
        shape= "image",
        info_params={"features": ["latitude", "longitude"], "index": -1},
    )

for sample in dataset_tcvae:
    break

len(sample), sample[0].shape,sample[1].shape

(2, torch.Size([4, 200]), torch.Size([2]))

In [4]:
loader = DataLoader(
    dataset_tcvae,
    batch_size=hparams["batch_size"],
    shuffle=True,
    num_workers=4,
    pin_memory=(device.type != "cpu"),
)

In [5]:
from tcvae import TCVAE

seq_len = dataset_tcvae.seq_len  # should be 200 given your config
model = TCVAE(hparams=hparams, inputdim=4, seq_length=seq_len).to(device)

  WeightNorm.apply(module, name, dim)


In [6]:
opt = torch.optim.Adam(model.parameters(), lr=hparams["lr"])

In [7]:
def empty_device_cache():
    if device.type == "cuda":
        torch.cuda.empty_cache()
        # torch.cuda.ipc_collect()  # optional
    elif device.type == "mps":
        torch.mps.empty_cache()

In [None]:
from tqdm.auto import tqdm

model.train()

for epoch in range(500):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/500", leave=False)
    running_loss = 0.0
    running_kld_loss = 0.0
    running_llv_loss = 0.0
    n_batches = 0

    for batch_idx, batch in enumerate(pbar):
        # Move batch to device
        x, y = batch
        x = x.to(device, non_blocking=True)
        # y is unused by the model, but keep it in the tuple structure
        batch = (x, y)

        opt.zero_grad(set_to_none=True)

        # Use the model's predefined training_step to compute ELBO loss
        elbo, kld_loss, llv_loss = model.training_step(batch, batch_idx)

        # Backprop + step
        elbo.backward()
        opt.step()

        # Track progress
        n_batches += 1
        running_loss += elbo.item()
        running_kld_loss += kld_loss
        running_llv_loss += llv_loss
        pbar.set_postfix({"train_loss": f"{running_loss / n_batches:.4f}"})

    # Optional: free device cache each epoch
    empty_device_cache()

    # Simple epoch print
    avg_loss = running_loss / max(1, n_batches)
    avg_kld_loss = running_kld_loss / max(1, n_batches)
    avg_llv_loss = running_llv_loss / max(1, n_batches)
    print(f"Epoch {epoch+1:03d}: train_loss={avg_loss:.6f}, kld_loss={avg_kld_loss:.6f}, llv_loss={avg_llv_loss:.6f}")

Epoch 1/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 001: train_loss=773.400796, kld_loss=1.897652, llv_loss=771.503113


Epoch 2/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 002: train_loss=733.959558, kld_loss=4.774056, llv_loss=729.185486


Epoch 3/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 003: train_loss=704.441956, kld_loss=5.848205, llv_loss=698.593628


Epoch 4/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 004: train_loss=675.307537, kld_loss=6.671826, llv_loss=668.635742


Epoch 5/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 005: train_loss=645.748217, kld_loss=6.597014, llv_loss=639.151184


Epoch 6/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 006: train_loss=615.792234, kld_loss=6.791328, llv_loss=609.000916


Epoch 7/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 007: train_loss=585.346608, kld_loss=7.121823, llv_loss=578.224792


Epoch 8/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 008: train_loss=553.389550, kld_loss=7.491145, llv_loss=545.898438


Epoch 9/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 009: train_loss=519.693260, kld_loss=7.833515, llv_loss=511.859711


Epoch 10/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 010: train_loss=483.825415, kld_loss=8.285417, llv_loss=475.540039


Epoch 11/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 011: train_loss=445.838542, kld_loss=8.715195, llv_loss=437.123352


Epoch 12/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 012: train_loss=405.414779, kld_loss=9.192568, llv_loss=396.222168


Epoch 13/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 013: train_loss=362.082567, kld_loss=9.801116, llv_loss=352.281464


Epoch 14/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 014: train_loss=314.776851, kld_loss=10.415311, llv_loss=304.361572


Epoch 15/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 015: train_loss=263.677580, kld_loss=11.196975, llv_loss=252.480591


Epoch 16/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 016: train_loss=207.185485, kld_loss=12.033170, llv_loss=195.152328


Epoch 17/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 017: train_loss=145.424917, kld_loss=13.035581, llv_loss=132.389343


Epoch 18/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 018: train_loss=74.188703, kld_loss=14.372531, llv_loss=59.816174


Epoch 19/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 019: train_loss=-5.488967, kld_loss=16.010706, llv_loss=-21.499674


Epoch 20/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 020: train_loss=-99.579615, kld_loss=18.279213, llv_loss=-117.858818


Epoch 21/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 021: train_loss=-207.592559, kld_loss=21.392752, llv_loss=-228.985321


Epoch 22/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 022: train_loss=-345.635713, kld_loss=25.385279, llv_loss=-371.020996


Epoch 23/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 023: train_loss=-527.074342, kld_loss=32.035297, llv_loss=-559.109619


Epoch 24/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 024: train_loss=-700.010289, kld_loss=43.343433, llv_loss=-743.353577


Epoch 25/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 025: train_loss=-909.606203, kld_loss=66.138588, llv_loss=-975.744934


Epoch 26/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 026: train_loss=-1075.720149, kld_loss=82.221680, llv_loss=-1157.941895


Epoch 27/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 027: train_loss=-1122.824958, kld_loss=90.264732, llv_loss=-1213.089722


Epoch 28/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 028: train_loss=-1184.938875, kld_loss=96.117226, llv_loss=-1281.056030


Epoch 29/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 029: train_loss=-1277.170384, kld_loss=105.808754, llv_loss=-1382.979126


Epoch 30/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 030: train_loss=-1314.317522, kld_loss=110.711601, llv_loss=-1425.029297


Epoch 31/500:   0%|          | 0/28 [00:00<?, ?it/s]