In [2]:
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
import datasets


In [3]:
def test_load_time(dataset):
  load_times = []
  end = time.time()
  for batch in dataset:
      load_times.append(time.time() - end)
      end=time.time()

  load_times = np.array(load_times)
  print(f"Avg load time: {np.mean(load_times)}, std: {np.std(load_times)}")

In [3]:
mydataset = datasets.load_from_disk("dataset/hf_test/scripted_trajectories_50_2024-07-14_14-25-22.hf").with_format("torch")
mydataloader = DataLoader(mydataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)


In [118]:
test_load_time(mydataset)

Avg load time: 0.0015681628704071046, std: 0.000694624684764919


In [119]:
test_load_time(mydataloader)

Avg load time: 0.11769705707744016, std: 0.46098371012284395


In [11]:
def embed_gripper(gripper):
  """
  Convert from (-1, 1) to one hot encoded
  One hot needs them as 1d
  """
  return torch.nn.functional.one_hot(gripper.flatten() + 1, num_classes=3)

def decode_gripper(gripper):
  """
  Convert from one hot encoded to column vector in range (-1, 1)
  """
  return (gripper.argmax(dim=1) - 1).unsqueeze(1).to(int)

## Dataset.map()
V slow

Calling .map() is cached for most of these, typically takes around 1m20s.

The image normalisation seens to cause the vast majority of the slowdown

In [7]:

bounds_centre = torch.tensor([0]*6)
bounds_range = torch.tensor([12]*6)

def preprocess_function(batch):
    """
    Take a batch of data and put it in a suitable tensor format for the model
    """
    def normalize_qpos(qpos):
        return (qpos - bounds_centre) / bounds_range + 0.5
    
    observation_qpos_normalised = torch.atleast_2d(normalize_qpos(batch["observation.state.qpos"]).to(torch.float32))
    observation_gripper = torch.atleast_2d(embed_gripper(batch["observation.state.gripper"]).to(torch.float32))
    
    observation_state = torch.hstack((observation_qpos_normalised, observation_gripper))

    action_qpos_normalised = torch.atleast_2d(normalize_qpos(batch["action.qpos"]).to(torch.float32))
    action_gripper = torch.atleast_2d(embed_gripper(batch["action.gripper"]).to(torch.float32))
    action_state = torch.hstack((action_qpos_normalised, action_gripper))
    
    image = batch["observation.pixels"]/ 255

    batch = {"preprocessed.observation.state": observation_state, "preprocessed.observation.image": image,
             "preprocessed.action.state": action_state}

    return batch

# Initial: 1m44 to iterate throfu
# Remove .to(torch.float32): 2m16s to run map, 

In [121]:
with torch.profiler.profile(on_trace_ready=torch.profiler.tensorboard_trace_handler("runs/dataload_profile")) as p:
  mydataset = mydataset.map(preprocess_function, batched=True)

In [9]:
print(p.key_averages().table(row_limit=-1))


-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                aten::reshape         0.00%      11.000us         0.00%      48.000us      24.000us             2  
                   aten::view         0.00%      83.000us         0.00%      83.000us       4.882us            17  
                 aten::unbind         0.16%       4.269ms         0.16%       4.433ms     554.125us             8  
                 aten::select         0.00%     134.000us         0.01%     159.000us       0.026us          6012  
             aten::as_strided         0.00%      50.000us         0.00%      50.000us       0.008us          6055  
                   aten::item         0.00%      40.000us         0.00% 

In [None]:
# WARNING - don't run this
# Adding the profiler makes it take forever and consume an ungodly amount of RAM (python process goes to 50GB RAM)
with torch.profiler.profile(on_trace_ready=torch.profiler.tensorboard_trace_handler("runs/dataload_profile")) as p:
  test_load_time(mydataset)
# test_load_time(mydataloader)
print(p.key_averages().table(row_limit=-1))


-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------  ------------  ------------  ------------  ------------  ------------  ------------  
       aten::lift_fresh         0.03%       1.492ms         0.03%       1.492ms       0.001us       1618944  
               aten::to        11.25%     638.390ms        81.60%        4.630s       2.860us       1618944  
         aten::_to_copy        27.70%        1.572s        73.22%        4.154s       2.566us       1618944  
    aten::empty_strided         0.15%       8.312ms         0.15%       8.312ms       0.005us       1618944  
            aten::copy_        51.87%        2.943s        51.87%        2.943s       1.818us       1618944  
          aten::detach_         0.30%      17.169ms         0.32%      18.002ms       0.011us       1618944  
          

In [10]:
mydataset = datasets.load_from_disk("dataset/hf_test/scripted_trajectories_50_2024-07-14_14-25-22.hf").with_format("torch")
mydataset = mydataset.map(preprocess_function, batched=True)

