In [1]:
from typing import Tuple, Sequence, Dict, Union, Optional, Callable
from torchsummary import summary
from my_unet import ConditionalUnet1D
from my_vision_encoder import ResidualBlock, ResNetFe
from my_ddpm import MyScheduler, MyDDPM
import torch
import torch.nn as nn

In [2]:
def replace_submodules(
        root_module: nn.Module,
        predicate: Callable[[nn.Module], bool],
        func: Callable[[nn.Module], nn.Module]) -> nn.Module:
    """
    Replace all submodules selected by the predicate with
    the output of func.

    predicate: Return true if the module is to be replaced.
    func: Return new module to use.
    """
    if predicate(root_module):
        return func(root_module)

    bn_list = [k.split('.') for k, m
        in root_module.named_modules(remove_duplicate=True)
        if predicate(m)]
    for *parent, k in bn_list:
        parent_module = root_module
        if len(parent) > 0:
            parent_module = root_module.get_submodule('.'.join(parent))
        if isinstance(parent_module, nn.Sequential):
            src_module = parent_module[int(k)]
        else:
            src_module = getattr(parent_module, k)
        tgt_module = func(src_module)
        if isinstance(parent_module, nn.Sequential):
            parent_module[int(k)] = tgt_module
        else:
            setattr(parent_module, k, tgt_module)
    # verify that all modules are replaced
    bn_list = [k.split('.') for k, m
        in root_module.named_modules(remove_duplicate=True)
        if predicate(m)]
    assert len(bn_list) == 0
    return root_module

def replace_bn_with_gn(
    root_module: nn.Module,
    features_per_group: int=16) -> nn.Module:
    """
    Relace all BatchNorm layers with GroupNorm.
    """
    replace_submodules(
        root_module=root_module,
        predicate=lambda x: isinstance(x, nn.BatchNorm2d),
        func=lambda x: nn.GroupNorm(
            num_groups=x.num_features//features_per_group,
            num_channels=x.num_features)
    )
    return root_module


In [3]:
vision_encoder = ResNetFe(ResidualBlock, [2, 2]) 
vision_encoder = replace_bn_with_gn(vision_encoder)

vision_encoder = vision_encoder.cuda()

n_params= sum(p.numel() for p in vision_encoder.parameters())
print(f"Number of parameters: {n_params:,}")

Number of parameters: 749,120


In [None]:
input_size = (64, 3, 96, 96)  
summary(vision_encoder, input_size=(64, 3, 96, 96))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 48, 48]           9,408
         GroupNorm-2           [-1, 64, 48, 48]             128
              ReLU-3           [-1, 64, 48, 48]               0
         MaxPool2d-4           [-1, 64, 24, 24]               0
            Conv2d-5           [-1, 64, 24, 24]          36,864
         GroupNorm-6           [-1, 64, 24, 24]             128
            Conv2d-7           [-1, 64, 24, 24]          36,864
         GroupNorm-8           [-1, 64, 24, 24]             128
     ResidualBlock-9           [-1, 64, 24, 24]               0
           Conv2d-10           [-1, 64, 24, 24]          36,864
        GroupNorm-11           [-1, 64, 24, 24]             128
           Conv2d-12           [-1, 64, 24, 24]          36,864
        GroupNorm-13           [-1, 64, 24, 24]             128
    ResidualBlock-14           [-1, 64,

In [5]:
sample = torch.randn(64, 3, 96, 96).cuda()
output = vision_encoder(sample)
sample.shape, output.shape

(torch.Size([64, 3, 96, 96]), torch.Size([64, 512]))

In [6]:
x=sample
self=vision_encoder

if len(x.shape) == 4:
    x = x.unsqueeze(1)
batch_size, seq_len, channels, height, width = x.shape
x = x.view(batch_size * seq_len, channels, height, width)

x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x.shape

torch.Size([64, 64, 24, 24])

In [None]:
def forward(self, x):
    if len(x.shape) == 4:
        x = x.unsqueeze(1)
    batch_size, seq_len, channels, height, width = x.shape
    x = x.view(batch_size * seq_len, channels, height, width)
    
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)
    return x

In [None]:
summary(vision_encoder.layer1, input_size=(64, 64, 24, 24))

In [19]:
obs_horizon = 2
vision_feature_dim = 512
# agent_pos is 2 dimensional
lowdim_obs_dim = 2
# observation feature has 514 dims in total per step
obs_dim = vision_feature_dim + lowdim_obs_dim
action_dim = 2

