In [1]:
from dataset import NearDetDataset3D, TransformerDataset
import matplotlib.pyplot as plt
import torch
import MinkowskiEngine as ME
import numpy as np
from torch.utils.data import DataLoader

In [2]:
def collation_function(data):
    target_data = [datum['target'] for datum in data]
    input_data = [datum['input'] for datum in data]

    target_coords, target_feats = list(zip(*target_data))
    input_coords, input_feats = list(zip(*input_data))

    # Create batched coordinates for the SparseTensor input
    batched_target_coords = ME.utils.batched_coordinates(target_coords)
    batched_input_coords = ME.utils.batched_coordinates(input_coords)

    # Concatenate all lists
    batched_target_feats = torch.from_numpy(np.concatenate(target_feats, 0)).float()
    batched_input_feats = torch.from_numpy(np.concatenate(input_feats, 0)).float()

    return {
        'input': (batched_input_coords, batched_input_feats),
        'target': (batched_target_coords, batched_target_feats)
    }

def make_data_loader(
    dataset, batch_size, shuffle, collation_function, num_workers, config=None
):

    args = {
        "batch_size": batch_size,
        "num_workers": num_workers,
        "collate_fn": collation_function,
        "pin_memory": False,
        "drop_last": False,
    }

    args["shuffle"] = shuffle
    loader = torch.utils.data.DataLoader(dataset, **args)

    return loader

In [3]:
dataset = NearDetDataset3D()
dataloader = make_data_loader(
        dataset=dataset,
        batch_size=128,
        shuffle=False,
        collation_function=collation_function,
        num_workers=4
    )

dataset

<dataset.NearDetDataset3D at 0x7f6ce9a25f50>

In [4]:
transformer_ds = TransformerDataset()
transformer_dataloader = DataLoader(transformer_ds, batch_size=2, shuffle=False)