In [1]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torchvision.io import read_image, ImageReadMode
from pathlib import Path
import itertools


Dataset requirements:

- Should be in the form of a folder with subfolders, where each subfolder is a set of pictures from a scene;
- For each picture in a subfolder, there should be a `.npy` file with the same name as the picture file that contains a $3 \times 4$ matrix with the camera extrinsics;
- All subfolders must have the same number of pictures;
- All pictures must have same size, same number of channels, and equal width and height;

TODO add support for varying number of images in each scene

TODO
- add support for varying image sizes
- in that case, we would resize images respecting aspect ratio so that they would fill in the model width and height
- then the model would not need to predict the tokens for the non-existing regions (the regions of the smallest size, that would have no content), or it would need to predict that they are unknown data
- how to deal with unknown data?
- test resizing by interpolation, by just getting the color at that specific point, or interpolating the 4 points close to it


TODO create preprocessing scripts to resize all images (test resizing average, or choosing color of middle point)


In [None]:
path = 'res/tmp/000-000'
file_paths = [os.path.join(r, f) for r, _, fs in os.walk(path) for f in fs]
file_paths.sort()
img_paths, mat_paths = list(filter(lambda x: x.endswith('.png'), file_paths)), list(filter(lambda x: x.endswith('.npy'), file_paths))
display(len(file_paths), file_paths[:3])
display(len(img_paths), img_paths[:3])
display(len(mat_paths), mat_paths[:3])
assert all((os.path.splitext(x)[0] == os.path.splitext(y)[0] for x, y in zip(img_paths, mat_paths)))


120000

['000-000\\000074a334c541878360457c672b6c2e\\012.npy',
 '000-000\\000074a334c541878360457c672b6c2e\\012.png',
 '000-000\\000074a334c541878360457c672b6c2e\\013.npy']

60000

['000-000\\000074a334c541878360457c672b6c2e\\012.png',
 '000-000\\000074a334c541878360457c672b6c2e\\013.png',
 '000-000\\000074a334c541878360457c672b6c2e\\014.png']

60000

['000-000\\000074a334c541878360457c672b6c2e\\012.npy',
 '000-000\\000074a334c541878360457c672b6c2e\\013.npy',
 '000-000\\000074a334c541878360457c672b6c2e\\014.npy']

In [88]:
parents = [str(Path(i).parent) for i in img_paths]
paths = list(zip(parents, img_paths, mat_paths))
paths[:3]


[('000-000\\000074a334c541878360457c672b6c2e',
  '000-000\\000074a334c541878360457c672b6c2e\\012.png',
  '000-000\\000074a334c541878360457c672b6c2e\\012.npy'),
 ('000-000\\000074a334c541878360457c672b6c2e',
  '000-000\\000074a334c541878360457c672b6c2e\\013.png',
  '000-000\\000074a334c541878360457c672b6c2e\\013.npy'),
 ('000-000\\000074a334c541878360457c672b6c2e',
  '000-000\\000074a334c541878360457c672b6c2e\\014.png',
  '000-000\\000074a334c541878360457c672b6c2e\\014.npy')]

In [89]:
data = list(list(j) for _, j in itertools.groupby(paths, lambda i: i[0]))
data[0][:3]


[('000-000\\000074a334c541878360457c672b6c2e',
  '000-000\\000074a334c541878360457c672b6c2e\\012.png',
  '000-000\\000074a334c541878360457c672b6c2e\\012.npy'),
 ('000-000\\000074a334c541878360457c672b6c2e',
  '000-000\\000074a334c541878360457c672b6c2e\\013.png',
  '000-000\\000074a334c541878360457c672b6c2e\\013.npy'),
 ('000-000\\000074a334c541878360457c672b6c2e',
  '000-000\\000074a334c541878360457c672b6c2e\\014.png',
  '000-000\\000074a334c541878360457c672b6c2e\\014.npy')]

In [90]:
def invert_transform(T):
    R, t = T[:3, :3], T[:3, 3:]
    R2 = R.T
    t2 = - R2 @ t
    return np.concatenate([np.concatenate([R2, t2], axis=1), np.array([[0, 0, 0, 1]])], axis=0)

# Testing function, should return identity matrix
T = np.load(mat_paths[0])
T2 = invert_transform(T) @ np.concatenate([T, np.array([[0, 0, 0, 1]])], axis=0)
T, T2


