Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
f58f5a8
First version of small model
Flova Jan 28, 2025
aeb6cd5
Add patching
Flova Jan 28, 2025
e8db54f
Add basic destillation training script
Flova Jan 28, 2025
cb936c2
Fix warning
Flova Jan 28, 2025
a116338
Fix destillation
Flova Jan 28, 2025
2226934
Add destilled model to plot
Flova Jan 30, 2025
b6955f5
Merge branch 'main' into feature/destillation
Flova Jan 30, 2025
d0b7b63
Rename file
Flova Jan 30, 2025
9fc7f78
Fix joint names
Flova Jan 30, 2025
191cf78
use destilled hyperparam
Flova Jan 30, 2025
34bcf83
Migrate distillation to new hyperparameter standard
Flova Jan 30, 2025
f4e74b0
Current WIP
Flova Jan 30, 2025
8dd6023
Make it possible to only load some of the features from the dataset
Flova Feb 4, 2025
5b0b7f4
Add decoder only training
Flova Feb 4, 2025
577a105
Format
Flova Feb 4, 2025
56b3b58
Shorter pretraining
Flova Feb 5, 2025
b6d626d
Merge branch 'main' into feature/small_model
Flova Feb 5, 2025
19dd73b
Reverse to transformer decoder
Flova Feb 5, 2025
8f53490
Merge branch 'feature/small_model' into feature/destillation
Flova Feb 5, 2025
f5b51b5
Fix other scripts
Flova Feb 5, 2025
420158a
Remove profiling
Flova Feb 5, 2025
12f95c0
Be able to load pretrained decoders
Flova Feb 5, 2025
d29db0e
Apply formatting
Flova Feb 5, 2025
213d285
Optmize data transfer
Flova Feb 5, 2025
c0e7755
Remove pinned memory
Flova Feb 5, 2025
73c9b6c
Add wandb
Flova Feb 6, 2025
92f0693
Fix wand and add larger model
Flova Feb 9, 2025
ca4312b
Add wandb logs to gitignore
Flova Feb 9, 2025
048cdca
Fix model loading in distillation
Flova Feb 9, 2025
af70a87
Avoid printing out all the params
Flova Feb 9, 2025
e2b2299
Add wandb to distill
Flova Feb 9, 2025
8c7965d
Sort keys
Flova Feb 11, 2025
9fb40c6
Change ros runtime
Flova Feb 11, 2025
ada1b33
Fix five dim input
Flova Feb 20, 2025
3454ee1
Add different resolutions, fix five dim imu, add current training con…
Flova Feb 20, 2025
9b0f14e
Fix image padding
Flova Feb 20, 2025
5d27aa9
Add support for other imu input
Flova Feb 27, 2025
69f21fa
Cleanup
Flova Feb 27, 2025
da0c1ac
Normalize image during data loading
Flova Feb 27, 2025
114f1eb
Sample image at correct rate
Flova Feb 27, 2025
f114ea7
Fix preprocessing pipeline
Flova Feb 27, 2025
633a3be
Fix image norm
Flova Mar 24, 2025
3f8de56
Use internal joint command buffer
Flova Apr 4, 2025
1e5f909
Merge branch 'main' into feature/destillation
Flova Apr 16, 2025
be4869e
Remove TODO
Flova Apr 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,8 @@ ENV/
# Torch models
*.pth

# Wandb Logs
ddlitlab2024/ml/training/wandb/

# Input data
input/
5 changes: 3 additions & 2 deletions ddlitlab2024/dataset/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


class RobotState(str, Enum):
POSITIONING = "POSITIONING"
PLAYING = "PLAYING"
POSITIONING = "POSITIONING"
STOPPED = "STOPPED"
UNKNOWN = "UNKNOWN"

Expand Down Expand Up @@ -219,7 +219,8 @@ class JointStates(Base):
Index(None, "recording_id", asc("stamp")),
)

