In [19]:
#TODO lazy hf dataset

# https://huggingface.co/docs/datasets/en/about_mapstyle_vs_iterable
#  search for yield
#  use this to lazily load the videos (in each iteration download the next one and return the current (that has already been downloaded or download too if not))
# https://huggingface.co/docs/datasets/en/video_load
# https://huggingface.co/docs/datasets/en/video_dataset
#  create video dataset
# https://huggingface.co/docs/datasets/en/about_map_batch
#  use to map transformations (resizing etc)
# https://github.com/iejMac/video2dataset
#  check to see how to paralelize the yield (and how to create it abstractly for any dataset of scenes with a list of videos, not just panoptic)
#  actually i think i can do that just using dataset.map batched + yield and dataset.take in streaming dataset, but you would need

# make a dataset that creates a uniform distribution of different video sizes/aspect ratios/cropping options
# then evaluate the model in these environments:
#  same size/aspect/cropping on entire dataset
#  same size/aspect/cropping for videos in a scene but varying for all scenes
#  varying size/aspect/cropping for all videos in all scenes


### Importing stuff

In [20]:
import os
from pathlib import Path
import itertools
from enum import Enum
import hashlib
import math
import pickle
import json
import asyncio
import aiohttp
import random
import progressbar

from matplotlib import pyplot as plt
import open3d as o3d
from open3d.visualization import draw_plotly
from mpl_toolkits.mplot3d import Axes3D

import einops
import einx
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE

import torch
import torch.nn as nn
import torch.nn.utils as utils
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, RandomSampler, SubsetRandomSampler, BatchSampler
import torchvision
from torchvision.io import read_image, ImageReadMode
from torchvision.utils import save_image
from torchinfo import summary
from torchcodec.decoders import VideoDecoder
import lightning as L
import lightning.pytorch as pl
import lightning.pytorch.callbacks as callbacks
import xformers
# from xformers.factory.model_factory import xFormer, xFormerConfig

In [21]:
from src.panoptic_dataset import PanopticDataset
from src.plenoptic_dataset import PlenopticDataset

from src.draw import get_camera_geometry

In [22]:
torch.__version__

'2.7.0+cu126'

In [23]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device


'cuda'

### Loading datasets

Panoptic dataset

In [24]:
dataset_panoptic = PanopticDataset('res/tmp/panoptic/')

In [25]:
v = dataset_panoptic.__getitem__(0)
v[0]

{'video': <torchcodec.decoders._video_decoder.VideoDecoder at 0x7fa810e62e40>,
 'K': tensor([[1.4107e+03, 0.0000e+00, 9.6000e+02],
         [0.0000e+00, 1.3299e+03, 5.4000e+02],
         [0.0000e+00, 0.0000e+00, 1.0000e+00]]),
 'Kinv': tensor([[ 7.0888e-04,  0.0000e+00, -6.8053e-01],
         [ 0.0000e+00,  7.5194e-04, -4.0605e-01],
         [ 0.0000e+00,  0.0000e+00,  1.0000e+00]]),
 'R': tensor([[[-0.6212, -0.0284,  0.7832],
          [ 0.0751,  0.9926,  0.0955],
          [-0.7801,  0.1182, -0.6144]]]),
 't': tensor([[-15.3971, 117.3840, 288.2436]]),
 'time': tensor([0.0000e+00, 3.3367e-02, 6.6733e-02,  ..., 2.0254e+02, 2.0257e+02,
         2.0260e+02]),
 'shape': [6073, 3, 1080, 1920]}

In [26]:
v, K, R, t2 = [v[0][i] for i in ['video', 'K', 'R', 't']]
v, K, R, t2


(<torchcodec.decoders._video_decoder.VideoDecoder at 0x7fa810e62e40>,
 tensor([[1.4107e+03, 0.0000e+00, 9.6000e+02],
         [0.0000e+00, 1.3299e+03, 5.4000e+02],
         [0.0000e+00, 0.0000e+00, 1.0000e+00]]),
 tensor([[[-0.6212, -0.0284,  0.7832],
          [ 0.0751,  0.9926,  0.0955],
          [-0.7801,  0.1182, -0.6144]]]),
 tensor([[-15.3971, 117.3840, 288.2436]]))

Plenoptic dataset

In [27]:
dataset_plenoptic = PlenopticDataset('res/tmp/plenoptic/')

In [28]:
v = dataset_plenoptic.__getitem__(0)
v[0]

