POINT BREAKDOWN FOR PSET
1. One point for implementing the dataset.
2. One point for plotting a batch item queried from the dataset
3. One point for implementing the model.
4. One point if the model can compute a forward pass without error.
5. One point for the training loop executing without error and the loss de-
creasing.
6. Three points for plotting the model’s output after training and writing
two sentences about why that output is reasonable.

Setup and imports

In [20]:
import sys
import os
import imageio
import skimage
import h5py
import io
import requests

import torch, torchvision
from torchvision import datasets, transforms
from torch.utils.data import IterableDataset

import numpy as np
import random
from einops import rearrange, repeat

import matplotlib.pyplot as plt
from PIL import Image



# # check whether run in Colab
# if 'google.colab' in sys.modules:
#     print('Running in Colab.')
#     !pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
#     !git clone https://github.com/facebookresearch/mae.git
#     sys.path.append('./mae')
# else:
#     sys.path.append('..')
# import models_mae

### Multiview functions from Problem set 1,2,3

In [18]:
def homogenize_points(points: torch.Tensor):
    """Appends a "1" to the coordinates of a (batch of) points of dimension DIM.

    Args:
        points: points of shape (..., DIM)

    Returns:
        points_hom: points with appended "1" dimension.
    """
    ones = torch.ones_like(points[..., :1], device=points.device)
    return torch.cat((points, ones), dim=-1)


def homogenize_vecs(vectors: torch.Tensor):
    """Appends a "0" to the coordinates of a (batch of) vectors of dimension DIM.

    Args:
        vectors: vectors of shape (..., DIM)

    Returns:
        vectors_hom: points with appended "0" dimension.
    """
    zeros = torch.zeros_like(vectors[..., :1], device=vectors.device)
    return torch.cat((vectors, zeros), dim=-1)


def unproject(
    xy_pix: torch.Tensor,
    z: torch.Tensor,
    intrinsics: torch.Tensor
    ) -> torch.Tensor:
    """Unproject (lift) 2D pixel coordinates x_pix and per-pixel z coordinate
    to 3D points in camera coordinates.

    Args:
        xy_pix: 2D pixel coordinates of shape (..., 2)
        z: per-pixel depth, defined as z coordinate of shape (..., 1)
        intrinscis: camera intrinscics of shape (..., 3, 3)

    Returns:
        xyz_cam: points in 3D camera coordinates.
    """
    xy_pix_hom = homogenize_points(xy_pix)
    xyz_cam = torch.einsum('...ij,...kj->...ki', intrinsics.inverse(), xy_pix_hom)
    xyz_cam *= z
    return xyz_cam


def transform_world2cam(xyz_world_hom: torch.Tensor, cam2world: torch.Tensor) -> torch.Tensor:
    """Transforms points from 3D world coordinates to 3D camera coordinates.

    Args:
        xyz_world_hom: homogenized 3D points of shape (..., 4)
        cam2world: camera pose of shape (..., 4, 4)

    Returns:
        xyz_cam: points in camera coordinates.
    """
    world2cam = torch.inverse(cam2world)
    return transform_rigid(xyz_world_hom, world2cam)


def transform_cam2world(xyz_cam_hom: torch.Tensor, cam2world: torch.Tensor) -> torch.Tensor:
    """Transforms points from 3D world coordinates to 3D camera coordinates.

    Args:
        xyz_cam_hom: homogenized 3D points of shape (..., 4)
        cam2world: camera pose of shape (..., 4, 4)

    Returns:
        xyz_world: points in camera coordinates.
    """
    return transform_rigid(xyz_cam_hom, cam2world)


def transform_rigid(xyz_hom: torch.Tensor, T: torch.Tensor) -> torch.Tensor:
    """Apply a rigid-body transform to a (batch of) points / vectors.

    Args:
        xyz_hom: homogenized 3D points of shape (..., 4)
        T: rigid-body transform matrix of shape (..., 4, 4)

    Returns:
        xyz_trans: transformed points.
    """
    return torch.einsum('...ij,...kj->...ki', T, xyz_hom)


def get_unnormalized_cam_ray_directions(xy_pix:torch.Tensor,
                                        intrinsics:torch.Tensor) -> torch.Tensor:
    return unproject(xy_pix, torch.ones_like(xy_pix[..., :1], device=xy_pix.device),  intrinsics=intrinsics)