def get_ordered_joint_names(self) -> list[str]:
@staticmethod
def get_ordered_joint_names() -> list[str]:
return [
JointStates.head_pan.name,
JointStates.head_tilt.name,
Expand Down
146 changes: 98 additions & 48 deletions ddlitlab2024/dataset/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from collections.abc import Iterable
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Literal
from typing import Literal, Optional

import cv2
import numpy as np
import pandas as pd
import torch
from profilehooks import profile
from tabulate import tabulate
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2

from ddlitlab2024 import DB_PATH
from ddlitlab2024.dataset import logger
Expand Down Expand Up @@ -40,15 +41,15 @@ class DDLITLab2024Dataset(Dataset):
@dataclass
class Result:
joint_command: torch.Tensor
joint_command_history: torch.Tensor
joint_state: torch.Tensor
rotation: torch.Tensor
game_state: torch.Tensor
image_data: torch.Tensor
image_stamps: torch.Tensor
joint_command_history: Optional[torch.Tensor]
joint_state: Optional[torch.Tensor]
rotation: Optional[torch.Tensor]
game_state: Optional[torch.Tensor]
image_data: Optional[torch.Tensor]
image_stamps: Optional[torch.Tensor]

def shapes(self) -> dict[str, tuple[int, ...]]:
return {k: v.shape for k, v in asdict(self).items()}
return {k: v.shape for k, v in asdict(self).items() if v is not None}

def __init__(
self,
Expand All @@ -61,8 +62,14 @@ def __init__(
sampling_rate: int = 100,
max_fps_video: int = 10,
num_frames_video: int = 50,
trajectory_stride: int = 10,
image_resolution: int = 480,
trajectory_stride: int = 1,
num_joints: int = 20,
use_images: bool = True,
use_imu: bool = True,
use_joint_states: bool = True,
use_action_history: bool = True,
use_game_state: bool = True,
):
# Initialize the database connection
self.db_connection: sqlite3.Connection = db_connection if db_connection else connect_to_db()
Expand All @@ -76,8 +83,15 @@ def __init__(
self.sampling_rate = sampling_rate
self.max_fps_video = max_fps_video
self.num_frames_video = num_frames_video
self.image_resolution = image_resolution
self.trajectory_stride = trajectory_stride
self.num_joints = num_joints
self.joint_names = JointStates.get_ordered_joint_names()
self.use_images = use_images
self.use_imu = use_imu
self.use_joint_states = use_joint_states
self.use_action_history = use_action_history
self.use_game_state = use_game_state

# Print out metadata
cursor = self.db_connection.cursor()
Expand All @@ -100,7 +114,9 @@ def __init__(
assert num_data_points > 0, "Recording length is negative or zero"
total_samples_before = self.num_samples
# Calculate the number of batches that can be build from the recording including the stride
self.num_samples += int(num_data_points / self.trajectory_stride)
self.num_samples += int(
(num_data_points - self.num_samples_joint_trajectory_future) / self.trajectory_stride
)
# Store the boundaries of the samples for later retrieval
self.sample_boundaries.append((total_samples_before, self.num_samples, recording_id))

Expand All @@ -119,7 +135,7 @@ def query_joint_data(
)

# Convert to numpy array, keep only the joint angle columns in alphabetical order
raw_joint_data = raw_joint_data[JointStates.get_ordered_joint_names()].to_numpy(dtype=np.float32)
raw_joint_data = raw_joint_data[self.joint_names].to_numpy(dtype=np.float32)

assert raw_joint_data.shape[1] == self.num_joints, "The number of joints is not correct"

Expand Down Expand Up @@ -155,7 +171,7 @@ def query_joint_data_history(
return raw_joint_data

def query_image_data(
self, recording_id: int, end_time_stamp: float, context_len: float, num_frames: int
self, recording_id: int, end_time_stamp: float, context_len: float, num_frames: int, resolution: int
) -> tuple[torch.Tensor, torch.Tensor]:
# Get the image data
cursor = self.db_connection.cursor()
Expand All @@ -178,25 +194,36 @@ def query_image_data(
stamps = []
image_data = []

# Define the preprocessing pipeline
preprocessing = v2.Compose(
[
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)

# Get the raw image data
for stamp, data in response:
# Deserialize the image data
image = np.frombuffer(data, dtype=np.uint8).reshape(480, 480, 3)
# Make chw from hwc
image = np.moveaxis(image, -1, 0)
# Resize the image
image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_AREA)
# Apply the preprocessing pipeline
image = preprocessing(image)
# Append to the list
image_data.append(image)
stamps.append(stamp)

# Apply zero padding if necessary
if len(image_data) < num_frames:
image_data = [
np.zeros((3, 480, 480), dtype=np.uint8) for _ in range(num_frames - len(image_data))
torch.zeros((3, resolution, resolution), dtype=torch.float32)
for _ in range(num_frames - len(image_data))
] + image_data
stamps = [end_time_stamp - context_len for _ in range(num_frames - len(stamps))] + stamps

# Convert to tensor
image_data = torch.from_numpy(np.stack(image_data, axis=0)).float()
image_data = torch.stack(image_data, axis=0)
stamps = torch.tensor(stamps)

return stamps, image_data
Expand Down Expand Up @@ -244,7 +271,7 @@ def query_imu_data(self, recording_id: int, end_sample: int, num_samples: int) -
case rep:
raise NotImplementedError(f"Unknown IMU representation {rep}")

return torch.from_numpy(imu_data)
return torch.from_numpy(imu_data).float()

def query_current_game_state(self, recording_id: int, stamp: float) -> torch.Tensor:
cursor = self.db_connection.cursor()
Expand All @@ -265,7 +292,6 @@ def query_current_game_state(self, recording_id: int, stamp: float) -> torch.Ten

return torch.tensor(int(game_state))

@profile
def __getitem__(self, idx: int) -> Result:
# Find the recording that contains the sample
for start_sample, end_sample, recording_id in self.sample_boundaries:
Expand All @@ -288,20 +314,30 @@ def __getitem__(self, idx: int) -> Result:
stamp = sample_joint_command_index / self.sampling_rate

# Get the image data
image_stamps, image_data = self.query_image_data(
recording_id,
stamp,
# The duration is used to narrow down the query for a faster retrieval, so we consider it as an upper bound
(self.num_frames_video + 1) / self.max_fps_video,
self.num_frames_video,
)
# Some sanity checks
assert all([stamp >= image_stamp for image_stamp in image_stamps]), "The image data is not synchronized"
assert len(image_stamps) == self.num_frames_video, "The image data is not the correct length"
assert image_data.shape == (self.num_frames_video, 3, 480, 480), "The image data has the wrong shape"
assert (
image_stamps[0] >= stamp - (self.num_frames_video + 1) / self.max_fps_video
), "The image data is not synchronized"
if self.use_images:
image_stamps, image_data = self.query_image_data(
recording_id,
stamp,
# The duration is used to narrow down the query for a faster retrieval,
# so we consider it as an upper bound
(self.num_frames_video + 1) / self.max_fps_video,
self.num_frames_video,
self.image_resolution,
)
# Some sanity checks
assert all([stamp >= image_stamp for image_stamp in image_stamps]), "The image data is not synchronized"
assert len(image_stamps) == self.num_frames_video, "The image data is not the correct length"
assert image_data.shape == (
self.num_frames_video,
3,
self.image_resolution,
self.image_resolution,
), "The image data has the wrong shape"
assert (
image_stamps[0] >= stamp - (self.num_frames_video + 1) / self.max_fps_video
), "The image data is not synchronized"
else:
image_stamps, image_data = None, None

# Get the joint command target (future)
joint_command = self.query_joint_data(
Expand All @@ -310,20 +346,32 @@ def __getitem__(self, idx: int) -> Result:
assert len(joint_command) == self.num_samples_joint_trajectory_future, "The joint command has the wrong length"

# Get the joint command history
joint_command_history = self.query_joint_data_history(
recording_id, sample_joint_command_index, self.num_samples_joint_trajectory, "JointCommands"
)
if self.use_action_history:
joint_command_history = self.query_joint_data_history(
recording_id, sample_joint_command_index, self.num_samples_joint_trajectory, "JointCommands"
)
else:
joint_command_history = None

# Get the joint state
joint_state = self.query_joint_data_history(
recording_id, sample_joint_command_index, self.num_samples_joint_states, "JointStates"
)
if self.use_joint_states:
joint_state = self.query_joint_data_history(
recording_id, sample_joint_command_index, self.num_samples_joint_states, "JointStates"
)
else:
joint_state = None

# Get the robot rotation (IMU data)
robot_rotation = self.query_imu_data(recording_id, sample_joint_command_index, self.num_samples_imu)
if self.use_imu:
robot_rotation = self.query_imu_data(recording_id, sample_joint_command_index, self.num_samples_imu)
else:
robot_rotation = None

# Get the game state
game_state = self.query_current_game_state(recording_id, stamp)
if self.use_game_state:
game_state = self.query_current_game_state(recording_id, stamp)
else:
game_state = None

return self.Result(
joint_command=joint_command,
Expand All @@ -339,12 +387,14 @@ def __getitem__(self, idx: int) -> Result:
def collate_fn(batch: Iterable[Result]) -> Result:
return DDLITLab2024Dataset.Result(
joint_command=torch.stack([x.joint_command for x in batch]),
joint_command_history=torch.stack([x.joint_command_history for x in batch]),
joint_state=torch.stack([x.joint_state for x in batch]),
image_data=torch.stack([x.image_data for x in batch]),
image_stamps=torch.stack([x.image_stamps for x in batch]),
rotation=torch.stack([x.rotation for x in batch]),
game_state=torch.tensor([x.game_state for x in batch]),
joint_command_history=torch.stack([x.joint_command_history for x in batch])
if batch[0].joint_command_history is not None
else None,
joint_state=torch.stack([x.joint_state for x in batch]) if batch[0].joint_state is not None else None,
image_data=torch.stack([x.image_data for x in batch]) if batch[0].image_data is not None else None,
image_stamps=torch.stack([x.image_stamps for x in batch]) if batch[0].image_stamps is not None else None,
rotation=torch.stack([x.rotation for x in batch]) if batch[0].rotation is not None else None,
game_state=torch.tensor([x.game_state for x in batch]) if batch[0].game_state is not None else None,
)


Expand Down
34 changes: 25 additions & 9 deletions ddlitlab2024/ml/inference/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# Parse the command line arguments
parser = argparse.ArgumentParser(description="Inference Plot")
parser.add_argument("checkpoint", type=str, help="Path to the checkpoint to load")
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps")
parser.add_argument("--steps", type=int, default=30, help="Number of denoising steps (not used for distilled)")
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate")
args = parser.parse_args()

Expand Down Expand Up @@ -55,8 +55,12 @@
image_encoder_type=ImageEncoderType(params["image_encoder_type"]),
num_image_sequence_encoder_layers=params["num_image_sequence_encoder_layers"],
image_context_length=params["image_context_length"],
image_use_final_avgpool=params.get("image_use_final_avgpool", True),
image_resolution=params.get("image_resolution", 480),
num_decoder_layers=params["num_decoder_layers"],
trajectory_prediction_length=params["trajectory_prediction_length"],
use_gamestate=params["use_gamestate"],
encoder_patch_size=params["encoder_patch_size"],
).to(device)
normalizer = Normalizer(model.mean, model.std)
model.load_state_dict(checkpoint["model_state_dict"])
Expand All @@ -76,6 +80,13 @@
num_samples_joint_trajectory=params["action_context_length"],
num_samples_imu=params["imu_context_length"],
num_samples_joint_states=params["joint_state_context_length"],
imu_representation=IMUEncoder.OrientationEmbeddingMethod(params["imu_orientation_embedding_method"]),
use_action_history=params["use_action_history"],
use_imu=params["use_imu"],
use_joint_states=params["use_joint_states"],
use_images=params["use_images"],
use_game_state=params["use_gamestate"],
image_resolution=params.get("image_resolution", 480),
)

# Create DataLoader object
Expand Down Expand Up @@ -104,15 +115,20 @@
noisy_trajectory = torch.randn_like(joint_targets).to(device)
trajectory = noisy_trajectory

# Perform the denoising process
scheduler.set_timesteps(args.steps)
for t in scheduler.timesteps:
if params.get("distilled_decoder", False):
# Directly predict the trajectory based on the noise
with torch.no_grad():
# Predict the noise residual
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample
trajectory = model(batch, noisy_trajectory, torch.tensor([0], device=device))
else:
# Perform the denoising process
scheduler.set_timesteps(args.steps)
for t in scheduler.timesteps:
with torch.no_grad():
# Predict the noise residual
noise_pred = model(batch, trajectory, torch.tensor([t], device=device))

# Update the trajectory based on the predicted noise and the current step of the denoising process
trajectory = scheduler.step(noise_pred, t, trajectory).prev_sample

# Undo the normalization
trajectory = normalizer.denormalize(trajectory)
Expand Down
Loading
Loading