In [3]:
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass
import os
import sys
import pandas as pd
import re
sys.path.append('..')
from utils.amc_parser import parse_motion_file, Motion, MotionFrame


In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x7bb6143326d0>

In [19]:
@dataclass
class PoseFlowConfig:
    block_size: int = 50
    frame_rate: int = 10
    max_iters: int = 5000
    n_embd: int = 256
    feature_length: int = 56
    n_layer: int = 12  
    n_head: int = 8
    dropout: float = 0.2
    device: str = 'cpu' if torch.cuda.is_available() else 'cpu'
    batch_size: int = 64
    lr: float = 1e-4


config = PoseFlowConfig()

In [20]:

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd),
            nn.Dropout(config.dropout),
        )
        self.attn = nn.MultiheadAttention(config.n_embd, config.n_head, config.dropout)


    def forward(self,x):
        x = x + self.attn(self.ln_1(x), self.ln_1(x), self.ln_1(x))[0]
        x = x + self.mlp(self.ln_2(x))
        return x


class PoseFlowModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.config = config
        self.pose_encoder = nn.Linear(config.feature_length, config.n_embd)
        self.position_embedding = nn.Embedding(config.block_size, config.n_embd)
        self.blocks = nn.Sequential(*[Block() for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.feature_length)
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        B, T, C = x.shape
        pose_emb = self.pose_encoder(x)
        pos_emb = self.position_embedding(torch.arange(T, device=config.device))
        x = pose_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits
    
    def generate(self, x, n_frames):
        ...

    def stream(self, x):
        # x: B, T, C
        while True:
            x = x[:, -config.block_size:, :]
            logits = self.forward(x)
            last_frame = logits[:, -1, :]
            x = torch.cat([x, last_frame.unsqueeze(0)], dim=1)
            yield last_frame

    def fit(self, input):
        optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        for iter in range(config.max_iters):
            iter_losses = []
            for j, batch in enumerate(input):
                x = batch[0]
                losses = []
                for i in range(config.block_size):
                    xb = x[:, :i + 1, :]
                    yb = x[:, i + 1, :]
                    logits = self.forward(xb)[:,-1,:]
                    loss = F.mse_loss(logits, yb)
                    losses.append(loss.item())
                    optimizer.zero_grad(set_to_none=True)
                    loss.backward()
                    optimizer.step()
                iter_losses.append(sum(losses)/len(losses))
            print(f'Iter {iter}, Loss: {sum(iter_losses)/len(iter_losses)}')





In [21]:
config.device

'cpu'

In [22]:
model = PoseFlowModel().to(config.device)

a = torch.randn(1, 3, config.feature_length)

total_el = 0

for parameter in model.parameters():
    total_el += parameter.numel()
print(total_el)

9519416


In [None]:
for parameter in model.parameters():
    total_el += parameter.numel()
print(total_el)

In [None]:
model.train()
model.fit(DataLoader(TensorDataset(data.to(config.device)), batch_size=config.batch_size))
# model = torch.load("model.bin")

In [None]:
data = torch.load("data.bin")

In [27]:
model.load_state_dict(torch.load("model.pth", config.device))

<All keys matched successfully>

In [34]:
generated_motion = []

model.eval()
for (pose, _) in zip(model.stream(data[3, 0:1].unsqueeze(0).to(config.device)), range(600)):
    generated_motion.append(to_motion_frame(pose.squeeze(0)))
    

In [24]:
from amc_parser.amc_parser import parse_amc, parse_asf
from amc_parser.viewer import Viewer

pygame 2.5.2 (SDL 2.28.2, Python 3.12.3)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [37]:
asf = parse_asf(os.path.join(data_dir, '01/01.asf'))

viewer = Viewer(asf, generated_motion)
viewer.run()

KeyboardInterrupt: 

: 

In [8]:
skeleton_structure = {
    # "root": 6,
    "lowerback": 3,
    "upperback": 3,
    "thorax": 3,
    "lowerneck": 3,
    "upperneck": 3,
    "head": 3,
    "rclavicle": 2,
    "rhumerus": 3,
    "rradius": 1,
    "rwrist": 1,
    "rhand": 2,
    "rfingers": 1,
    "rthumb": 2,
    "lclavicle": 2,
    "lhumerus": 3,
    "lradius": 1,
    "lwrist": 1,
    "lhand": 2,
    "lfingers": 1,
    "lthumb": 2,
    "rfemur": 3,
    "rtibia": 1,
    "rfoot": 2,
    "rtoes": 1,
    "lfemur": 3,
    "ltibia": 1,
    "lfoot": 2,
    "ltoes": 1
}

In [9]:
def to_tensor(motion: Motion):
    return torch.tensor([[channel for bone in skeleton_structure for channel in frame[bone]] for frame in motion.frames])


In [10]:
def to_motion(tensor: torch.Tensor) -> Motion:
    #tensor = (T, C = 62)    
    motion = Motion()
    motion.frame_rate = config.frame_rate

    motion.frames = [to_motion_frame(i)   for i in tensor]
    return motion


def to_motion_frame(tensor: torch.Tensor) -> MotionFrame:
    motion_frame = {}
    motion_frame['root'] = [0] * 6
    idx = 0 
    for bone, size in skeleton_structure.items():
        motion_frame[bone] =  tensor[idx: idx + size].tolist()
        idx += size
    return motion_frame


In [11]:
data_dir = '../all_asfamc/subjects'
raw_metadata_location = '../all_asfamc/mocap-index.html'

In [12]:
raw_metadata = pd.read_html(raw_metadata_location)[5]
raw_metadata

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,,,,,,,,,,
1,"Subject #1 (climb, swing, hang on playground e...","Subject #1 (climb, swing, hang on playground e...","Subject #1 (climb, swing, hang on playground e...",- - asf,- - asf,- - asf,- - asf,- - asf,framerate,Feedback
2,Image,Trial #,Motion Description,,,,,,,
3,,1,"playground - forward jumps, turn around",tvd,c3d,amc,mpg,Animated,120,Feedback
4,,2,playground - climb,tvd,c3d,amc,mpg,Animated,120,Feedback
...,...,...,...,...,...,...,...,...,...,...
2757,,30,Sun Salutation,,c3d,amc,,,120,Feedback
2758,,31,Sun Salutation001,,c3d,amc,,,120,Feedback
2759,,32,Sun Salutation002,,c3d,amc,,,120,Feedback
2760,,33,Walking,,c3d,amc,,,120,Feedback


In [13]:
@dataclass
class MotionInfo:
    frame_rate: int
    description: str

metadata: dict[str, dict[str, MotionInfo]] = {}
current_subject = None
for row in range(raw_metadata.shape[0]):
    subject = raw_metadata.iloc[row, 0]
    if isinstance(subject, str):
        match = re.search(r"Subject #(\d+)", subject)
        if match:
            current_subject = match.group(1)
            metadata[current_subject] = {}
            continue
    
    frame_rate = raw_metadata.iloc[row, 8]

    if isinstance(frame_rate, str) and frame_rate.isdigit():
        description = raw_metadata.iloc[row, 2]
        description = description if isinstance(description, str) else ""
        index = raw_metadata.iloc[row, 1]
        assert isinstance(current_subject, str)
        assert isinstance(index, str)
        metadata[current_subject][index] = MotionInfo(int(frame_rate), description)


In [14]:
def to_target_frame_rate(motion: torch.Tensor, current_frame_rate: int, target_frame_rate, average: bool = False) -> torch.Tensor:
    assert current_frame_rate >= target_frame_rate
    if current_frame_rate == target_frame_rate:
        return motion
    factor = current_frame_rate // target_frame_rate
    idx = torch.arange(0, len(motion), factor)
    if average:
        return torch.stack([motion[i: min(len(motion), i + factor)].mean(0)  for i in idx])
    else: 
        return motion[idx]


In [15]:
data = []
for root, dirs, files in os.walk(os.path.join(data_dir, "05")):
    subject = str(int(os.path.basename(root))) if os.path.basename(root).isdigit() else root
    for file in files:
        if file.endswith('.amc'):
            index = str(int(file.split(".")[0].split("_")[1]))
            if subject not in metadata or index not in metadata[subject]:
                continue
            motion_info = metadata[subject][index]
            motion_data = to_tensor(parse_motion_file(os.path.join(root, file)))
            motion_data = to_target_frame_rate(motion_data, motion_info.frame_rate, config.frame_rate, average=True)
            if len(motion_data)  < config.block_size + 1:
                continue # not enough frames to make a block
            idx = torch.randint(0, (len(motion_data) - (config.block_size + 1)) + 1, (max(1, len(motion_data) // config.block_size),))
            data.extend([motion_data[i: i + config.block_size + 1] for i in idx])


data = torch.stack(data)


In [None]:
torch.save(data, "data.bin")

In [16]:
assert isinstance(data, torch.Tensor)
data.shape

torch.Size([20, 51, 56])

In [None]:
torch.save(metadata, "metadata.bin")

In [None]:
data[0].shape