(array([[ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
          0.00000000e+00],
        [ 0.00000000e+00, -1.34358856e-07,  1.00000012e+00,
         -2.68717713e-07],
        [ 0.00000000e+00, -1.00000012e+00, -1.34358856e-07,
         -2.00000024e+00]]),
 array([[1.        , 0.        , 0.        , 0.        ],
        [0.        , 1.00000024, 0.        , 0.        ],
        [0.        , 0.        , 1.00000024, 0.        ],
        [0.        , 0.        , 0.        , 1.        ]]))

In [91]:
def path_depth(path):
    return len(Path(path).parents)

# Testing function
path_depth('a/b/c'), path_depth('/a/b/c/')


(3, 3)

In [92]:
class NVSDataset(Dataset):
    # Intrinsics are (f, wx)
    def __init__(self, path, intrinsics):
        file_paths = [os.path.join(r, f) for r, _, fs in os.walk(path) for f in fs]
        img_paths, mat_paths = self._filter_extension(file_paths, 'png'), self._filter_extension(file_paths, 'npy')
        assert len(img_paths) == len(mat_paths) and all((os.path.splitext(x)[0] == os.path.splitext(y)[0] for x, y in zip(img_paths, mat_paths))), 'There are unmatched pairs of images and numpy extrinsics arrays'

        img_paths, mat_paths = self._process_paths(img_paths), self._process_paths(mat_paths)

        self.path = path
        self.intrinsics = intrinsics
        self.img_paths = img_paths
        self.mat_paths = mat_paths
    
    @staticmethod
    def _filter_extension(paths, ext):
        res = list(filter(lambda x: x.endswith(f'.{ext}'), paths))
        res.sort()
        return res

    @staticmethod
    def _process_paths(paths):
        return list(list(j) for _, j in itertools.groupby(paths, lambda i: str(Path(i).parent)))

    def __len__(self):
        return len(self.img_paths)
    
    def get_scene(self, i):
        img_paths = self.img_paths[i]
        mat_paths = self.mat_paths[i]
        imgs = [read_image(i) / 255.0 for i in img_paths]
        extrinsics = [torch.from_numpy(np.load(i)) for i in mat_paths]
        f, wx = self.intrinsics
        n = len(img_paths)

        f = torch.tensor([f]).repeat(n)
        wx = torch.tensor([wx]).repeat(n)
        camera_vecs = torch.tensor([[[1.0, 0, 0], [0, 1, 0], [0, 0, -1]]]).repeat(n, 1, 1)
        T = torch.stack(extrinsics)
        imgs = torch.stack(imgs)

        return f, wx, camera_vecs, T, imgs

    def __getitem__(self, i):
        return self.get_scene(i)

    # def __getitem__(self, i):
    #     f, wx, vecs, T, imgs = self.get_scene(i)

    #     source_item = f[:-1], wx[:-1], vecs[:-1], T[:-1], imgs[:-1]
    #     target_item = f[-1:], wx[-1:], vecs[-1:], T[-1:], imgs[-1:]

    #     return source_item, target_item


In [None]:
dataset = NVSDataset(path, (0.035, 0.032))


In [94]:
# dataloader = DataLoader(dataset, batch_size=4, sampler=RandomSampler(dataset), collate_fn=lambda b: b)
dataloader = DataLoader(dataset, batch_size=4, collate_fn=lambda b: b)


In [100]:
next(dataloader.__iter__())[0]


(tensor([0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350, 0.0350,
         0.0350, 0.0350, 0.0350]),
 tensor([0.0320, 0.0320, 0.0320, 0.0320, 0.0320, 0.0320, 0.0320, 0.0320, 0.0320,
         0.0320, 0.0320, 0.0320]),
 tensor([[[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 1.,  0.,  0.],
          [ 0.,  1.,  0.],
          [ 0.,  0., -1.]],
 
         [[ 

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device


In [25]:
# torch.set_default_device(device) #TODO Only if needed, since it delays all PyTorch API calls
device, torch.cuda.device_count()

model = nn.Sequential()
model = nn.DataParallel(model)
# model.to(device)
model.device_ids


(DataParallel(
   (module): Sequential()
 ),
 [0])