In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../..')

In [3]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
from src.infra import config

path = '../../config/vec.yaml'
opt = config.load_config(path)
opt.path = path

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## Dataset & dataloader

In [7]:
from pathlib import Path

data_root = Path('/home/knpob/Documents/Hinton/data/shape-corr/FAUST_r')

In [13]:
from src.dataset.shape_cor_fast import PairFaustDatasetFast

dataset = PairFaustDatasetFast(
    data_root=data_root,
    phase='train',
    return_faces=True,
    return_L=True,
    return_mass=True,
    num_evecs=200,
    return_evecs=True,
    return_grad=True,
    return_corr=True,
    return_dist=False,
)
len(dataset)

6400

In [124]:
from src.dataloader.shape_cor_batch import BatchShapePairDataLoader

dataloader = BatchShapePairDataLoader(
    dataset,
    batch_size=8,
    shuffle=False,
    num_workers=0,
)
len(dataloader)

800

In [125]:
from src.utils.tensor import to_device

for idx, data in enumerate(dataloader):
    data = to_device(data, device)
    if idx == 1:
        break

data['first']['name'], data['second']['name'], 

(['tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000',
  'tr_reg_000'],
 ['tr_reg_008',
  'tr_reg_009',
  'tr_reg_010',
  'tr_reg_011',
  'tr_reg_012',
  'tr_reg_013',
  'tr_reg_014',
  'tr_reg_015'])

In [126]:
def corr2pointmap_vectorized(corr_x, corr_y, num_verts_y):
    """
    Convert a pair of correspondences to point-to-point map in a vectorized manner.
    Args:
        corr_x (torch.Tensor): Correspondences from template to target. Shape [B, V_t] _P.S. V_t is the number of vertices in the template shape._
        corr_y (torch.Tensor): Correspondences from target to template. Shape [B, V_t]
        num_verts_y (int): Number of vertices in the target shape.
    Returns:
        p2p (torch.Tensor): Point-to-point map (shape y -> shape x). Shape [B, V_y] _P.S. Padded points will have value -1.
    """
    # template -(corr_x)-> shape x <--> shape y <-(corr_y)- target
    # i.e. the i-th row of corr_y is correspongding with the i-th row of corr_x
    B, V_t = corr_x.shape
    batch_idx = torch.arange(B, device=corr_y.device).unsqueeze(1).expand(B, V_t)  # [B, V_t]
    p2p_t = torch.full((B, V_t), -1, dtype=torch.long).to(device=corr_y.device)
    p2p_t[batch_idx, corr_y] = corr_x

    # get p2p in shape [B, V_y]
    V_y = num_verts_y

    if V_t > V_y:
        p2p = p2p_t[:, :V_y]

    else:
        p2p = torch.full((B, V_y), -1, dtype=torch.long).to(device=corr_y.device)
        p2p[:, :V_t] = p2p_t
    
    return p2p

In [127]:
data['second']['corr']

tensor([[4466, 1211, 2037,  ..., 3712,  182, 2703],
        [4414, 1138, 1930,  ..., 3604,  141, 2499],
        [4420, 1137, 1903,  ..., 3627, 2400, 2591],
        ...,
        [4402, 1169, 1929,  ..., 3664,  160, 2617],
        [4306, 1089, 1845,  ..., 3604,  146, 2490],
        [4433, 1149, 1941,  ..., 3656,  180, 2601]], device='cuda:0')

In [128]:
data['first']['corr']

tensor([[4414, 1129, 1929,  ..., 3617,  173, 2482],
        [4414, 1129, 1929,  ..., 3617,  173, 2482],
        [4414, 1129, 1929,  ..., 3617,  173, 2482],
        ...,
        [4414, 1129, 1929,  ..., 3617,  173, 2482],
        [4414, 1129, 1929,  ..., 3617,  173, 2482],
        [4414, 1129, 1929,  ..., 3617,  173, 2482]], device='cuda:0')

In [129]:
# from src.utils.fmap import corr2pointmap_vectorized

p2p = corr2pointmap_vectorized(
    corr_x=data['first']['corr'],
    corr_y=data['second']['corr'],
    num_verts_y=max(data['second']['num_verts']),
)
p2p

tensor([[   0,   16,    2,  ...,   -1,   -1,   -1],
        [   0, 2383,   96,  ...,   -1,   -1,   -1],
        [   0,   28,    2,  ..., 1949,   -1,   -1],
        ...,
        [  71,   16,    2,  ...,   -1,   -1,   -1],
        [   0,    1,  118,  ..., 1949,   -1,   -1],
        [  70,   28,  118,  ...,   -1,   -1,   -1]], device='cuda:0')

In [None]:
import torch

# verts_mask_x = data['first']['verts_mask']
# verts_mask_y = data['second']['verts_mask']
B, V_y, V_x = len(p2p), max(data['second']['num_verts']), max(data['first']['num_verts'])


Pyx

tensor([[[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

In [158]:
Pyx.sum(axis=1)

tensor([[1., 1., 1.,  ..., 1., 1., 2.],
        [1., 1., 1.,  ..., 0., 1., 2.],
        [1., 1., 1.,  ..., 1., 1., 2.],
        ...,
        [0., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 0., 1., 1.],
        [1., 0., 0.,  ..., 0., 1., 1.]], device='cuda:0')