def get_world_rays(xy_pix: torch.Tensor,
                   intrinsics: torch.Tensor,
                   cam2world: torch.Tensor,
                   ) -> torch.Tensor:
    # Get camera origin of camera 1
    cam_origin_world = cam2world[..., :3, -1]

    # Get ray directions in cam coordinates
    ray_dirs_cam = get_unnormalized_cam_ray_directions(xy_pix, intrinsics)

    # Homogenize ray directions
    rd_cam_hom = homogenize_vecs(ray_dirs_cam)

    # Transform ray directions to world coordinates
    rd_world_hom = transform_cam2world(rd_cam_hom, cam2world)

    # Tile the ray origins to have the same shape as the ray directions.
    # Currently, ray origins have shape (batch, 3), while ray directions have shape
    cam_origin_world = repeat(cam_origin_world, 'b ch -> b num_rays ch', num_rays=ray_dirs_cam.shape[1])

    # Return tuple of cam_origins, ray_world_directions
    return cam_origin_world, rd_world_hom[..., :3]


def get_opencv_pixel_coordinates(
    y_resolution: int,
    x_resolution: int,
    ):
    """For an image with y_resolution and x_resolution, return a tensor of pixel coordinates
    normalized to lie in [0, 1], with the origin (0, 0) in the top left corner,
    the x-axis pointing right, the y-axis pointing down, and the bottom right corner
    being at (1, 1).

    Returns:
        xy_pix: a meshgrid of values from [0, 1] of shape
                (y_resolution, x_resolution, 2)
    """
    i, j = torch.meshgrid(torch.linspace(0, 1, steps=x_resolution),
                          torch.linspace(0, 1, steps=y_resolution))

    xy_pix = torch.stack([i.float(), j.float()], dim=-1).permute(1, 0, 2)
    return xy_pix

# Dataset 

We'll use the dataset from the paper "Scene Representation Networks, Sitzmann et al. 2019" which includes both 3D scenes and camera poses. 

In [3]:
import os
import requests
import h5py

# Local path where you want to save the downloaded file
local_file_path = 'cars_train.hdf5'

#Download SRNs-cars dataset
if not os.path.exists(local_file_path):
    # URL of the HDF5 file on the web
    url = 'https://drive.google.com/uc?id=1SBjlsizq0sFNkCZxMQh-pNRi0HyFozKb'

    # Send an HTTP GET request to download the file
    response = requests.get(url)

    # Check if the request was successful (status code 200)
    if response.status_code == 200:
        # Open the local file in binary write mode and save the content
        with open(local_file_path, 'wb') as file:
            file.write(response.content)
        print(f"File '{local_file_path}' downloaded successfully.")

File 'cars_train.hdf5' downloaded successfully.


In [None]:
#! pip install gdown

In [43]:
# import gdown
# import os 
#Download SRNs-cars dataset in colab 
#if not os.path.exists("/content/cars_train.hdf5"):
    #!gdown 1TIxIBN1EN9FHsH_7_lGFXF9SBYFHn_rl

## Query a single batch from the dataset and plot an element of that batch.

### Dataloader 

In [6]:
from skimage.transform import resize

def parse_rgb(hdf5_dataset):
    '''reads the data from the HDF5 dataset'''
    s = hdf5_dataset[...].tobytes()
    f = io.BytesIO(s)

    img = imageio.imread(f)[:, :, :3]
    img = skimage.img_as_float32(img)
    return img


def parse_intrinsics(hdf5_dataset):
    '''read camera intrinsic parameters'''
    s = hdf5_dataset[...].tobytes()
    s = s.decode('utf-8')

    lines = s.split('\n')
    f, cx, cy, _ = map(float, lines[0].split())
    full_intrinsic = torch.tensor([[f, 0., cx],
                                    [0., f, cy],
                                    [0., 0, 1]])

    return full_intrinsic


