In [1]:
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F 
import math 

In [2]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

In [3]:
class FiLM(nn.Module):
    def __init__(self, out_channels, cond_dim):
        super().__init__()
        # FiLM modulation https://arxiv.org/abs/1709.07871
        # predicts per-channel scale and bias
        cond_channels = out_channels * 2
        self.out_channels = out_channels
        self.cond_encoder = nn.Sequential(
            nn.Mish(),
            nn.Linear(cond_dim, cond_channels),
            nn.Unflatten(-1, (-1, 1))
        )

    def forward(self, x, cond):
        embed = self.cond_encoder(cond)

        embed = embed.reshape(
            embed.shape[0], 2, self.out_channels, 1)
        scale = embed[:,0,...]
        bias = embed[:,1,...]
        out = scale * x + bias    # FiLM modulation
        return out


In [4]:
class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.GroupNorm(n_groups, out_channels),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)

In [5]:
class CondBlock1D(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            cond_dim,
            kernel_size=3,
            n_groups=8):
        super().__init__()
        
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=kernel_size // 2)
        self.gn1 = nn.GroupNorm(n_groups, out_channels)
        self.gn2 = nn.GroupNorm(n_groups, out_channels) 
        self.film = FiLM(out_channels, cond_dim)  
        
    def forward(self, x, cond): 
        """ 
        Take input and condition, apply FiLM and return output
        """
        x = self.conv1(x) 
        x = F.mish(  self.gn1(x) ) 
        
        x = self.film(x, cond)
        
        x = self.conv2(x)
        x = F.mish(  self.gn2(x) ) 
        return x 

In [11]:
class DownModule(nn.Module):
    def __init__(self, dim_in, dim_out, cond_dim, kernel_size, n_groups, is_last=False):
        super().__init__()
        self.crb=CondBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups) 
        if is_last:
            self.downsample = nn.Identity()
        else: 
            self.downsample  =  nn.Conv1d(dim_out, dim_out, 3, 2, 1)
 
    def forward(self, x, cond):
        x = self.crb(x, cond)
        x_small = self.downsample(x)

        return x, x_small

In [12]:
class UpModule(nn.Module):
    def __init__(self, dim_in, dim_out, cond_dim, kernel_size, n_groups, is_last=False):
        super().__init__()
        if is_last:
            self.upsample = nn.Identity()
        else: 
            self.upsample = nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1)
    
        self.crb = CondBlock1D(
                    dim_in, dim_out, cond_dim=cond_dim,
                    kernel_size=kernel_size, n_groups=n_groups) 

    def forward(self, x, x_down, cond):
        x = torch.cat((x, x_down), dim=1)    #unet skip connection
        x = self.crb(x, cond)
        x = self.upsample(x)  
        return x