# create network object
noise_pred_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim*obs_horizon
)
noise_pred_net = noise_pred_net.cuda()
pass 

In [None]:
# summary(noise_pred_net, input_size=(64,16, 2))

In [22]:
noisy_actions = torch.randn(64, 16, 2).cuda()
obs_cond = torch.randn(64, 1028).cuda()
timesteps = torch.randint(0, 100,(64,) ).long()
timesteps = timesteps.cuda()
noise_pred = noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)

noisy_actions.shape, timesteps.shape, obs_cond.shape, noise_pred.shape  

(torch.Size([64, 16, 2]),
 torch.Size([64]),
 torch.Size([64, 1028]),
 torch.Size([64, 16, 2]))

In [38]:
class NoiseWrapper(nn.Module):
    def __init__(self, noise_pred_net: nn.Module):
        super().__init__()
        self.noise_pred_net = noise_pred_net

    def forward(self, noisy_actions):
        noisy_actions = torch.randn(64, 16, 2).cuda()
        # print('nw: ', noisy_actions.shape)
        obs_cond = torch.randn(64, 1028).cuda()
        timesteps = torch.randint(0, 100,(64,) ).long()
        timesteps = timesteps.cuda()
        noise_pred = self.noise_pred_net(noisy_actions, timesteps, global_cond=obs_cond)
        return noise_pred

In [39]:
nw=NoiseWrapper(noise_pred_net)
out=nw(noisy_actions)
out.shape

torch.Size([64, 16, 2])

In [40]:
summary(nw, input_size=(64,16, 2))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
  SinusoidalPosEmb-1                  [-1, 256]               0
            Linear-2                 [-1, 1024]         263,168
              Mish-3                 [-1, 1024]               0
            Linear-4                  [-1, 256]         262,400
            Conv1d-5              [-1, 256, 16]           2,816
         GroupNorm-6              [-1, 256, 16]             512
              Mish-7              [-1, 256, 16]               0
       Conv1dBlock-8              [-1, 256, 16]               0
              Mish-9                 [-1, 1284]               0
           Linear-10                  [-1, 512]         657,920
        Unflatten-11               [-1, 512, 1]               0
             FiLM-12              [-1, 256, 16]               0
           Conv1d-13              [-1, 256, 16]         327,936
        GroupNorm-14              [-1, 

In [43]:
noise_pred_net.down_modules[0]

DownModule(
  (crb): ConditionalResidualBlock1D(
    (block1): Conv1dBlock(
      (block): Sequential(
        (0): Conv1d(2, 256, kernel_size=(5,), stride=(1,), padding=(2,))
        (1): GroupNorm(8, 256, eps=1e-05, affine=True)
        (2): Mish()
      )
    )
    (block2): Conv1dBlock(
      (block): Sequential(
        (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,))
        (1): GroupNorm(8, 256, eps=1e-05, affine=True)
        (2): Mish()
      )
    )
    (film): FiLM(
      (cond_encoder): Sequential(
        (0): Mish()
        (1): Linear(in_features=1284, out_features=512, bias=True)
        (2): Unflatten(dim=-1, unflattened_size=(-1, 1))
      )
    )
    (residual_conv): Conv1d(2, 256, kernel_size=(1,), stride=(1,))
  )
  (downsample): Conv1d(256, 256, kernel_size=(3,), stride=(2,), padding=(1,))
)

In [44]:
noise_pred_net.up_modules[0]

UpModule(
  (upsample): ConvTranspose1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
  (crb): ConditionalResidualBlock1D(
    (block1): Conv1dBlock(
      (block): Sequential(
        (0): Conv1d(2048, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        (1): GroupNorm(8, 512, eps=1e-05, affine=True)
        (2): Mish()
      )
    )
    (block2): Conv1dBlock(
      (block): Sequential(
        (0): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        (1): GroupNorm(8, 512, eps=1e-05, affine=True)
        (2): Mish()
      )
    )
    (film): FiLM(
      (cond_encoder): Sequential(
        (0): Mish()
        (1): Linear(in_features=1284, out_features=1024, bias=True)
        (2): Unflatten(dim=-1, unflattened_size=(-1, 1))
      )
    )
    (residual_conv): Conv1d(2048, 512, kernel_size=(1,), stride=(1,))
  )
)