Diffusion-based Policy Learning for RL

This notebook implements Diffusion Policy, a diffusion model that predicts robot action sequences in reinforcement learning tasks.

This example implements a robot control model for pushing a T-shaped block into a target area. The model takes in current state observations as input, and outputs a trajectory of subsequent steps to follow. This script was contributed by [Dorsa Rohani](https://github.com/DorsaRoh) and the notebook by [Parag Ekbote](https://github.com/ParagEkbote).

In [2]:
!pip install torch==2.0.1+cu117 \
  torchvision==0.15.2+cu117 \
  torchaudio==2.0.2+cu117 \
  git+https://github.com/rail-berkeley/d4rl.git \
  gym==0.23.1 \
  protobuf==3.20.1 \
  einops \
  mediapy \
  Pillow==9.0.0 \
  -f https://download.pytorch.org/whl/torch_stable.html


Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting git+https://github.com/rail-berkeley/d4rl.git
  Cloning https://github.com/rail-berkeley/d4rl.git to /tmp/pip-req-build-tdkn3r22
  Running command git clone --filter=blob:none --quiet https://github.com/rail-berkeley/d4rl.git /tmp/pip-req-build-tdkn3r22

  Resolved https://github.com/rail-berkeley/d4rl.git to commit 89141a689b0353b0dac3da5cba60da4b1b16254d
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting torch==2.0.1+cu117
  Downloading https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp310-cp310-linux_x86_64.whl (1843.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 GB[0m [31m49.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hCollecting torchvision==0.15.2+cu117
  Downloading https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp310-cp310-linux_x86_64.whl (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6

In [6]:
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from diffusers import DDPMScheduler

class ObservationEncoder(nn.Module):

    def __init__(self, state_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256)
        )

    def forward(self, x):
        return self.net(x)

class ObservationProjection(nn.Module):

    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(32, 512))
        self.bias = nn.Parameter(torch.zeros(32))

    def forward(self, x):
        if x.size(-1) == 256:
            x = torch.cat([x, torch.zeros(*x.shape[:-1], 256, device=x.device)], dim=-1)
        return nn.functional.linear(x, self.weight, self.bias)

class UNet1D(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels=128):
        super().__init__()
        
        # Downsampling path
        self.down1 = nn.Sequential(
            nn.Conv1d(in_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )
        
        # Middle
        self.mid = nn.Sequential(
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Upsampling path
        self.up1 = nn.Sequential(
            nn.Conv1d(2 * hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_channels, out_channels, kernel_size=3, padding=1)
        )

    def forward(self, x, t):
        # Ensure proper tensor dimensions
        if not isinstance(t, torch.Tensor):
            t = torch.tensor([t], device=x.device)
        if t.dim() == 0:
            t = t.view(1)
        if t.dim() == 1:
            t = t.unsqueeze(-1)

        # Time embedding
        t_emb = self.time_mlp(t.float())  # [B, H]
        
        # Reshape time embedding to match spatial dimensions
        t_emb = t_emb.unsqueeze(-1)  # [B, H, 1]
        t_emb = t_emb.expand(-1, -1, x.shape[-1])  # [B, H, L]
        
        # Downsampling
        d1 = self.down1(x)  # [B, H, L]
        
        # Add time embedding
        mid = self.mid(d1 + t_emb)  # [B, H, L]
        
        # Upsampling with skip connections
        up = self.up1(torch.cat([mid, d1], dim=1))  # [B, out_channels, L]
        
        return up

class DiffusionPolicy:
    def __init__(self, state_dim=5, device="cpu"):
        self.device = device
        
        # Define valid ranges
        self.stats = {
            "obs": {
                "min": torch.zeros(5, device=device),
                "max": torch.tensor([512, 512, 512, 512, 2 * np.pi], device=device)
            },
            "action": {
                "min": torch.zeros(2, device=device),
                "max": torch.full((2,), 512, device=device)
            },
        }
        
        self.obs_encoder = ObservationEncoder(state_dim).to(device)
        self.obs_projection = ObservationProjection().to(device)
        
        # Use custom UNet1D implementation
        self.model = UNet1D(
            in_channels=34,  # 2 action channels + 32 context channels
            out_channels=2,  # x,y coordinates
            hidden_channels=128
        ).to(device)
        
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=100,
            beta_schedule="squaredcos_cap_v2"
        )
        
        # Load pre-trained weights using a more compatible approach
        try:
            checkpoint_path = hf_hub_download("dorsar/diffusion_policy", "push_tblock.pt")
            checkpoint = torch.load(checkpoint_path, map_location=device)
            
            # Load weights for encoder and projection
            self.obs_encoder.load_state_dict(self._fix_state_dict(checkpoint["encoder_state_dict"]))
            self.obs_projection.load_state_dict(self._fix_state_dict(checkpoint["projection_state_dict"]))
            
            # Transfer UNet weights
            self._transfer_weights(checkpoint["model_state_dict"])
            
        except Exception as e:
            print(f"Warning: Could not load pre-trained weights: {e}")
            print("The model will use randomly initialized weights.")

    def _fix_state_dict(self, state_dict):
        """Helper function to fix state dict keys if needed"""
        new_state_dict = {}
        for k, v in state_dict.items():
            # Remove 'module.' prefix if it exists 
            k = k.replace('module.', '')
            new_state_dict[k] = v
        return new_state_dict

    def _transfer_weights(self, original_state_dict):

        custom_state_dict = self.model.state_dict()
        
        # Create mapping between original and custom architecture
        layer_mapping = {
            'down_blocks.0.resnets.0': 'down1.0',
            'down_blocks.0.resnets.1': 'down1.2',
            'mid_block.resnets.0': 'mid.0',
            'mid_block.resnets.1': 'mid.2',
            'up_blocks.0.resnets.0': 'up1.0',
            'up_blocks.0.resnets.1': 'up1.2',
        }
        
        # Transfer weights for compatible layers
        transferred = set()
        for orig_name, param in original_state_dict.items():
            for orig_prefix, custom_prefix in layer_mapping.items():
                if orig_name.startswith(orig_prefix):
                    custom_name = orig_name.replace(orig_prefix, custom_prefix)
                    if custom_name in custom_state_dict:
                        if custom_state_dict[custom_name].shape == param.shape:
                            custom_state_dict[custom_name].copy_(param)
                            transferred.add(custom_name)
        
        # Load the transferred weights
        self.model.load_state_dict(custom_state_dict, strict=False)
        
        print(f"Transferred weights for {len(transferred)} layers")

    def normalize_data(self, data, stats):
        return ((data - stats["min"]) / (stats["max"] - stats["min"])) * 2 - 1

    def unnormalize_data(self, ndata, stats):
        return ((ndata + 1) / 2) * (stats["max"] - stats["min"]) + stats["min"]

    @torch.no_grad()
    def predict(self, observation):
        # Ensure observation is a tensor and has batch dimension
        if not isinstance(observation, torch.Tensor):
            observation = torch.tensor(observation, device=self.device)
        if observation.dim() == 1:
            observation = observation.unsqueeze(0)
            
        observation = observation.to(self.device)
        normalized_obs = self.normalize_data(observation, self.stats["obs"])
        
        # Generate context
        cond = self.obs_projection(self.obs_encoder(normalized_obs))
        cond = cond.view(normalized_obs.shape[0], -1, 1).expand(-1, -1, 16)
        
        # Initialize with noise
        action = torch.randn((observation.shape[0], 2, 16), device=self.device)
        
        # Denoise
        self.noise_scheduler.set_timesteps(100)
        for t in self.noise_scheduler.timesteps:
            model_input = torch.cat([action, cond], dim=1)
            model_output = self.model(model_input, t.to(self.device))
            action = self.noise_scheduler.step(model_output, t.to(self.device), action).prev_sample
        
        action = action.transpose(1, 2)
        action = self.unnormalize_data(action, self.stats["action"])
        return action

if __name__ == "__main__":
    policy = DiffusionPolicy()
    
    # Test with sample observation
    obs = torch.tensor([[
        256.0,  # robot arm x position
        256.0,  # robot arm y position
        200.0,  # block x position
        300.0,  # block y position
        np.pi / 2,  # block angle
    ]])
    
    action = policy.predict(obs)
    print("Action shape:", action.shape)
    print("\nPredicted trajectory:")
    for i, (x, y) in enumerate(action[0]):
        print(f"Step {i:2d}: x={x:6.1f}, y={y:6.1f}")

Transferred weights for 0 layers
Action shape: torch.Size([1, 16, 2])

Predicted trajectory:
Step  0: x=  48.9, y= 449.9
Step  1: x= 118.0, y=  49.0
Step  2: x= 191.8, y= 110.5
Step  3: x= 501.7, y= 512.0
Step  4: x=   0.0, y= 425.6
Step  5: x= 378.3, y=   0.0
Step  6: x=  39.3, y=   0.5
Step  7: x= 474.6, y= 372.0
Step  8: x=  17.0, y= 398.2
Step  9: x=  30.2, y= 369.9
Step 10: x=  11.4, y= 503.2
Step 11: x= 512.0, y= 424.3
Step 12: x= 415.7, y= 508.0
Step 13: x= 357.8, y= 503.9
Step 14: x= 294.6, y= 512.0
Step 15: x= 219.7, y=  87.5
