In [1]:
import torch
from torch import nn
from data import default_skeleton, hierarchical_order
from transforms3d.euler import euler2mat
from typing import Optional
import torch.nn.functional as F

In [6]:
class ForwardKinematics(nn.Module):
    def __init__(self):
        super().__init__()
        self.bones = nn.ModuleDict({
            name: Bone(bone.axis, bone.direction, bone.length)
            for (name, bone) in default_skeleton.bone_map.items()
            if not name == "root"
        })
    def forward(self, x):
        # (B, 282) -> (B, 93)
        bone_cache = {}
        i = 0
        for bone in hierarchical_order:
            if bone == "root":
                tail_position = x[:, i:i+3]
                i += 3
                global_rotation = x[:, i:i + 9].view(-1, 3, 3)
                i += 9
                bone_cache[bone] = (global_rotation, tail_position)
            else:
                parent = default_skeleton.bone_map[bone].parent
                assert parent is not None
                assert parent.name in bone_cache
                local_rotation = x[:, i: i + 9].view(-1, 3, 3)
                i += 9
                bone_cache[bone] = self.bones[bone](
                    *bone_cache[parent.name], local_rotation)

        return torch.cat([bone_cache[bone][1] for bone in hierarchical_order], dim=1)

class Bone(nn.Module):
    def __init__(self, axis, direction, length):
        super().__init__()
        axis = torch.tensor(axis).deg2rad().to(torch.float32)
        bind = torch.tensor(euler2mat(*axis.tolist())).to(torch.float32)
        inverse_bind = bind.inverse().to(torch.float32)
        direction = torch.tensor(direction).to(torch.float32)
        length = torch.tensor(length).to(torch.float32)
        self.register_buffer("bind", bind)
        self.register_buffer("inverse_bind", inverse_bind)
        self.register_buffer("direction", direction)
        self.register_buffer("length", length)

    def forward(self, parent_global_transform, parent_tail_position, local_rotation):
        global_transform = parent_global_transform @ self.bind @ local_rotation @ self.inverse_bind
        tail_position = parent_tail_position + self.length * \
            (global_transform @ self.direction)
        return global_transform, tail_position

class FeedFowardBlock(nn.Module):
    def __init__(self, input_embedding_size: int, hidden_embedding_size: int, output_embedding_size: int, dropout: float):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_embedding_size, hidden_embedding_size),
            nn.GELU(),
            nn.Linear(hidden_embedding_size, output_embedding_size),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)

class AttentionBlock(nn.Module):
    def __init__(
        self,
        input_embedding_size: int,
        hidden_embedding_size: int,
        output_embedding_size: int,
        attention_head_count: int,
        dropout: float,
        device='cpu'
    ):
        super().__init__()
        self.device = device

        self.ln_1 = nn.LayerNorm(input_embedding_size)
        self.ln_2 = nn.LayerNorm(input_embedding_size)

        self.mlp = FeedFowardBlock(
            input_embedding_size, hidden_embedding_size, output_embedding_size, dropout)

        self.attn = nn.MultiheadAttention(
            input_embedding_size, attention_head_count, dropout, batch_first=True)

    def forward(self, query: torch.Tensor, key_value: Optional[torch.Tensor] = None, mask = True) -> torch.Tensor:
        length = query.size(1)

        attn_mask = torch.triu(torch.ones(length, length) *
                          float('-inf'), diagonal=1).to(self.device) if mask else None

        key_value = key_value if key_value is not None else query

        x, _ = self.attn(self.ln_1(query), self.ln_1(key_value), self.ln_1(key_value), attn_mask=attn_mask)

        x = self.mlp(self.ln_2(x))

        return x


class Transformer(nn.Module):
    def __init__(
        self,
        number_of_layers: int,
        input_embedding_size: int,
        hidden_embedding_size: int,
        output_embedding_size: int,
        attention_head_count: int,
        dropout: float
    ):
        super().__init__()
        self.layers = nn.ModuleList([AttentionBlock(input_embedding_size, 
                                                    hidden_embedding_size,
                                                    input_embedding_size, 
                                                    attention_head_count, 
                                                    dropout) for _ in range(number_of_layers)])
        
        self.ln = nn.LayerNorm(input_embedding_size)
        self.projection = FeedFowardBlock(input_embedding_size, hidden_embedding_size, output_embedding_size, dropout)
        
    def forward(self, x: torch.Tensor, encoder_input:Optional[torch.Tensor] = None, mask=True) -> torch.Tensor:
        x_ = x
        for layer in self.layers:
            x_ = layer(x_, key_value=encoder_input, mask=mask)
        
        x_ = self.ln(x_)
        x_ = x + x_
        x_ = self.projection(x_) 
        return x_


