In [1]:
import os
from pathlib import Path
import itertools
from enum import Enum
import hashlib
import math
import pickle
import json
import asyncio
import aiohttp
import random
import progressbar

import einops
import einx
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE

import torch
import torch.nn as nn
import torch.nn.utils as utils
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, RandomSampler, SubsetRandomSampler, BatchSampler
import torchvision
from torchvision.io import read_image, ImageReadMode
from torchvision.utils import save_image
from torchinfo import summary
from torchcodec.decoders import VideoDecoder
import lightning as L
import lightning.pytorch as pl
import lightning.pytorch.callbacks as callbacks
import xformers
# from xformers.factory.model_factory import xFormer, xFormerConfig

In [2]:
from src.panoptic_downloader import PanopticDownloader, PanopticScene
from src.panoptic_dataset import PanopticDataset
from src.plenoptic_dataset import PlenopticDataset

In [3]:
torch.__version__

'2.7.0+cu126'

In [4]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device


'cuda'

In [5]:
# d = PanopticDownloader(device='cuda')
# # d._get_scene_names = lambda _: ['161029_hands2']  # TODO
# await d.load(scene_names_file='res/tmp/panoptic_scene_names.txt')

In [6]:
# d.download_views('res/tmp/scenes/', 'hd', 4, 3)

In [7]:
d = PanopticDataset('res/tmp/scenes/')

In [8]:
v = d.__getitem__(1)

In [9]:
v[0]['video'], v[0]['shape']

(<torchcodec.decoders._video_decoder.VideoDecoder at 0x7f714ee9c2f0>,
 [11598, 3, 1080, 1920])

In [10]:
v[0]

{'video': <torchcodec.decoders._video_decoder.VideoDecoder at 0x7f714ee9c2f0>,
 'K': tensor([[1.3961e+03, 0.0000e+00, 9.4912e+02],
         [0.0000e+00, 1.3927e+03, 5.4855e+02],
         [0.0000e+00, 0.0000e+00, 1.0000e+00]]),
 'R': tensor([[-0.9413, -0.0377,  0.3354],
         [ 0.1810,  0.7822,  0.5962],
         [-0.2848,  0.6219, -0.7295]]),
 't': tensor([[ -6.4672],
         [112.8077],
         [359.5437]]),
 'fps': 29.97,
 'shape': [11598, 3, 1080, 1920]}

In [11]:
v, K, R, t = [v[0][i] for i in ['video', 'K', 'R', 't']]
v, K, R, t


(<torchcodec.decoders._video_decoder.VideoDecoder at 0x7f714ee9c2f0>,
 tensor([[1.3961e+03, 0.0000e+00, 9.4912e+02],
         [0.0000e+00, 1.3927e+03, 5.4855e+02],
         [0.0000e+00, 0.0000e+00, 1.0000e+00]]),
 tensor([[-0.9413, -0.0377,  0.3354],
         [ 0.1810,  0.7822,  0.5962],
         [-0.2848,  0.6219, -0.7295]]),
 tensor([[ -6.4672],
         [112.8077],
         [359.5437]]))

In [12]:
d2 = PlenopticDataset('res/tmp/plenoptic/')

In [13]:
d2

<src.plenoptic_dataset.PlenopticDataset at 0x7f714ef52270>

In [14]:
v = d2.__getitem__(0)

In [15]:
v