In [11]:
test_load_time(mydataset)


Avg load time: 0.006466682863235474, std: 0.0017202445870103937


## Keep the dataset in RAM
Offers no speed up, so clearly not a disk read speed issue

In [12]:
ramdataset = datasets.load_from_disk("dataset/hf_test/scripted_trajectories_50_2024-07-14_14-25-22.hf", keep_in_memory=True).with_format("torch")

In [13]:
ramdataset = ramdataset.map(preprocess_function, batched=True)


Map:   0%|          | 0/15000 [00:00<?, ? examples/s]

In [14]:
test_load_time(ramdataset)

Avg load time: 0.005494078191121419, std: 0.0016571423263816668


## Transform on the fly
Instead of using .map(), just apply the transform to each batch as we load it, each iteration, with set_transform.
set_transform seems to override with_format("torch") so we have to do the torch conversion ourselves

This is much quicker. .map is broken!


In [6]:
import torchvision

def preprocess_function(batch):
    """
    Take a batch of data and put it in a suitable tensor format for the model
    """
    bounds_centre = torch.tensor([0]*6)
    bounds_range = torch.tensor([12]*6)
    out = {}
    # start = time.time()
    def normalize_qpos(qpos):
        return (qpos - bounds_centre) / bounds_range + 0.5
    
    observation_qpos_normalised = normalize_qpos(torch.tensor(batch["observation.state.qpos"], dtype=torch.float32))
    observation_gripper = embed_gripper(torch.tensor(batch["observation.state.gripper"], dtype=int)).to(torch.float32)
    out["preprocessed.observation.state"] = torch.hstack((observation_qpos_normalised, observation_gripper))

    action_qpos_normalised = normalize_qpos(torch.tensor(batch["action.qpos"], dtype=torch.float32))
    action_gripper = embed_gripper(torch.tensor(batch["action.gripper"], dtype=int)).to(torch.float32)
    out["preprocessed.action.state"] = torch.hstack((action_qpos_normalised, action_gripper))
    
    # Convert to float32 with image from channel first in [0,255]
    tf = torchvision.transforms.ToTensor()
    out["preprocessed.observation.image"] = torch.stack([tf(x) for x in batch["observation.pixels"]])


    return out

device = "mps"

def preprocess_on_device(batch):
    """
    Take a batch of data and put it in a suitable tensor format for the model
    """
    bounds_centre = torch.tensor([0]*6).to(device)
    bounds_range = torch.tensor([12]*6).to(device)
    out = {}
    # start = time.time()
    def normalize_qpos(qpos):
        return (qpos - bounds_centre) / bounds_range + 0.5
    
    observation_qpos_normalised = normalize_qpos(torch.tensor(batch["observation.state.qpos"], dtype=torch.float32).to(device))
    observation_gripper = embed_gripper(torch.tensor(batch["observation.state.gripper"], dtype=int).to(device)).to(torch.float32)
    out["preprocessed.observation.state"] = torch.hstack((observation_qpos_normalised, observation_gripper))

    action_qpos_normalised = normalize_qpos(torch.tensor(batch["action.qpos"], dtype=torch.float32).to(device))
    action_gripper = embed_gripper(torch.tensor(batch["action.gripper"], dtype=int).to(device)).to(torch.float32)
    out["preprocessed.action.state"] = torch.hstack((action_qpos_normalised, action_gripper))
    
    # Create tensor stack, move to GPU, normalise
    tf = torchvision.transforms.PILToTensor()
    out["preprocessed.observation.image"] = torch.stack([tf(x) for x in batch["observation.pixels"]], dim=0).to(device)
    out["preprocessed.observation.image"] = out["preprocessed.observation.image"] / 255


    return out

def transform_function(batch):
    batch["preprocessed.observation.image"] = batch["observation.pixels"]/ 255
    return batch
    

In [7]:
mydataset = datasets.load_from_disk("dataset/hf_test/scripted_trajectories_50_2024-07-14_14-25-22.hf", keep_in_memory=True)
# mydataset = mydataset.map(preprocess_function, batched=True)
mydataset.set_transform(preprocess_function)
# mydataset.set_transform(preprocess_on_device)


In [12]:
test_load_time(mydataset)


Avg load time: 0.001014552625020345, std: 0.0002771132183452647


In [13]:
# Multithreading doesn't work, probably a jupyter issue
mydataloader = DataLoader(mydataset, batch_size=256, shuffle=True)


In [14]:
test_load_time(mydataloader)


Avg load time: 0.2979206756009894, std: 0.02904730607116912
