In [107]:
import os
from pathlib import Path
import itertools
from enum import Enum
import hashlib
import math
import pickle

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE

import torch
import torch.nn as nn
import torch.nn.utils as utils
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler, RandomSampler, SubsetRandomSampler, BatchSampler
import torchvision
from torchvision.io import read_image, ImageReadMode
from torchvision.utils import save_image
from torchinfo import summary
import lightning as L
import lightning.pytorch as pl
import lightning.pytorch.callbacks as callbacks
from xformers.factory.model_factory import xFormer, xFormerConfig


In [108]:
torch.__version__

'2.5.1+cu124'

In [115]:
class CreatePluckerRayEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, batch):
        # Only images with even width and height
        # vecs contains the right (vecs[b, 0, :]), up (vecs[b, 1, :]), and forward (vecs[b, 2, :]) unit vectors of the camera in the camera frame
        # Shapes: (B,), (B,), (B, 3, 3), (B, 3, 4), (B, C, H, W)
        f, wx, vecs, T, img = batch
        R, t = T[:, :, :3], T[:, :, 3] # Shapes: (B, 3, 3), (B, 3)
        
        ry, rx = img.shape[-2:]
        wy = wx * (ry / rx) # Shape (B,)

        # Creating tensors with indices
        i = torch.arange(rx, dtype=torch.float64, device=img.device)
        j = torch.arange(ry, dtype=torch.float64, device=img.device)
        
        # Computing displacements
        dx = ((i + 0.5) / rx - 0.5) # Shape: (W,)
        dy = -((j + 0.5) / ry - 0.5) # Shape: (H,)

        dx2 = torch.einsum('b,i->bi', wx, dx) # dx2_bi = wx_b * dx_i
        dy2 = torch.einsum('b,j->bj', wy, dy) # dy2_bj = wy_b * dy_j
        
        # Computing pixel point in camera frame
        v1 = torch.einsum('bi,bc->bic', dx2, vecs[:, 0, :]) # v1_bic = dx2_bi * vr_c
        v2 = torch.einsum('bj,bc->bjc', dy2, vecs[:, 1, :]) # v2_bjc = dy2_bj * vu_c
        v3 = torch.einsum('b,bc->bc', f, vecs[:, 2, :]) # v3_bc = f_b * vf_c

        # q_bijc = v1_bic + v2_bjc + v3_bc
        q = v1[:, :, None, :] + v2[:, None, :, :] + v3[:, None, None, :] # TODO test speed with unsqueeze
        
        p = t[:, :, None, None]
        l = torch.einsum('bijc,bkc->bkji', q, R) # l_bijk = q_bijc * R_bkc # TODO test speed with unsqueeze
        m = torch.cross(p, l, dim=1)
        
        pl = torch.cat([img, l, m], dim=1)
        return pl


In [117]:
emb = CreatePluckerRayEmbedding()

device = 'cuda'
B = 2
f = torch.tensor(1, device=device).repeat(B)
wx = torch.tensor(8, device=device).repeat(B)
vecs = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], dtype=torch.float64, device=device).repeat(B, 1, 1)
T = torch.tensor([[1, 0, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0]], dtype=torch.float64, device=device).repeat(B, 1, 1)
img = torch.zeros((B, 3, 4, 4), dtype=torch.float64, device=device)
emb((f, wx, vecs, T, img))[:, 6:, :, :]


tensor([[[[-0., -0., -0., -0.],
          [-0., -0., -0., -0.],
          [ 0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.]],

         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],

         [[ 3.,  3.,  3.,  3.],
          [ 1.,  1.,  1.,  1.],
          [-1., -1., -1., -1.],
          [-3., -3., -3., -3.]]],


        [[[-0., -0., -0., -0.],
          [-0., -0., -0., -0.],
          [ 0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.]],

         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],

         [[ 3.,  3.,  3.,  3.],
          [ 1.,  1.,  1.,  1.],
          [-1., -1., -1., -1.],
          [-3., -3., -3., -3.]]]], device='cuda:0', dtype=torch.float64)