[{'video': <torchcodec.decoders._video_decoder.VideoDecoder at 0x7f714ee9c4a0>,
  'K': tensor([[1.4585e+03, 0.0000e+00, 1.3520e+03],
          [0.0000e+00, 1.4585e+03, 1.0140e+03],
          [0.0000e+00, 0.0000e+00, 1.0000e+00]]),
  'R': tensor([[-0.0272,  0.8776,  0.4786],
          [ 0.9996,  0.0286,  0.0042],
          [-0.0100,  0.4786, -0.8780]], dtype=torch.float64),
  't': tensor([[ 5.4591],
          [-1.0853],
          [ 0.6145]], dtype=torch.float64),
  'fps': 60,
  'shape': [1200, 3, 2028, 2704]},
 {'video': <torchcodec.decoders._video_decoder.VideoDecoder at 0x7f703f128860>,
  'K': tensor([[1.4585e+03, 0.0000e+00, 1.3520e+03],
          [0.0000e+00, 1.4585e+03, 1.0140e+03],
          [0.0000e+00, 0.0000e+00, 1.0000e+00]]),
  'R': tensor([[-0.0239,  0.9283,  0.3710],
          [ 0.9997,  0.0238,  0.0049],
          [-0.0043,  0.3710, -0.9286]], dtype=torch.float64),
  't': tensor([[ 4.6324],
          [-1.1316],
          [ 0.0874]], dtype=torch.float64),
  'fps': 60,
  'sh

(tirado do TCC)

O raio $\mathbf{r}_{ij}$ que passa pelo píxel $(i, j)$ é dado por:

$$
\begin{align}
  \mathbf{r}_{ij} &= (\mathbf{o}_{ij}, \mathbf{d}_{ij}) \\
  \mathbf{o}_{ij} &= - Q R^{T} \mathbf{t} \\
  \mathbf{d}_{ij} &= \frac{\mathbf{d}'_{ij}}{\lVert \mathbf{d}'_{ij} \rVert} \\
  \mathbf{d}'_{ij} &= Q R^{T} K^{-1} Q^{-1} \mathbf{x}_{ij, cam} \\
  \mathbf{x}'_{cam} &=
  \begin{bmatrix}
    fx + z p_{x} & fy + z p_{y} & z & 1
  \end{bmatrix}^{T}
\end{align}
$$

In [13]:
def new_create_plucker_embeddings():
    # (...B, H, W, C)
    # B = n videos x t frames
    pass


In [None]:
def compute_plucker_embeddings(f, wx, vecs, T):
    # TODO We assume images have even width and height
    # Input shapes: (B,), (B,), (B, 3, 3), (B, 3, 4)
    # vecs contains the right (vecs[b, 0, :]), up (vecs[b, 1, :]), and forward (vecs[b, 2, :]) unit vectors of the camera in the camera frame
    R, t = T[:, :, :3], T[:, :, 3] # Shapes: (B, 3, 3), (B, 3)

    # TODO wrong, res_x and res_y should come from image
    ry, rx = T.shape[-2], T.shape[-1]
    wy = wx * (ry / rx) # Shape (B,)

    # Creating tensors with indices
    i = torch.arange(rx, dtype=torch.float64, device=T.device)
    j = torch.arange(ry, dtype=torch.float64, device=T.device)

    # Computing displacements
    # Shapes: (W,), (H,)
    dx = ((i + 0.5) / rx - 0.5)
    dy = -((j + 0.5) / ry - 0.5)

    dx2 = torch.einsum('b,i->bi', wx, dx) # dx2_bi = wx_b * dx_i
    dy2 = torch.einsum('b,j->bj', wy, dy) # dy2_bj = wy_b * dy_j
    
    # Computing pixel point in camera frame
    v1 = torch.einsum('bi,bc->bic', dx2, vecs[:, 0, :]) # v1_bic = dx2_bi * vr_c
    v2 = torch.einsum('bj,bc->bjc', dy2, vecs[:, 1, :]) # v2_bjc = dy2_bj * vu_c
    v3 = torch.einsum('b,bc->bc', f, vecs[:, 2, :]) # v3_bc = f_b * vf_c

    # q_bijc = v1_bic + v2_bjc + v3_bc
    q = v1[:, :, None, :] + v2[:, None, :, :] + v3[:, None, None, :] # TODO test speed with unsqueeze
    
    p = t[:, :, None, None]
    l = torch.einsum('bijc,bkc->bkji', q, R) # l_bijk = q_bijc * R_bkc # TODO test speed with unsqueeze
    m = torch.cross(p, l, dim=1)
    
    # Plucker ray embeddings
    pl = torch.cat([l, m], dim=1) # Shape: (B, 6, H, W)

    return pl

B = 2
f = torch.tensor(1, device=device).repeat(B)
wx = torch.tensor(8, device=device).repeat(B)
vecs = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], dtype=torch.float64, device=device).repeat(B, 1, 1)
T = torch.tensor([[1, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0]], dtype=torch.float64, device=device).repeat(B, 1, 1)
# img = torch.zeros((B, 3, 4, 4), dtype=torch.float64, device=device)
# CreatePluckerRayEmbedding()((f, wx, vecs, T, img))[:, 6:, :, :]
compute_plucker_embeddings(f, wx, vecs, T)