def parse_pose(hdf5_dataset):
    '''reads transformation matrices (pose)'''
    raw = hdf5_dataset[...]
    ba = bytearray(raw)
    s = ba.decode('ascii')

    lines = s.splitlines()
    pose = np.zeros((4, 4), dtype=np.float32)

    for i in range(16):
        pose[i // 4, i % 4] = lines[0].split(" ")[i]

    pose = torch.from_numpy(pose.squeeze())
    return pose

In [None]:
class SRNsCars(IterableDataset):
    def __init__(self, max_num_instances=None, img_sidelength=None, num_images = 4):
        self.f = h5py.File('cars_train.hdf5', 'r')
        self.instances = sorted(list(self.f.keys()))
        print('instances', self.instances)

        self.img_sidelength = img_sidelength
        self.num_images = num_images

        if max_num_instances:
            self.instances = self.instances[:max_num_instances]

    def __len__(self):
        return len(self.instances)

    def __iter__(self, override_idx=None): #None
        while True:
            if override_idx is not None:
                idx = override_idx
            else:
                idx = random.randint(0, len(self.instances)-1)

            key = self.instances[idx]

            instance = self.f[key]
            rgbs_ds = instance['rgb']
            c2ws_ds = instance['pose']

            rgb_keys = list(rgbs_ds.keys())
            c2w_keys = list(c2ws_ds.keys())

            obs_idx = np.random.randint(0, len(rgb_keys), self.num_images)

            def process_one_img(observation_idx):
                rgb = parse_rgb(rgbs_ds[rgb_keys[observation_idx]] )

                x_pix = get_opencv_pixel_coordinates(*rgb.shape[:2])

                # There is a lot of white-space around the cars - we'll thus crop the images a bit:
                rgb = rgb[32:-32, 32:-32]
                x_pix = x_pix[32:-32, 32:-32]

                # Nearest-neighbor downsampling of *both* the
                # RGB image and the pixel coordinates. This is better than down-
                # sampling RGB only and then generating corresponding pixel coordinates,
                # which generates "fake rays", i.e., rays that the camera
                # didn't actually capture with wrong colors. Instead, this simply picks a
                # subset of the "true" camera rays.
                if self.img_sidelength is not None and rgb.shape[0] != self.img_sidelength:
                    rgb = resize(rgb,
                                (self.img_sidelength, self.img_sidelength),
                                anti_aliasing=False,
                                order=0)
                    rgb = torch.from_numpy(rgb)
                    x_pix = resize(x_pix,
                                (self.img_sidelength, self.img_sidelength),
                                anti_aliasing=False,
                                order=0)

                x_pix = rearrange(x_pix, 'i j c -> (i j) c')
                c2w = parse_pose( c2ws_ds[c2w_keys[observation_idx]] )

                rgb = rearrange(rgb, 'i j c -> (i j) c')

                intrinsics = parse_intrinsics( instance['intrinsics.txt'] )
                intrinsics[:2, :3] /= 128. # Normalize intrinsics from resolution-specific intrinsics for 128x128

                ###
                # Create a dictionary which contains the following
                # 1. the 'cam2world' poses, which we computed as c2w
                # 2. the camera 'intrinsics'
                # 3. the pixel coordinates, 'x_pix'
                # 4. the index of the sampled car, 'idx'
                model_input = {
                        'cam2world': c2w,  # Camera-to-world poses
                        'intrinsics': intrinsics,  # Camera intrinsics
                        'x_pix': x_pix,  # Pixel coordinates
                        'idx': torch.tensor([idx])  # Index of the sampled car
                    }
                return model_input, rgb 
            
            rgb_list = []
            input_list = dict()
            for o_idx in obs_idx:
                i, r = process_one_img(o_idx)
                rgb_list.append(r)
                input_list[o_idx] = i

            rgb_stacked = torch.stack(rgb_list)
            # yield model_input, rgb

            yield input_list, rgb_stacked

### Plot single input (N sub images)

In [None]:
sl = 32
num_images = 5
dataset = SRNsCars(img_sidelength=sl,num_images = num_images)
mi, rgb = next(iter(dataset))

imgs = rgb.view(rgb.shape[0],sl, sl, 3)

fig, ax = plt.subplots(1,num_images, figsize=(30, 6))
for r in range(num_images):
    ax[r].imshow(imgs[r])
    ax[r].set_title(f'Input {r}')
    # ax[r].axis('off')

plt.show()

pose_keys = list(mi.keys())
print('Scene number:', int(mi[pose_keys[0]]['idx']))
print('Pose indices:', pose_keys)