In [1]:
import yaml
from pathlib import Path

# Load params.yaml
params_path = Path("hparams.yaml")
with open(params_path, "r") as f:
    hparams = yaml.safe_load(f)

print(type(hparams))
print(hparams["batch_size"])
print(hparams["features"])

<class 'dict'>
500
['track', 'groundspeed', 'altitude', 'timedelta']


In [2]:
from sklearn.preprocessing import MinMaxScaler
from data.dataset import TrafficDataset

dataset_tcvae = TrafficDataset.from_file(
        "data/traffic_noga_tilFAF_train.pkl",
        features=["track", "groundspeed", "altitude", "timedelta"],
        scaler=MinMaxScaler(feature_range=(-1, 1)),
        shape= "image",
        info_params={"features": ["latitude", "longitude"], "index": -1},
    )

for sample in dataset_tcvae:
    break

len(sample), sample[0].shape,sample[1].shape

(2, torch.Size([4, 200]), torch.Size([2]))

In [3]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset_tcvae, batch_size=hparams["batch_size"], shuffle=True, num_workers=4)

for batch in loader:
    break

len(batch), batch[0].shape, batch[1].shape

(2, torch.Size([500, 4, 200]), torch.Size([500, 2]))

In [4]:
from tcn import TemporalBlock

t = TemporalBlock(
    in_channels=4,
    out_channels=hparams["h_dims"][-1],
    kernel_size=hparams["kernel_size"],
    dilation=hparams["dilation_base"],
    out_activ=None,
    dropout=0.2,
)

t

  WeightNorm.apply(module, name, dim)


