In [None]:
# PyTorch, Torchvision
import torch
from torch import nn
from torchvision.transforms import ToPILImage, ToTensor

# Common
from pathlib import Path
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# For train-validation-test split (if we want to do it manually)
import random
from math import floor

# Utils from Torchvision
tensor_to_image = ToPILImage()
image_to_tensor = ToTensor()

In [2]:
# Helper Functions (from Amar)
def get_img_dict(img_dir):
    img_files = [x for x in img_dir.iterdir() if x.name.endswith('.png') or x.name.endswith('.tiff')]
    img_files.sort()

    img_dict = {}
    for img_file in img_files:
        img_type = img_file.name.split('_')[0] 
        if img_type not in img_dict:
            img_dict[img_type] = []
        img_dict[img_type].append(img_file)
    return img_dict

def get_sample_dict(sample_dir):

    camera_dirs = [x for x in sample_dir.iterdir() if 'camera' in x.name]
    camera_dirs.sort()
    
    sample_dict = {}

    for cam_dir in camera_dirs:
        cam_dict = {}
        cam_dict['scene'] = get_img_dict(cam_dir)

        obj_dirs = [x for x in cam_dir.iterdir() if 'obj_' in x.name]
        obj_dirs.sort()
        
        for obj_dir in obj_dirs:
            cam_dict[obj_dir.name] = get_img_dict(obj_dir)

        sample_dict[cam_dir.name] = cam_dict

    return sample_dict

In [3]:
def make_dataset11(videos_dir: str, 
                          inds: list[int] | None = None, 
                          random_seed: int = 42) -> tuple[torch.tensor, torch.tensor, list[dict]]: 
    """ 
    Loads one frame for every object viewed from every camera angle in every video in videos_dir and 
    returns the modal masks and amodal masks in tensor form.

    Parameters:
        - videos_dir: a directory where each sub-directory contains a video from Movi-MC-AC
        - inds: a list of the same length as the number of videos telling which frame of each video to load into the dataset
        - random_seed: an integer to set the random seed. Only necessary if inds is None.

    Returns: 
        - a tensor of modal masks
        - a tensor of amodal masks
        - a list of ALL images from ALL videos. TODO: abstract this into a separate function
        - TODO: also return a tensor of RGB images
    """

    # Save the RNG state to restore it after we're done
    prev_state = random.getstate()
    random.seed(random_seed)

    # If user does not specify, take one random frame from each video
    if inds is None: 
        inds = [random.randint(0, num_frames - 1) for i in range(num_vids)]
    random.setstate(prev_state)

    # Load a dictionary storing each video file
    videos_dir = Path(videos_dir)
    video_dicts = [get_sample_dict(Path(sample_dir)) for sample_dir in videos_dir.iterdir()]

    num_vids = len(video_dicts)

    # Count number of objects in each video
    num_obj_list = [None] * num_vids 
    for i in range(num_vids):
        num_obj_list[i] = sum([s.startswith('obj') for s in video_dicts[i]['camera_0000'].keys()])

    num_obj = sum(num_obj_list)
    
    # Every video in Movi-MC-AC has 24 frames and 6 camera angles
    num_cams = 6 
    num_frames = 24 

    # Sample size
    dataset_len = num_vids * num_cams * num_obj
    
    modal_tensor = torch.empty((dataset_len, 256, 256))
    amodal_tensor = torch.empty((dataset_len, 256, 256))

    l = 0 # iterates through l = 0, ..., dataset_len - 1
    for i in range(num_vids): # for each video
        video_dict = video_dicts[i]
        cur_num_obj = num_obj_list[i]
        for j in range(num_cams): # for each camera angle
            cam_dict = video_dict[f'camera_{j:04d}']
            modal_masks = Image.open(cam_dict['scene']['segmentation'][inds[i]]) # load the modal masks
            for k in range(cur_num_obj): # for each object
                amodal_mask = Image.open(cam_dict[f'obj_{k+1:04d}']['segmentation'][inds[i]]) # load the amodal mask
                
                modal_mask = (torch.tensor(np.array(modal_masks)) == (k + 1)).float() # extract the modal mask of a single object
                amodal_mask = image_to_tensor(amodal_mask)

                modal_tensor[l,:,:] = modal_mask 
                amodal_tensor[l,:,:] = amodal_mask
                l += 1
    
    return modal_tensor, amodal_tensor, video_dicts


