In [1]:
# find the correct kernels by brute force
# inspired by Doojin's previous code but automated
# TODO: connect it with the configs so it is easier to use with the problematic config file
#       -> instead of adding the configs by hand, just use the config file and the script will find the correct kernels

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from einops.layers.torch import Rearrange
import pdb


In [2]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim, k=3):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, k, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Upsample1d(nn.Module):
    def __init__(self, dim, k=4):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, k, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    """
    Conv1d --> GroupNorm --> Mish
    """

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(
                inp_channels, out_channels, kernel_size, padding=kernel_size // 2
            ),
            Rearrange("batch channels horizon -> batch channels 1 horizon"),
            nn.GroupNorm(n_groups, out_channels),
            Rearrange("batch channels 1 horizon -> batch channels horizon"),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


In [3]:
class ResidualTemporalBlock(nn.Module):
    def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
        super().__init__()

        self.blocks = nn.ModuleList(
            [
                Conv1dBlock(inp_channels, out_channels, kernel_size),
                Conv1dBlock(out_channels, out_channels, kernel_size),
            ]
        )

        self.time_mlp = nn.Sequential(
            nn.Mish(),
            nn.Linear(embed_dim, out_channels),
            Rearrange("batch t -> batch t 1"),
        )

        self.residual_conv = (
            nn.Conv1d(inp_channels, out_channels, 1)
            if inp_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x, t):
        """
        x : [ batch_size x inp_channels x horizon ]
        t : [ batch_size x embed_dim ]
        returns:
        out : [ batch_size x out_channels x horizon ]
        """
        out = self.blocks[0](x) + self.time_mlp(t)
        out = self.blocks[1](out)
        return out + self.residual_conv(x)

class TemporalUnet(nn.Module):
    def __init__(
        self,
        horizon,
        transition_dim,
        cond_dim,
        dim=32,
        dim_mults=(1, 2, 4, 8),
        kernel_size=5,
        upsample_k=4,
        downsample_k=3,
    ):
        super().__init__()

        self.kernel_size = kernel_size
        dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        print(f"[ models/temporal ] Channel dimensions: {in_out}")

        if isinstance(upsample_k, int):
            upsample_k = [upsample_k] * len(in_out)
        if isinstance(downsample_k, int):
            downsample_k = [downsample_k] * len(in_out)
        time_dim = dims[1]
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_dim),
            nn.Linear(time_dim, dim * 4),
            nn.Mish(),
            nn.Linear(dim * 4, time_dim),
        )

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)
        down_and_up = []

        print(in_out)
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            if not is_last:
                down_and_up.append(1)
            else:
                down_and_up.append(0)

            layer = nn.ModuleList(
                [
                    ResidualTemporalBlock(
                        dim_in,
                        dim_out,
                        embed_dim=time_dim,
                        horizon=horizon,
                        kernel_size=kernel_size,
                    ),
                    ResidualTemporalBlock(
                        dim_out,
                        dim_out,
                        embed_dim=time_dim,
                        horizon=horizon,
                        kernel_size=kernel_size,
                    ),
                    Downsample1d(dim_out, k=downsample_k[ind])
                    if down_and_up[ind]
                    else nn.Identity(),
                ]
            )
            self.downs.append(layer)

            if not is_last:
                horizon = horizon // 2

        mid_dim = dims[-1]
        self.mid_block1 = ResidualTemporalBlock(
            mid_dim,
            mid_dim,
            embed_dim=time_dim,
            horizon=horizon,
            kernel_size=kernel_size,
        )
        self.mid_block2 = ResidualTemporalBlock(
            mid_dim,
            mid_dim,
            embed_dim=time_dim,
            horizon=horizon,
            kernel_size=kernel_size,
        )

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            layer = nn.ModuleList(
                [
                    ResidualTemporalBlock(
                        dim_out * 2,
                        dim_in,
                        embed_dim=time_dim,
                        horizon=horizon,
                        kernel_size=kernel_size,
                    ),
                    ResidualTemporalBlock(
                        dim_in,
                        dim_in,
                        embed_dim=time_dim,
                        horizon=horizon,
                        kernel_size=kernel_size,
                    ),
                    Upsample1d(dim_in, k=upsample_k[ind])
                    if down_and_up[::-1][ind + 1]
                    else nn.Identity(),
                ]
            )
            self.ups.append(layer)

            if not is_last:
                horizon = horizon * 2

        self.final_conv = nn.Sequential(
            Conv1dBlock(dim_in, dim_in, kernel_size=5),
            nn.Conv1d(dim_in, transition_dim, 1),
        )

    def forward(self, x, cond, time):
        """
        x : [ batch x horizon x transition ]
        """

        x = einops.rearrange(x, "b h t -> b t h")

        t = self.time_mlp(time)
        h = []
        # from diffuser.utils.debug import debug
        # debug()
        for resnet, resnet2, downsample in self.downs:
            x = resnet(x, t)
            x = resnet2(x, t)
            h.append(x)
            x = downsample(x)
        # debug()
        x = self.mid_block1(x, t)
        x = self.mid_block2(x, t)

        for resnet, resnet2, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, t)
            x = resnet2(x, t)
            x = upsample(x)
        # debug()
        x = self.final_conv(x)

        x = einops.rearrange(x, "b t h -> b h t")
        return x