TemporalBlock(
  (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
  (dropout): Dropout(p=0.2, inplace=False)
)

In [5]:
t(batch[0]).shape

torch.Size([500, 64, 200])

In [6]:
from tcn import ResidualBlock

r = ResidualBlock(
    in_channels=4,
    out_channels=hparams["h_dims"][-1],
    kernel_size=hparams["kernel_size"],
    dilation=hparams["dilation_base"],
    h_activ=None,
    dropout=0.2,
)

r

ResidualBlock(
  (tmp_block1): TemporalBlock(
    (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (tmp_block2): TemporalBlock(
    (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
)

In [7]:
r(batch[0]).shape

torch.Size([500, 64, 200])

In [8]:
from tcn import TCN

tn = TCN(
    input_dim=4,
    out_dim=hparams["h_dims"][-1],
    h_dims=hparams["h_dims"],
    kernel_size=hparams["kernel_size"],
    dilation_base=hparams["dilation_base"],
    h_activ=None,
    dropout=0.2,)

tn

TCN(
  (network): Sequential(
    (0): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (tmp_block2): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
    )
    (1): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (tmp_block2): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
    )
    (2): ResidualBlock(
      (tmp_block1): TemporalBlock(
        (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(4,))
        (dropout): Dropout(p=0.2, inplace=False)
      )
  

In [9]:
tn(batch[0]).shape

torch.Size([500, 64, 200])

In [10]:
from encoder import Encoder

e = Encoder(
    input_dim=4,
    out_dim=hparams["h_dims"][-1],
    h_dims=hparams["h_dims"][:-1],
    kernel_size=hparams["kernel_size"],
    dilation_base=hparams["dilation_base"],
    sampling_factor=hparams["sampling_factor"],
    h_activ=None,
    dropout=0.2,
)

e

Encoder(
  (tcn): TCN(
    (network): Sequential(
      (0): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(4, 64, kernel_size=(16,), stride=(1,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (tmp_block2): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (downsample): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
      )
      (1): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (tmp_block2): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilation=(2,))
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
      (2): ResidualBlock(
        (tmp_block1): TemporalBlock(
          (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), dilati

In [11]:
#befoore flatten shape was 500x64x20
encoder_output = e(batch[0])
encoder_output.shape

torch.Size([500, 1280])

In [13]:
h_dim = hparams["h_dims"][-1] * (
            int(batch[0][0].shape[-1] / hparams["sampling_factor"])
        )
h_dim

1280

In [14]:
hparams["h_dims"][-1], batch[0][0].shape[-1], hparams["sampling_factor"]

(64, 200, 10)

In [15]:
from lsr import NormalLSR

lsr = NormalLSR(
    input_dim=h_dim,
    out_dim=hparams["encoding_dim"],
)

lsr(encoder_output), lsr(encoder_output).rsample().shape

(Independent(Normal(loc: torch.Size([500, 64]), scale: torch.Size([500, 64])), 1),
 torch.Size([500, 64]))

In [16]:
sampled = lsr(encoder_output).rsample()

In [17]:
batch[0][0].shape[0]

4

In [18]:
from lsr import VampPriorLSR

vamplsr = VampPriorLSR(
    original_dim=4,
    original_seq_len=200,
    input_dim=h_dim,
    out_dim=hparams["encoding_dim"],
    encoder=e,
    n_components=hparams["n_components"],  # Number of components in the VampPrior
)

vamplsr(encoder_output), vamplsr(encoder_output).rsample().shape

(Independent(Normal(loc: torch.Size([500, 64]), scale: torch.Size([500, 64])), 1),
 torch.Size([500, 64]))

In [19]:
sampled_vamp = vamplsr(encoder_output).rsample()

In [20]:
from decoder import TCDecoder
from torch import nn

d = decoder = TCDecoder(
            input_dim=hparams["encoding_dim"],
            out_dim=batch[0][0].shape[0],
            h_dims=hparams["h_dims"][::-1],
            seq_len=batch[0][0].shape[-1],
            kernel_size=hparams["kernel_size"],
            dilation_base=hparams["dilation_base"],
            sampling_factor=hparams["sampling_factor"],
            dropout=hparams["dropout"],
            h_activ=nn.ReLU(),
            # h_activ=None,
        )

reconstructed = d(sampled)
reconstructed.shape

torch.Size([500, 4, 200])

In [21]:
reconstructed_vamp = d(sampled_vamp)
reconstructed_vamp.shape

torch.Size([500, 4, 200])

In [22]:
from tcvae import TCVAE

tcv = TCVAE(hparams=hparams)
lsr_distparams, z, x_hat = tcv(batch[0])

In [23]:
x_hat.shape

torch.Size([500, 4, 200])

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_device()
print("Using device:", device)

Using device: mps


In [2]:
import yaml
from pathlib import Path

# Load params.yaml
params_path = Path("hparams.yaml")
with open(params_path, "r") as f:
    hparams = yaml.safe_load(f)

print(type(hparams))
print(hparams["batch_size"])
print(hparams["features"])

<class 'dict'>
500
['track', 'groundspeed', 'altitude', 'timedelta']


In [3]:
from sklearn.preprocessing import MinMaxScaler
from data.dataset import TrafficDataset

dataset_tcvae = TrafficDataset.from_file(
        "data/traffic_noga_tilFAF_train.pkl",
        features=["track", "groundspeed", "altitude", "timedelta"],
        scaler=MinMaxScaler(feature_range=(-1, 1)),
        shape= "image",
        info_params={"features": ["latitude", "longitude"], "index": -1},
    )

for sample in dataset_tcvae:
    break

len(sample), sample[0].shape,sample[1].shape

(2, torch.Size([4, 200]), torch.Size([2]))

In [4]:
loader = DataLoader(
    dataset_tcvae,
    batch_size=hparams["batch_size"],
    shuffle=True,
    num_workers=4,
    pin_memory=(device.type != "cpu"),
)

In [5]:
from tcvae import TCVAE

seq_len = dataset_tcvae.seq_len  # should be 200 given your config
model = TCVAE(hparams=hparams, inputdim=4, seq_length=seq_len).to(device)

  WeightNorm.apply(module, name, dim)


In [6]:
opt = torch.optim.Adam(model.parameters(), lr=hparams["lr"])

In [7]:
def empty_device_cache():
    if device.type == "cuda":
        torch.cuda.empty_cache()
        # torch.cuda.ipc_collect()  # optional
    elif device.type == "mps":
        torch.mps.empty_cache()

In [None]:
from tqdm.auto import tqdm

model.train()

for epoch in range(500):
    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/500", leave=False)
    running_loss = 0.0
    running_kld_loss = 0.0
    running_llv_loss = 0.0
    n_batches = 0

    for batch_idx, batch in enumerate(pbar):
        # Move batch to device
        x, y = batch
        x = x.to(device, non_blocking=True)
        # y is unused by the model, but keep it in the tuple structure
        batch = (x, y)

        opt.zero_grad(set_to_none=True)

        # Use the model's predefined training_step to compute ELBO loss
        elbo, kld_loss, llv_loss = model.training_step(batch, batch_idx)

        # Backprop + step
        elbo.backward()
        opt.step()

        # Track progress
        n_batches += 1
        running_loss += elbo.item()
        running_kld_loss += kld_loss
        running_llv_loss += llv_loss
        pbar.set_postfix({"train_loss": f"{running_loss / n_batches:.4f}"})

    # Optional: free device cache each epoch
    empty_device_cache()

    # Simple epoch print
    avg_loss = running_loss / max(1, n_batches)
    avg_kld_loss = running_kld_loss / max(1, n_batches)
    avg_llv_loss = running_llv_loss / max(1, n_batches)
    print(f"Epoch {epoch+1:03d}: train_loss={avg_loss:.6f}, kld_loss={avg_kld_loss:.6f}, llv_loss={avg_llv_loss:.6f}")



Epoch 1/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 001: train_loss=789.793483, kld_loss=2.213984, llv_loss=787.579590


Epoch 2/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 002: train_loss=767.447486, kld_loss=4.271811, llv_loss=763.175720


Epoch 3/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 003: train_loss=762.608259, kld_loss=4.826683, llv_loss=757.781616


Epoch 4/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 004: train_loss=759.357191, kld_loss=5.284576, llv_loss=754.072571


Epoch 5/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 005: train_loss=756.502947, kld_loss=5.504741, llv_loss=750.998169


Epoch 6/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 006: train_loss=754.807641, kld_loss=5.475673, llv_loss=749.332031


Epoch 7/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 007: train_loss=753.632193, kld_loss=5.358947, llv_loss=748.273315


Epoch 8/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 008: train_loss=753.044231, kld_loss=5.323489, llv_loss=747.720764


Epoch 9/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 009: train_loss=752.439168, kld_loss=5.349605, llv_loss=747.089478


Epoch 10/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 010: train_loss=751.919918, kld_loss=5.355347, llv_loss=746.564636


Epoch 11/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 011: train_loss=751.570836, kld_loss=5.367347, llv_loss=746.203552


Epoch 12/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 012: train_loss=751.225963, kld_loss=5.421280, llv_loss=745.804565


Epoch 13/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 013: train_loss=750.782937, kld_loss=5.419363, llv_loss=745.363464


Epoch 14/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 014: train_loss=750.478186, kld_loss=5.469743, llv_loss=745.008423


Epoch 15/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 015: train_loss=750.113538, kld_loss=5.444224, llv_loss=744.669373


Epoch 16/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 016: train_loss=749.949504, kld_loss=5.437018, llv_loss=744.512512


Epoch 17/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 017: train_loss=749.790944, kld_loss=5.455531, llv_loss=744.335388


Epoch 18/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 018: train_loss=749.708533, kld_loss=5.414989, llv_loss=744.293457


Epoch 19/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 019: train_loss=749.468416, kld_loss=5.460172, llv_loss=744.008301


Epoch 20/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 020: train_loss=749.384103, kld_loss=5.436135, llv_loss=743.948059


Epoch 21/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 021: train_loss=749.257710, kld_loss=5.448231, llv_loss=743.809509


Epoch 22/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 022: train_loss=749.227910, kld_loss=5.434752, llv_loss=743.793152


Epoch 23/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 023: train_loss=749.180814, kld_loss=5.479869, llv_loss=743.700745


Epoch 24/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 024: train_loss=749.053990, kld_loss=5.455201, llv_loss=743.598816


Epoch 25/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 025: train_loss=748.947739, kld_loss=5.458354, llv_loss=743.489380


Epoch 26/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 026: train_loss=748.903229, kld_loss=5.479332, llv_loss=743.423828


Epoch 27/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 027: train_loss=748.811066, kld_loss=5.471159, llv_loss=743.339844


Epoch 28/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 028: train_loss=748.737248, kld_loss=5.451046, llv_loss=743.286133


Epoch 29/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 029: train_loss=748.750229, kld_loss=5.474953, llv_loss=743.275269


Epoch 30/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 030: train_loss=748.688725, kld_loss=5.485872, llv_loss=743.202820


Epoch 31/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 031: train_loss=748.610807, kld_loss=5.486404, llv_loss=743.124390


Epoch 32/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 032: train_loss=748.584630, kld_loss=5.441709, llv_loss=743.142944


Epoch 33/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 033: train_loss=748.557273, kld_loss=5.463021, llv_loss=743.094238


Epoch 34/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 034: train_loss=748.559915, kld_loss=5.492224, llv_loss=743.067749


Epoch 35/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 035: train_loss=748.444007, kld_loss=5.469955, llv_loss=742.974060


Epoch 36/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 036: train_loss=748.452595, kld_loss=5.487327, llv_loss=742.965332


Epoch 37/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 037: train_loss=748.435604, kld_loss=5.461471, llv_loss=742.974182


Epoch 38/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 038: train_loss=748.417779, kld_loss=5.451975, llv_loss=742.965698


Epoch 39/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 039: train_loss=748.339724, kld_loss=5.458650, llv_loss=742.881042


Epoch 40/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 040: train_loss=748.308910, kld_loss=5.461600, llv_loss=742.847229


Epoch 41/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 041: train_loss=748.267829, kld_loss=5.451199, llv_loss=742.816589


Epoch 42/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 042: train_loss=748.305073, kld_loss=5.495147, llv_loss=742.809998


Epoch 43/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 043: train_loss=748.258214, kld_loss=5.469005, llv_loss=742.789185


Epoch 44/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 044: train_loss=748.224625, kld_loss=5.470803, llv_loss=742.753784


Epoch 45/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 045: train_loss=748.207535, kld_loss=5.498902, llv_loss=742.708557


Epoch 46/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 47/500:   0%|          | 0/28 [01:20<?, ?it/s]

Epoch 046: train_loss=748.179367, kld_loss=5.466260, llv_loss=742.713074


IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out


Epoch 047: train_loss=748.179753, kld_loss=5.480611, llv_loss=742.699219


Epoch 48/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 048: train_loss=748.164540, kld_loss=5.469508, llv_loss=742.694946


Epoch 49/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 049: train_loss=748.142953, kld_loss=5.521706, llv_loss=742.621216


Epoch 50/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 050: train_loss=748.153013, kld_loss=5.512982, llv_loss=742.640015


Epoch 51/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 051: train_loss=748.123389, kld_loss=5.498786, llv_loss=742.624634


Epoch 52/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 052: train_loss=748.097933, kld_loss=5.483574, llv_loss=742.614319


Epoch 53/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 053: train_loss=748.066341, kld_loss=5.487374, llv_loss=742.579102


Epoch 54/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 054: train_loss=748.077645, kld_loss=5.475915, llv_loss=742.601685


Epoch 55/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 055: train_loss=748.123252, kld_loss=5.467258, llv_loss=742.656067


Epoch 56/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 056: train_loss=748.081979, kld_loss=5.498965, llv_loss=742.583008


Epoch 57/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 057: train_loss=748.032469, kld_loss=5.512986, llv_loss=742.519531


Epoch 58/500:   0%|          | 0/28 [00:00<?, ?it/s]

Epoch 058: train_loss=748.099418, kld_loss=5.466068, llv_loss=742.633240


Epoch 59/500:   0%|          | 0/28 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x10e67f7e0>
Traceback (most recent call last):
  File "/Users/meldor/Desktop/vae-paper/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/Users/meldor/Desktop/vae-paper/.venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1628, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/meldor/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/meldor/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/meldor/.local/share/uv/python/cpython-3.12.11-macos-aarch64-none/lib/

In [50]:
import torch
from torch.distributions import Normal, Independent

loc   = torch.tensor([[0., 0., 0.],    # sample 1 mean vector
                      [1., 1., 1.]])   # sample 2 mean vector
scale = torch.ones_like(loc)       # std = 1 everywhere

In [51]:
loc, scale

(tensor([[0., 0., 0.],
         [1., 1., 1.]]),
 tensor([[1., 1., 1.],
         [1., 1., 1.]]))

In [52]:
base = Normal(loc, scale)
z = torch.tensor([[0., 0., 0.],
                  [1., 2., 3.]])   # two latent vectors
logp = base.log_prob(z)
print(logp)

tensor([[-0.9189, -0.9189, -0.9189],
        [-0.9189, -1.4189, -2.9189]])


In [53]:
logp.sum(dim=-1)   # [2]

tensor([-2.7568, -5.2568])

In [48]:
diag_gauss = Independent(Normal(loc, scale), 1)
logp_indep = diag_gauss.log_prob(z)
print(logp_indep)

tensor([-2.7568, -5.2568])


In [49]:
diag_gauss.rsample()

tensor([[ 0.6864,  0.4128, -0.5349],
        [ 1.6668,  1.8176,  1.6142]])