In [6]:
class ConditionalUnet1D(nn.Module):
    def __init__(self,
        input_dim,
        global_cond_dim,
        diffusion_step_embed_dim=256,
        down_dims=[256,512,1024],
        kernel_size=5,
        n_groups=8
        ):
        """
        input_dim: Dim of actions.
        global_cond_dim: Dim of global conditioning applied with FiLM
          in addition to diffusion step embedding. This is usually obs_horizon * obs_dim
        diffusion_step_embed_dim: Size of positional encoding for diffusion iteration k
        down_dims: Channel size for each UNet level.
          The length of this array determines numebr of levels.
        kernel_size: Conv kernel size
        n_groups: Number of groups for GroupNorm
        """

        super().__init__()
        all_dims = [input_dim] + list(down_dims)
        start_dim = down_dims[0]

        dsed = diffusion_step_embed_dim
        diffusion_step_encoder = nn.Sequential(
            SinusoidalPosEmb(dsed),
            nn.Linear(dsed, dsed * 4),
            nn.Mish(),
            nn.Linear(dsed * 4, dsed),
        )
        cond_dim = dsed + global_cond_dim

        in_out = list(zip(all_dims[:-1], all_dims[1:]))
        mid_dim = all_dims[-1]
        self.mid_modules = nn.ModuleList([
            CondBlock1D(
                mid_dim, mid_dim, cond_dim=cond_dim,
                kernel_size=kernel_size, n_groups=n_groups
            ), 
        ])

        down_modules = nn.ModuleList([])  
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (len(in_out) - 1)
            down_modules.append(
                DownModule(dim_in, dim_out, cond_dim, kernel_size, n_groups, is_last)
            )  

        up_modules = nn.ModuleList([])
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (len(in_out) - 1)
            up_modules.append(
                UpModule(dim_out*2, dim_in, cond_dim, kernel_size, n_groups, is_last)
            )  

        final_conv = nn.Sequential(
            Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
            nn.Conv1d(start_dim, input_dim, 1),
        )

        self.diffusion_step_encoder = diffusion_step_encoder
        self.up_modules = up_modules
        self.down_modules = down_modules
        self.final_conv = final_conv 

 
    def forward(self,
            sample: torch.Tensor,
            timestep: Union[torch.Tensor, float, int],
            global_cond=None):
        """
        x: (B,T,input_dim)
        timestep: (B,) or int, diffusion step
        global_cond: (B,global_cond_dim)
        output: (B,T,input_dim)
        """
        # (B,T,C)
        sample = sample.moveaxis(-1,-2)  # (B,C,T)

        # 1. time 
        timestep  = timestep.expand(sample.shape[0]) 
        positional_feature = self.diffusion_step_encoder(timestep)

        global_feature = torch.cat([positional_feature, global_cond], axis=-1)

        # unet training
        x = sample
        h = []
        for idx, down_module in enumerate(self.down_modules):
            x, x_small = down_module(x, global_feature) 
            h.append(x)
            x = x_small 

        for mid_module in self.mid_modules:
            x = mid_module(x, global_feature)

        for idx, upmodule  in enumerate(self.up_modules):
            x= upmodule(x, h.pop(), global_feature) 

        x = self.final_conv(x)

        # (B,C,T)
        x = x.moveaxis(-1,-2)
        # (B,T,C)
        return x


In [7]:
vision_feature_dim = 512
lowdim_obs_dim = 2
obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 2
obs_horizon = 2
action_horizon=16
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)

In [8]:
B=64
T=100
obs_cond_dim = 1028

noisy_actions = torch.randn(B,action_horizon,action_dim) 
timesteps = torch.randint(0, T,(B,)).long()
obs_cond = torch.randn(B,obs_cond_dim)

noisy_actions.shape, timesteps.shape, obs_cond.shape

(torch.Size([64, 16, 2]), torch.Size([64]), torch.Size([64, 1028]))

In [9]:
noise_pred = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)
noise_pred.shape

torch.Size([64, 16, 2])

In [10]:
noise_pred_net

ConditionalUnet1D(
  (mid_modules): ModuleList(
    (0): CondBlock1D(
      (conv1): Conv1d(1024, 1024, kernel_size=(5,), stride=(1,), padding=(2,))
      (conv2): Conv1d(1024, 1024, kernel_size=(5,), stride=(1,), padding=(2,))
      (gn1): GroupNorm(8, 1024, eps=1e-05, affine=True)
      (gn2): GroupNorm(8, 1024, eps=1e-05, affine=True)
      (film): FiLM(
        (cond_encoder): Sequential(
          (0): Mish()
          (1): Linear(in_features=1284, out_features=2048, bias=True)
          (2): Unflatten(dim=-1, unflattened_size=(-1, 1))
        )
      )
    )
  )
  (diffusion_step_encoder): Sequential(
    (0): SinusoidalPosEmb()
    (1): Linear(in_features=256, out_features=1024, bias=True)
    (2): Mish()
    (3): Linear(in_features=1024, out_features=256, bias=True)
  )
  (up_modules): ModuleList(
    (0): UpModule(
      (upsample): ConvTranspose1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
      (crb): CondBlock1D(
        (conv1): Conv1d(2048, 512, kernel_size=(5