In [23]:
success = []

horizon = 21 #11
transition_dim = 4 + 2# + 6
cond_dim = 4
dim = 32 #128
dim_mults = (2, 2, 4, 8)  # (1,4,8) #(1,2,4,8)
kernel_size = 5
# upsample_k = (3,3,3)
# downsample_k = (3,3,3)
out_shape = [4, horizon, transition_dim]

import itertools

elements = [3, 4]
combinations = list(itertools.product(elements, repeat=6))
# combinations = list(itertools.product(elements, repeat=4))

for combination in combinations:
    # print(combination)
    upsample_k = combination[:3]
    downsample_k = combination[3:]
    # upsample_k = combination[:2]
    # downsample_k = combination[2:]
    unet = TemporalUnet(horizon, transition_dim, cond_dim,
                        dim, dim_mults, kernel_size, upsample_k, downsample_k)
    x = torch.randn(out_shape) # B, H, D
    try:
        out = unet(x, None, torch.randn(4,))
        print(list(out.shape))
        assert sum([out.shape[i] == out_shape[i] for i in range(len(out_shape))]) == 3
        success.append((upsample_k, downsample_k))
        print("Success", upsample_k, downsample_k)
    except:
        # print("Failed")
        continue
    
print(success)

[ models/temporal ] Channel dimensions: [(6, 64), (64, 64), (64, 128), (128, 256)]
[(6, 64), (64, 64), (64, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 64), (64, 64), (64, 128), (128, 256)]
[(6, 64), (64, 64), (64, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 64), (64, 64), (64, 128), (128, 256)]
[(6, 64), (64, 64), (64, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 64), (64, 64), (64, 128), (128, 256)]
[(6, 64), (64, 64), (64, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 64), (64, 64), (64, 128), (128, 256)]
[(6, 64), (64, 64), (64, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 64), (64, 64), (64, 128), (128, 256)]
[(6, 64), (64, 64), (64, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 64), (64, 64), (64, 128), (128, 256)]
[(6, 64), (64, 64), (64, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 64), (64, 64), (64, 128), (128, 256)]
[(6, 64), (64, 64), (64, 128), (128

In [None]:
success = []

horizon = 40 #11
transition_dim = 4 + 2# + 6
cond_dim = 4
dim = 32 #128
dim_mults = (1,4,8) #(1,2,4,8)
kernel_size = 5
# upsample_k = (3,3,3)
# downsample_k = (3,3,3)
out_shape = [4, horizon, transition_dim]

import itertools

elements = [3, 4]
# combinations = list(itertools.product(elements, repeat=6))
combinations = list(itertools.product(elements, repeat=4))

for combination in combinations:
    # print(combination)
    # upsample_k = combination[:3]
    # downsample_k = combination[3:]
    upsample_k = combination[:2]
    downsample_k = combination[2:]
    unet = TemporalUnet(horizon, transition_dim, cond_dim,
                        dim, dim_mults, kernel_size, upsample_k, downsample_k)
    x = torch.randn(out_shape) # B, H, D
    try:
        out = unet(x, None, torch.randn(4,))
        print(list(out.shape))
        assert sum([out.shape[i] == out_shape[i] for i in range(len(out_shape))]) == 3
        success.append((upsample_k, downsample_k))
        print("Success", upsample_k, downsample_k)
    except:
        # print("Failed")
        continue
    
print(success)

[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[ models/temporal ] Channel dimensions: [(6, 32), (32, 128), (128, 256)]
[(6, 32), (32, 128), (128, 256)]
[4, 39, 6]
[ models/temporal ] Channel dimensi

In [5]:
for s in success:
    print(f'"upsample_k": {tuple(s[0])}, "downsample_k":  {tuple(s[1])},')

"upsample_k": (4, 4, 3), "downsample_k":  (3, 3, 3),
"upsample_k": (4, 4, 3), "downsample_k":  (3, 3, 4),
"upsample_k": (4, 4, 3), "downsample_k":  (3, 4, 3),
"upsample_k": (4, 4, 3), "downsample_k":  (3, 4, 4),
"upsample_k": (4, 4, 3), "downsample_k":  (4, 3, 3),
"upsample_k": (4, 4, 3), "downsample_k":  (4, 3, 4),
"upsample_k": (4, 4, 3), "downsample_k":  (4, 4, 3),
"upsample_k": (4, 4, 3), "downsample_k":  (4, 4, 4),
"upsample_k": (4, 4, 4), "downsample_k":  (3, 3, 3),
"upsample_k": (4, 4, 4), "downsample_k":  (3, 3, 4),
"upsample_k": (4, 4, 4), "downsample_k":  (3, 4, 3),
"upsample_k": (4, 4, 4), "downsample_k":  (3, 4, 4),
"upsample_k": (4, 4, 4), "downsample_k":  (4, 3, 3),
"upsample_k": (4, 4, 4), "downsample_k":  (4, 3, 4),
"upsample_k": (4, 4, 4), "downsample_k":  (4, 4, 3),
"upsample_k": (4, 4, 4), "downsample_k":  (4, 4, 4),


In [6]:
for s in success:
    up = s[0]
    down = s[1]
    if list(up)[::-1] == list(down):
        print(f'"upsample_k": {tuple(up)}, "downsample_k":  {tuple(down)},')

"upsample_k": (4, 4, 3), "downsample_k":  (3, 4, 4),
"upsample_k": (4, 4, 4), "downsample_k":  (4, 4, 4),


In [12]:
upsample_k = (3, 3, 3)
downsample_k = (3, 3, 3)
# model_config = utils.Config(
#     'models.TemporalUnet',
#     # savepath=(args.savepath, "model_config.pkl"),
#     horizon=horizon, #args.horizon // args.jump,
#     transition_dim=transition_dim,
#     cond_dim=cond_dim,
#     dim=dim,
#     dim_mults=dim_mults,
#     kernel_size=kernel_size,
#     upsample_k=upsample_k,
#     downsample_k=downsample_k,
# )

unet = TemporalUnet(horizon, transition_dim, cond_dim,
                    dim, dim_mults, kernel_size, upsample_k, downsample_k)

x = torch.randn(out_shape) # B, H, D
out = unet(x, None, torch.randn(4,))


[ models/temporal ] Channel dimensions: [(10, 32), (32, 128), (128, 256)]
[(10, 32), (32, 128), (128, 256)]