class Denoiser(nn.Module):
    def __init__(
        self,
        encoder_layers: int = 16,
        decoder_layers: int = 16,
        cross_attention_layers: int = 16,
        attention_head_count: int = 8,
        input_embedding_size: int = 256,
        block_size: int = 60,
        feature_length: int = 90,
        timesteps: int = 300,
        dropout: float = 0.1
    ):
        super().__init__()

        self.encoder = Transformer(encoder_layers, input_embedding_size, input_embedding_size * 2, input_embedding_size, attention_head_count, dropout)
        self.decoder = Transformer(decoder_layers, input_embedding_size, input_embedding_size * 2, input_embedding_size, attention_head_count, dropout)
        self.cross_attention = Transformer(cross_attention_layers, input_embedding_size, input_embedding_size * 2, input_embedding_size, attention_head_count, dropout)

        self.positional_embedding = nn.Embedding(block_size, input_embedding_size)
        self.time_embedding = nn.Embedding(timesteps, input_embedding_size)
        self.feature_embedding = nn.Linear(feature_length, input_embedding_size)

    def forward(self, x: torch.Tensor, c: torch.Tensor, c_i: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # x = (b, block_size, feature_length) c = (b, context_length, feature_length) c_i = (b, context_length,)   t = (b,)

        time_embd = self.time_embedding(t).unsqueeze(1)

        x_feature_embd = self.feature_embedding(x)
        x_position_embd = self.positional_embedding(torch.arange(x.size(1)))

        c_feature_embd = self.feature_embedding(c)
        c_position_embd = self.positional_embedding(c_i)

        x = x_feature_embd + x_position_embd + time_embd
        c = c_feature_embd + c_position_embd + time_embd

        c = self.encoder(c, mask=False)
        x = self.decoder(x, mask=True)
        x = self.cross_attention(x, c, mask=False)

        return x


def linear_beta_schedule(timesteps: int):
    scale = 1000 / timesteps
    start = scale * 0.0001
    end = scale * 0.02
    return torch.linspace(start, end, timesteps)


def extract(a, t, x_shape):
    b = t.shape[0]
    out = a.gather(-1, t)
    return out.reshape(b, *((1, ) * (len(x_shape) - 1)))


class Diffusion(nn.Module):
    def __init__(self, timesteps: int):
        super().__init__()
        betas = linear_beta_schedule(timesteps)
        alphas = 1. - betas

        alphas_cumprod = torch.cumprod(alphas, 0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)

        self.register_buffer("betas", betas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register_buffer("sqrt_one_minus_alphas_cumprod",
                             torch.sqrt(1. - alphas_cumprod))
        
        self.denoiser = Denoiser()

    def forward_diffusion_sample(self, x_0, t):
        noise = torch.randn_like(x_0)
        return extract(self.sqrt_alphas_cumprod, t, x_0.shape) * x_0 + extract(self.sqrt_one_minus_alphas_cumprod, t, x_0.shape) * noise, noise

In [7]:
diffusion = Diffusion(400)

count = 0
for p in diffusion.parameters():
    count += p.numel()
print(f"{count * 4 / 1024 / 1024:.2f} Mb")

99.97 Mb


In [9]:
c = torch.randn(64, 10, 90)
c_i = torch.randint(0, 60, (64, 10,))   
x = torch.randn(64, 60, 90)
t = torch.randint(0, 300, (64,))


# encoded_input = pose_encoder(input_poses, mask=False)
# decoded_poses = pose_decoder(initial_noise)
# output = cross_attention(decoded_poses, encoder_input=encoded_input, mask=False)


# a = diffusion.denoiser(x, c, c_i, t)

diffusion.forward_diffusion_sample(x, t)[0].shape



torch.Size([64, 60, 90])