{'video': <torchcodec.decoders._video_decoder.VideoDecoder at 0x7fa810e62f00>,
 'K': tensor([[1.4585e+03, 0.0000e+00, 1.3520e+03],
         [0.0000e+00, 1.4585e+03, 1.0140e+03],
         [0.0000e+00, 0.0000e+00, 1.0000e+00]]),
 'Kinv': tensor([[ 6.8564e-04,  0.0000e+00, -9.2698e-01],
         [ 0.0000e+00,  6.8564e-04, -6.9523e-01],
         [ 0.0000e+00,  0.0000e+00,  1.0000e+00]]),
 'R': tensor([[[-0.0272,  0.8776,  0.4786],
          [ 0.9996,  0.0286,  0.0042],
          [-0.0100,  0.4786, -0.8780]]], dtype=torch.float64),
 't': tensor([[ 5.4591, -1.0853,  0.6145]], dtype=torch.float64),
 'time': tensor([0.0000e+00, 3.3333e-02, 6.6667e-02,  ..., 3.9900e+01, 3.9933e+01,
         3.9967e+01]),
 'shape': [1200, 3, 2028, 2704]}

### Pose encoder

Auxiliary functions

In [29]:
def compute_pad(hw, p):
    # Pads the input so that it is divisible by 'p'
    # hw: (2,), p: (1)
    
    pad_raw = [((p - i) % p) for i in hw]
    pad_s = [i // 2 for i in pad_raw]
    pad = (pad_s[1], pad_raw[1] - pad_s[1], pad_s[0], pad_raw[0] - pad_s[0])
    hw_padded = [i + d for i, d in zip(hw, pad_raw)]
    
    # pad: (pad_width_start, pad_width_end, pad_height_start, pad_height_end) (starts from last dimension to pad)
    # hw_padded: (2,), pad: (4,)
    return hw_padded, pad

compute_pad([5, 4], 4)


([8, 4], (0, 0, 1, 2))

In [30]:
def compute_view_rays(vecs, Kinv, R, t):
    # Computes view rays (o, d)
    # vecs: meshgrid vecs, first dim is (x, y, z)
    # vecs: (3, h, w), Kinv, R: (B, 3, 3), t: (B, 3)

    # TODO check without double precision
    vecs, Kinv, R, t = [i.to(torch.float64) for i in (vecs, Kinv, R, t)]

    h, w = vecs.shape[-2:]

    o = -einx.dot('... h w, ... h -> ... w', R, t)  # -R^T t
    o = einx.rearrange('... c -> ... c h w', o, h=h, w=w) # repeat o for each vec # TODO repeating maybe not needed
    d = einx.dot('... x1 c2, x1 c, c h w -> ... c2 h w', R.to(torch.float64), Kinv.to(torch.float64), vecs) # R^T K^-1 x_ij,cam # TODO check without double precision
    d = d / einx.sum('b [c] h w -> b 3 h w', d * d).sqrt() # normalize d

    # o, d: (B, 3, H, W)
    return o, d

def compute_plucker_rays(o, d):
    # o, d: (B, 3, H, W)

    l = torch.cross(o, d, dim=-3)
    rays = torch.concat([d, l], dim=-3)

    # rays: (B, 6, H, W)
    return rays


In [31]:
def compute_octaves(v, n_oct, dim=-1):
    assert dim < 0, 'No positive dim allowed'

    v = v * torch.pi
    tensors = [torch.sin(v), torch.cos(v)]
    last = v
    for _ in range(n_oct - 1):
        last = last * 2
        tensors.append(torch.sin(last))
        tensors.append(torch.cos(last))
        
    

    return torch.stack(tensors, dim=dim).flatten(dim - 1, dim)

v = torch.zeros((3, 6, 2))
v[0, 0, 0] = 1
compute_octaves(v, n_oct=4, dim=-2)

tensor([[[-8.7423e-08,  0.0000e+00],
         [-1.0000e+00,  1.0000e+00],
         [ 1.7485e-07,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 3.4969e-07,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 6.9938e-07,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
 

Pose encoder

In [32]:
class PoseEncoder(nn.Module):
    def __init__(self, d_lat, n_oct, C, p):
        super().__init__()
        self.d_lat = d_lat
        self.n_oct = n_oct
        self.C = C
        self.p = p

        # TODO initialize w gaussian
        # (C, p, p)
        self.im_parameter = nn.Parameter(torch.zeros((self.C, self.p, self.p)))

        # TODO check without double precision
        self.linear = nn.Linear(
            in_features=(12 * self.n_oct + self.C) * self.p ** 2 + 2 * self.n_oct,
            #in_features=(6 + self.C) * self.p ** 2 + 1, # Without octaves, just for testing
            out_features=d_lat,
            dtype=torch.float64
        )
        
    def _compute_view_rays(self, Kinv, R, t, pad, hw_padded):
        # The forward function was split into two to display the view rays layer
        
        pad_s = pad[-2::-2]

        # Creates vectors for each pixel in screen
        # No need to unflip y axis since it being flipped does not affect the topological structure of the representation TODO is it true?
        ranges = [torch.arange(l, dtype=torch.float64) - o + 0.5 for o, l in zip(pad_s, hw_padded)]
        # In the original LVSM impl, the K^{-1} multiplication is done here bc its faster, maybe change the code to do that too (https://github.com/Haian-Jin/LVSM/blob/ebeff4989a3e1ec38fcd51ae24919d0eadf38c8f/utils/data_utils.py#L71-L73)
        # Used torch.ones since it seems to be used by most of the vision models similar to this (e.g. lvsm, see https://github.com/Haian-Jin/LVSM/blob/ebeff4989a3e1ec38fcd51ae24919d0eadf38c8f/utils/data_utils.py#L73)
        vecs = torch.meshgrid(*ranges, indexing='ij')
        vecs = torch.concat([torch.stack([*vecs[::-1]]), torch.ones((1, *vecs[0].shape))], dim=-3)

        o, d = compute_view_rays(vecs, Kinv, R, t) # o, d: (B, 3, H, W)
        return o, d

    # I = images, HW = tuple with height and width
    # Set both if image has been resized, specifying original image height and width in HW
    # We assume images are already resized (always resize them maintaining aspect ratio)
    # We assume images are already padded so that p divides H and W
    # We assume that the K matrix uses xy mapping instead of uv (sensor area is real in range [(0, 0), (h, w)], not [(0, 0), (1, 1)])
    # We assume images are in type float with colors in range 0-1
    def forward(self, Kinv, R, t, time, I=None, hw=None):
        # I: (B, C, H, W), K, Kinv, R: (B, 3, 3), t: (B, 3), time: (B,), hw: (2,)
 
        #TODO corrige hw
        #TODO tem que retornar quanto de padding teve pra tirar o padding na comparacao da loss function
        #TODO na verdade no lugar de retornar o padding ja retorna a visao prevista com padding retirado no modelo final
        
        assert (I == None) ^ (hw == None), 'Either I or HW or both should be set'
        
        if I is not None:
            hw = I.shape[-2:]
            I = I * 2 - 1 # Normalizing image

        # Pads the input so that it is divisible by 'p'
        hw_padded, pad = compute_pad(hw, self.p)
        I = F.pad(I, pad, 'constant', 0) if I is not None else None

        o, d = self._compute_view_rays(Kinv, R, t, pad, hw_padded)
        plucker_rays = compute_plucker_rays(o, d) # (B, 6, H, W)

        # (B, 2 * 6 * n_oct, H, W)
        plucker_octs = compute_octaves(plucker_rays, self.n_oct, dim=-3)
        #plucker_octs = torch.concat([plucker_octs, I * 2 - 1], dim=-3) if I is not None else plucker_octs # Transforming and concatenating image

        # Concatenating image with octaves and rearranging into patches
        # (B, HW/p^2, (12 * n_oct + C) * p^2)
        if I is None:
            patches = einx.rearrange('... c1 (h p1) (w p2), c2 p1 p2 -> ... (h w) ((c1 + c2) p1 p2)', plucker_octs, self.im_parameter, p1=self.p, p2=self.p)
        else:
            patches = einx.rearrange('... c1 (h p1) (w p2), ... c2 (h p1) (w p2) -> ... (h w) ((c1 + c2) p1 p2)', plucker_octs, I, p1=self.p, p2=self.p)

        time_octs = compute_octaves(time.unsqueeze(-1), self.n_oct, dim=-1) # (B, 2 * n_oct)

        # (B, HW/p^2, (12 * n_oct + C) * p^2 + 2 * n_oct)
        tokens = einx.rearrange('... hw c1, ... c2 -> ... hw (c1 + c2)', patches, time_octs)
        tokens = self.linear(tokens)

        return tokens

B = 4
C = 2
K = torch.linalg.inv(torch.arange(9).reshape((3, 3)) + 4.0)
Kinv = K.inverse()
R, t = torch.arange(B * 9).reshape((B, 3, 3)), torch.arange(B * 3).reshape((B, 3))
I = torch.ones((B, C, 5, 4))

pose_encoder = PoseEncoder(d_lat=12, n_oct=6, C=C, p=4)
pose_encoder.forward(Kinv, R, t, torch.arange(B) / 4, I).shape # (4, 2, 12)
#a.forward(Kinv, R, t, torch.arange(B) / 4, None, I.shape[-2:])

torch.Size([4, 2, 12])

In [None]:
# TODO change to RawDVST, create DVST that also has CNN to reduce dims and PoseWrapper to add a pose estimator to both
class DVST(nn.Module):
    # not specified: H, W, C, N_{context}
    # n_heads has to divide d_lat
    # p has to divide H and W (padding, cropping and resizing)
    def __init__(self, N_enc, N_dec, n_heads, d_lat, e_ff, n_lat, p, n_oct):
        super().__init__()
        
        assert d_lat % n_heads == 0, "n_heads should divide d_lat"

        self.N_enc = N_enc
        self.N_dec = N_dec
        self.n_heads = n_heads
        self.d_lat = d_lat
        self.e_ff = e_ff
        self.n_lat = n_lat
        self.p = p
        self.n_oct = n_oct
        
        self.pose_encoder = PoseEncoder(self.d_lat, self.n_oct, self.C, self.p)

    def forward(self, x):
        Kinv, R, t, time, I, hw