In [None]:
# Can put the sample data from the demo (the folder starting with ff...) in a 
# directory called `data` to test this
modal_tensor, amodal_tensor, video_dicts = make_dataset11('../data', 
                                                          [1]) # take the first frame of the video

# Display one set of modal & amodal masks
# Create a figure with 2 columns
fig, axes = plt.subplots(1, 2)

# Plot the images
axes[0].imshow(tensor_to_image(modal_tensor[8]), cmap='gray')  # First image
axes[0].set_title("Modal Mask")
axes[0].axis('off')  # Turn off axes

axes[1].imshow(tensor_to_image(amodal_tensor[8]), cmap='gray')  # Second image
axes[1].set_title("Amodal Mask")
axes[1].axis('off')  # Turn off axes

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
def train_val_test_split(X: torch.tensor, 
                         Y: torch.tensor, 
                         props: tuple[float, float, float], 
                         random_seed: int = 42) -> list[torch.tensor]:
    ''' 
    Split two tensors, X and Y, into train, validation, and test datasets according to proportions `probs`.

    Parameters:
        - X: input data
        - Y: output data
        - props: a length 3 tuple which must sum to 1. probs = (0.7, 0.2, 0.1) specifies 70% of the data 
        in training, 20% of the data in validation, and 10% in testing. Any of these three percentages can 
        be 0.
        - random_seed: Integer to set the RNG state.
    
    Return:
        - 6 tensors: X and Y split into train, validation, and test.

    Details:
    
    Validation data should be used for hyperparameter tuning and model selection, while test data should 
    be used for model evaluation (to make sure the model isn't under or overfitting).
    '''
    # Make sure proportions sum to 1, within floating point error
    if (abs(sum(props) - 1)) >= 1e-4: 
        raise ValueError("props must sum to 1")

    if X.shape[0] != Y.shape[0]: 
        raise ValueError("X and Y must have the same sample size.")
    
    n = X.shape[0]

    test_size = floor(n * props[2])

    # Make sure test size is not accidentally rounded to 0
    if props[2] > 1e-4:
        test_size = max(test_size, 1)

    # Make sure validation size is not accidentally rounded to 0
    val_size = floor(n * props[1])
    if props[1] > 1e-4:
        val_size = max(val_size, 1)

    data_inds = range(n)

    prev_state = random.getstate()
    random.seed(random_seed)

    test_inds = random.sample(data_inds, test_size)
    X_test = X[test_inds]
    Y_test = Y[test_inds]

    train_val_inds = set(data_inds) - set(test_inds)
    val_inds = random.sample(list(train_val_inds), val_size)
    X_val = X[val_inds]
    Y_val = Y[val_inds]

    train_inds = list(train_val_inds - set(val_inds))
    random.shuffle(train_inds)

    random.setstate(prev_state)

    X_train = X[train_inds]
    Y_train = Y[train_inds]

    return X_train, Y_train, X_val, Y_val, X_test, Y_test

In [None]:
# Example of splitting empty tensors
res = train_val_test_split(torch.empty((1001, 200, 200)), 
                           torch.empty((1001, 200, 200)), 
                           (0.7, 0.2, 0.1))

(torch.Size([701, 200, 200]),
 torch.Size([701, 200, 200]),
 torch.Size([200, 200, 200]),
 torch.Size([200, 200, 200]),
 torch.Size([100, 200, 200]),
 torch.Size([100, 200, 200]))