tensor([[[[-3.0000, -1.0000,  1.0000,  3.0000],
          [-3.0000, -1.0000,  1.0000,  3.0000],
          [-3.0000, -1.0000,  1.0000,  3.0000]],

         [[ 2.0000,  2.0000,  2.0000,  2.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [-2.0000, -2.0000, -2.0000, -2.0000]],

         [[-1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000, -1.0000]],

         [[-0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 1.0000,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  1.0000]],

         [[ 2.0000,  2.0000,  2.0000,  2.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [-2.0000, -2.0000, -2.0000, -2.0000]]],


        [[[-3.0000, -1.0000,  1.0000,  3.0000],
          [-3.0000, -1.0000,  1.0000,  3.0000],
          [-3.0000, -1.000

In [None]:
def patchify_flatten_embeddings(embeddings, p):
    # Input shape: (B, C, H, W)
    patches = embeddings.unfold(2, p, p).unfold(3, p, p) # Shape: (B, C, H/p, W/p, p, p)
    patches = patches.permute(0, 2, 3, 4, 5, 1)
    patches = patches.flatten(3, 5).flatten(0, 2) # Shape: (BWH/p^2, Cp^2)
    
    return patches

def reverse_patchify_flatten_embeddings(patches, p, W, H):
    patches = patches.unflatten(0, (-1, H//p, W//p)).unflatten(3, (p, p, -1))
    patches = patches.permute(0, 5, 1, 2, 3, 4)
    embeddings = patches.permute(0, 1, 2, 4, 3, 5).reshape((*patches.shape[0:2], H, W))
    
    return embeddings

x = torch.arange(2*3*16*16).reshape((2, 3, 16, 16))
for i in range(4):
    for j in range(4):
        assert torch.equal(x.unfold(2, 4, 4).unfold(3, 4, 4)[:, :, i, j, :, :], x[:, :, i*4:i*4+4, j*4:j*4+4])

for T in [1, 2, 4, 8, 16]:
    assert torch.equal(x, reverse_patchify_flatten_embeddings(patchify_flatten_embeddings(x, T), T, 16, 16))


Perceptual loss (from [LVSM](https://arxiv.org/pdf/2410.17242) paper that says its from
[GS-LRM](https://arxiv.org/pdf/2404.19702) paper that says its from
[this paper](https://arxiv.org/pdf/1707.09405) which uses a feature reconstruction loss,
introduced in [this paper](https://arxiv.org/pdf/1603.08155))


In [6]:
def perceptual_loss(perceptual_layers, perceptual_params, imgs, target_imgs):
    # 
    # TODO Fix so that the implementation is right (use every layer and use 1 norm instead of 2 and use hyperparameters for each layer)
    x1, x2 = imgs, target_imgs
    losses = []

    for l in perceptual_layers:
        x1, x2 = l(x1), l(x2)
        C, H, W = x1.shape[-3], x1.shape[-2], x1.shape[-1]
        losses.append(torch.norm(x1 - x2, p=2, dim=-1).sum() / (C * H * W))

    print(losses)
    res = (torch.concat(losses, device=imgs.device) *  perceptual_params).sum()
    # res = torch.concat(losses, device=imgs.device).sum()
    return res

perceptual_model = torchvision.models.convnext_tiny(torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT)
perceptual_model
perceptual_model = perceptual_model.features
perceptual_layers = [lambda x: x] + list(perceptual_model)
perceptual_params = torch.ones(8, dtype=torch.float64)
shape = (4, 3, 64, 64)
I = torch.rand(shape)
[
    perceptual_loss(perceptual_layers, perceptual_params, I, I),
    perceptual_loss(perceptual_layers, perceptual_params, torch.zeros(shape), torch.ones(shape)),
    perceptual_loss(perceptual_layers, perceptual_params, torch.ones(shape), torch.zeros(shape)),
    perceptual_loss(perceptual_layers, perceptual_params, torch.zeros(shape), torch.zeros(shape)),
    perceptual_loss(perceptual_layers, perceptual_params, torch.ones(shape), torch.ones(shape)),
]




[tensor(0.), tensor(0., grad_fn=<DivBackward0>), tensor(0., grad_fn=<DivBackward0>), tensor(0., grad_fn=<DivBackward0>), tensor(0.0336, grad_fn=<DivBackward0>), tensor(0.0081, grad_fn=<DivBackward0>), tensor(0.0938, grad_fn=<DivBackward0>), tensor(0.0443, grad_fn=<DivBackward0>), tensor(0.1299, grad_fn=<DivBackward0>)]


TypeError: concat() received an invalid combination of arguments - got (list, device=torch.device), but expected one of:
 * (tuple of Tensors tensors, int dim = 0, *, Tensor out = None)
 * (tuple of Tensors tensors, name dim, *, Tensor out = None)


TODO document these in obsidian

TODO where to send data to device: send it when loading from dataloader

```py
for x_data, y_data in train_dataloader:
    x_data, y_data = x_data.to(device), y_data.to(device)
```

check
pytorch lightning dataloader for device
https://lightning.ai/docs/pytorch/stable/data/datamodule.html
    LightningDataModule.transfer_batch_to_device()
https://github.com/Lightning-AI/pytorch-lightning/issues/3341

TODO how does it compute gradients for a batch? it keeps aggregating the gradient for individual elements of the batch until we use step and zero_grad, then a new batch starts

TODO
- first layer creates batch tensor with multiple source and target images and a second tensor with indices of target images in first tensor
- second layer (plucker embeddings) receives only batch tensor and outputs plucker ray embeddings
- third layer projects embeddings onto linear tokens
- fourth layer is transformer
- then only in fifth layer we use the target images' indices tensor to get resulting images and compare to actual target images

TODO check tensor views


In [1]:
class LVSM(nn.Module):
    def __init__(self, p, d, l, C, N, h):
        # p is patch size, d is latent size, l is number of latent tokens, C is number of channels in each image
        # N is number of encoder/decoder layers, h is number of attention heads
        super().__init__()
        
        self.p = p
        self.d = d
        self.l = l
        self.C = C
        
        self.linear_in = nn.Linear(in_features=(C + 6) * p * p, out_features=d)
        self.target_linear = nn.Linear(in_features=6 * p * p, out_features=d)
        self.layer_out = nn.Sequential([
            nn.Linear(in_features=d, out_features=(C + 6) * p * p),
            nn.Sigmoid(),
        ])
        
        xformer_config = [
            {
                "reversible": False,
                "block_type": "encoder",
                "num_layers": N,
                "dim_model": d,
                "residual_norm_style": "pre",
                # "position_encoding_config": {
                #     "name": "sine",
                #     "seq_len": self.hparams.block_size,
                # },
                "multi_head_config": {
                    "num_heads": h,
                    "residual_dropout": 0.1,
                    "use_rotary_embeddings": False,
                    "attention": {
                        "name": self.hparams.attention,
                        "dropout": 0.1,
                        "causal": False,
                        "seq_len": self.hparams.block_size,
                        "num_rules": self.hparams.n_head,
                    },
                },
                "feedforward_config": {
                    "name": "MLP",
                    "dropout": 0.1,
                    "activation": "gelu",
                    "hidden_layer_multiplier": self.hparams.hidden_layer_multiplier,
                },
            }
        ]
        
        config = xFormerConfig(xformer_config)
        config.weight_init = 'small'
        self.model = xFormer.from_config(config)
        
    def forward(self, source_batch, target_batch):
        # TODO We assume p divides W and H
        # TODO Here we compute the plucker embeddings in the whole image before breaking it into patches, but in the paper they break the image first, then compute the patches later (test patching before and after computing plucker ray embeddings)
        # TODO we assume batch size 1, so B will be the number of input images of that batch
        # Shapes: (B,), (B,), (B, 3, 3), (B, 3, 4), (B, C, H, W)
        f, wx, vecs, T, imgs = source_batch
        f2, wx2, vecs2, T2 = target_batch
        W, H = imgs.shape[3], imgs.shape[2]

        # Sources and target plucker rays
        # Shapes: (B, C + 6, H, W), (1, 6, H, W)
        source_pl = compute_plucker_embeddings(f, wx, vecs, T)
        source_pl = torch.cat([imgs, source_pl], dim=1)
        target_pl = compute_plucker_embeddings(f2, wx2, vecs2, T2)
        
        # Creates and flattens patches
        source_flattened_patches = patchify_flatten_embeddings(source_pl, self.p)
        target_flattened_patches = patchify_flatten_embeddings(target_pl, self.p)

        # Linear transformation so that they have same shape
        # TODO test directly sending them flattened instead of using linear. in that case, the target patches would need images, but these would be either zeros or learned latent tokens
        source_tokens = self.linear_in(source_flattened_patches)
        target_tokens = self.target_linear(target_flattened_patches)
        
        # Generates output
        # TODO use inheritance use abstract function here and specialize into enc-dec and dec-only
        output = self.model(torch.cat([source_tokens, target_tokens]))
        
        # Generates target embeddings back from output
        # Source outputs are discarded
        target_tokens_out = output[source_tokens.shape[0]:]
        target_flattened_patches_out = self.layer_out(target_tokens_out)
        target_embs_out = reverse_patchify_flatten_embeddings(target_flattened_patches_out, self.p, W, H)
        
        # TODO test using linear instead of just stripping off the final plucker embeddings from the tokens
        out_imgs = target_embs_out[:, :-6, :, :]
        return out_imgs

NameError: name 'L' is not defined

In [14]:
a = nn.Linear(in_features=10, out_features=5)

a(torch.randn(14, 7, 10)).shape


torch.Size([14, 7, 5])

In [None]:
class Model(L.LightningModule):
    def __init__(self, perceptual_params):
        super().__init__()
        # perceptual_params is the parameters for the weighted perceptual loss of the rendered images
        
        self.lvsm = LVSM() #TODO

        # Perceptual loss layers
        # self.perceptual_layers = torchvision.models.vgg19(torchvision.models.VGG19_Weights.DEFAULT).features
        perceptual_model = torchvision.models.convnext_tiny(torchvision.models.ConvNeXt_Tiny_Weights.DEFAULT)
        perceptual_model.eval()
        self.perceptual_layers = perceptual_model.features
        self.perceptual_layers = [lambda x: x] + list(self.perceptual_layers) # This makes the first layer the identity (this layer is used to compute the MSE loss)
        self.perceptual_params = torch.tensor(perceptual_params)

    def step(self, batch):
        f, wx, vecs, T, imgs = batch

        # For each batch, we choose the last image as the target image
        # TODO create multiple targets instead of a single one
        # TODO to do this, probably we would pass the source tokens and copy the transformer state at that point and create target tokens for each image from there
        source_batch = f[:-1], wx[:-1], vecs[:-1], T[:-1], imgs[:-1]
        target_batch = f[-1:], wx[-1:], vecs[-1:], T[-1:]
        target_imgs = imgs[-1:]
        
        gen_imgs = self.lvsm(source_batch, target_batch)


In [None]:
lvsm = LVSM(16, 20, )

In [146]:
np.load('poses_bounds.npy').shape

(18, 17)

In [None]:
class DatasetStandardizer(nn.Module):
    # axis_permutation shape (3, 3) converts permuted/negative axes to <TODO standard>
    # intrinsic_01_range whether the intrinsics map to the 0-1 range
    # time_scaling = 1/FPS
    def __init__(self, axis_permutation, intrinsic_01_range, time_scaling):
        super().__init__()

    def forward(self, I, E, K, t):
        


In [None]:
class PoseEncoder(nn.Module):
    # intrinsic_01_range = if true, intrinsics map coords into points in sensor in range 0-1, otherwise, maps to range 0-W or 0-H
    def __init__(self, p, n_oct, intrinsic_01_range=True):
        super().__init__()
        self.p = p
        self.n_oct = n_oct
        self.intrinsic_01_range = intrinsic_01_range

    # I = images, HW = tuple with height and width
    # Set both if image has been resized, specifying original image height and width in HW
    # We assume images are already resized (always resize them maintaining aspect ratio)
    def forward(self, E, K, t, I = None, HW = None):
        assert (I == None) ^ (HW == None), 'Either I or HW or both should be set'
        
        #TODO corrige hw
        #TODO tem que retornar quanto de padding teve pra tirar o padding na comparacao da loss function
        #TODO na verdade no lugar de retornar o padding ja retorna a visao prevista com padding retirado no modelo final

        HW = I.shape[-2:] if HW == None else HW
        R, T = E[:, :3, :3], E[:, :3, 3]
        Kinv = K.inverse()
        
        # Pads the input so that it is divisible by 'p'
        pad = [((self.p - i) % self.p) for i in HW]
        pad_s = [i // 2 for i in pad]
        pad = (pad_s[1], pad[1] - pad_s[1], pad_s[0], pad[0] - pad_s[0])
        I = F.pad(I, pad, 'constant', 0)

        # Creates vectors for each pixel in screen
        # No need to unflip y axis since it being flipped does not affect the topological structure of the representation
        ranges = [torch.arange(l, dtype=torch.float64) - o + 0.5 for o, l in zip(pad_s, I.shape[-2:])]
        vecs = torch.meshgrid(*ranges, indexing='ij')
        if self.intrinsic_01_range:
            vecs = [v / l for v, l in zip(vecs, HW)]
        vecs = torch.concat([torch.stack([*vecs[::-1]]), torch.ones((1, *vecs[0].shape))], dim=-3)
        vecs = einops.repeat(vecs, 'c h w -> b c h w', b=I.shape[-4])

        # Computes view rays
        o = -einops.einsum(R, T, 'b h w, b h -> b w')  # -R^T t
        o = einops.repeat(o, 'b c -> b c h w', h=I.shape[-2], w=I.shape[-1]) # repeat o for each vec
        d = einops.einsum(R.to(torch.float64), Kinv.to(torch.float64), vecs, 'b x1 c2, b x1 c, b c h w -> b c2 h w') # R^T K^-1 x_ij,cam
        d = d / einx.sum('b [c] h w -> b 3 h w', d * d).sqrt() # normalize d
        
        
        return d

a = PoseEncoder(p=4, n_oct=6)
a.forward(torch.arange(64).reshape((4, 4, 4)) + 0.0, torch.linalg.inv(torch.arange(9).reshape((3, 3)) + 4.0).unsqueeze(0).repeat((4, 1, 1)), 0, torch.ones((4, 2, 5, 4)))

tensor([[[[0.4748, 0.4749, 0.4750, 0.4751],
          [0.4749, 0.4750, 0.4751, 0.4751],
          [0.4749, 0.4750, 0.4751, 0.4751],
          [0.4749, 0.4750, 0.4751, 0.4751],
          [0.4750, 0.4750, 0.4751, 0.4751],
          [0.4750, 0.4750, 0.4751, 0.4751],
          [0.4750, 0.4750, 0.4751, 0.4751],
          [0.4750, 0.4750, 0.4751, 0.4751]],

         [[0.5719, 0.5719, 0.5719, 0.5719],
          [0.5719, 0.5719, 0.5719, 0.5719],
          [0.5719, 0.5719, 0.5719, 0.5719],
          [0.5719, 0.5719, 0.5719, 0.5719],
          [0.5719, 0.5719, 0.5719, 0.5719],
          [0.5719, 0.5719, 0.5719, 0.5719],
          [0.5719, 0.5719, 0.5719, 0.5719],
          [0.5719, 0.5719, 0.5719, 0.5719]],

         [[0.6690, 0.6688, 0.6688, 0.6687],
          [0.6689, 0.6688, 0.6688, 0.6687],
          [0.6689, 0.6688, 0.6688, 0.6687],
          [0.6689, 0.6688, 0.6688, 0.6687],
          [0.6688, 0.6688, 0.6687, 0.6687],
          [0.6688, 0.6688, 0.6687, 0.6687],
          [0.6688, 0.6688, 0

In [None]:
# TODO change to RawDVST, create DVST that also has CNN to reduce dims and PoseWrapper to add a pose estimator to both
class DVST(nn.Module):
    # not specified: H, W, C, N_{context}
    # n_heads has to divide d_lat
    # p has to divide H and W (padding, cropping and resizing)
    def __init__(self, N_enc, N_dec, n_heads, d_lat, e_ff, n_lat, p, n_oct):
        super().__init__()
        
        assert d_lat % n_heads == 0, "n_heads should divide d_lat"

        self.N_enc = N_enc
        self.N_dec = N_dec
        self.n_heads = n_heads
        self.d_lat = d_lat
        self.e_ff = e_ff
        self.n_lat = n_lat
        self.p = p
        self.n_oct = n_oct
        
        self.pose_encoder = PoseEncoder()

    def forward(self, x):
        
