In [1]:
import os, glob
import math, copy, time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from tqdm import tqdm, trange

import numpy as np

import matplotlib
from matplotlib import pyplot as plt

from IPython.display import HTML
from IPython.display import display

from skeleton_models import ntu_rgbd, ntu_ss_1, ntu_ss_2, ntu_ss_3
from graph import Graph
from render import animate, save_animation
from datasets import NTUDataset, Normalize, CropSequence, SelectDimensions, SelectSubSample

# Model components
from zoo_pose_embedding import TwoLayersGCNPoseEmbedding, JoaosDownsampling, SuperSimpleDownsampling
from zoo_action_encoder_units import TransformerEncoderUnit
from zoo_action_decoder_units import TransformerDecoderUnit
from zoo_upsampling import StepByStepUpsampling, JoaosUpsampling, SuperSimpleUpsampling
from model import ActionEmbeddingTransformer, SimplePoseEncoderDecoder
from layers import subsequent_mask

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
skeleton_model = ntu_rgbd
adjacency = Graph(skeleton_model)
conf_kernel_size = adjacency.A.shape[0]
conf_num_nodes = adjacency.A.shape[1]
conf_heads = 5
conf_encoding_per_node = 100
conf_internal_per_node = int(conf_encoding_per_node/conf_heads)
print(conf_encoding_per_node*conf_num_nodes)


class BetterThatBestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = ActionEmbeddingTransformer(
            JoaosDownsampling(
                conf_num_nodes,
                conf_encoding_per_node*conf_num_nodes,
                node_channel_in = 2,
                device=device
            ),
            TransformerEncoderUnit (
                heads=conf_heads,
                embedding_in=conf_num_nodes*conf_encoding_per_node,
                embedding_out=conf_num_nodes*conf_internal_per_node
            ),
            TransformerDecoderUnit(
                heads=conf_heads,
                embedding_in=conf_num_nodes*conf_encoding_per_node,
                embedding_out=conf_num_nodes*conf_internal_per_node,
                memory_in=conf_num_nodes*conf_encoding_per_node
            ),
            JoaosUpsampling(
                conf_num_nodes,
                conf_encoding_per_node*conf_num_nodes,
                node_channel_out = 2,
                device=device
            )
        )

    def forward(self, x_in, x_out, A, mask):
        return self.model(x_in, x_out, A, mask)
    
model = BetterThatBestModel()

A = torch.from_numpy(adjacency.A).to(device, dtype=torch.float)
model = model.to(device)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

criterion = torch.nn.L1Loss()

composed = transforms.Compose([Normalize(),
                               SelectDimensions(2),
                               SelectSubSample(skeleton_model)
                              ])

#ntu_dataset = NTUDataset(root_dir='../ntu-rgbd-dataset/Python/raw_npy/')
ntu_dataset = NTUDataset(root_dir='../datasets/NTURGB-D/Python/sel_npy/', transform=composed)
loader = DataLoader(ntu_dataset, batch_size=512, shuffle=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)


pbar = tqdm(range(200000), desc='Initializing ...')
for epoch in pbar:
    
    for data in loader:
        data = data.to(device, dtype=torch.float)
        
        n_out, t_out, v_out, c_out = data.size()
        mask = subsequent_mask(t_out).to(device, dtype=torch.float)
        optimizer.zero_grad()
        out = model(data, data, A, mask)

        loss = criterion(out, data)
        loss.backward()

        # update parameters
        optimizer.step()
        pbar.set_description("Curr loss = {:.4f}".format(loss.item()))
        
    if epoch == 0:
        save_animation(data[0], skeleton_model, 'outputs/animations/a_sample_example_epoch_{}.gif'.format(epoch))
    
    if epoch % 250 == 0:
        print('Epoch {} loss = {}'.format(epoch, loss.item()))
        # torch.save(model.state_dict(), 'outputs/models/simple_encoder_epoch_{}.pth'.format(epoch))
        # save_animation(data[0], ntu_rgbd, 'outputs/animations/sample_example_epoch_{}.gif'.format(epoch))
        save_animation(out[0], skeleton_model, 'outputs/animations/out_example_epoch_{}.gif'.format(epoch))

2500
Warn: not activated
Warn: not activated


Curr loss = 0.7442:   0%|          | 0/200000 [00:00<?, ?it/s]

Epoch 0 loss = 0.7442317008972168


Curr loss = 0.0718:   0%|          | 250/200000 [08:02<105:00:57,  1.89s/it]

Epoch 250 loss = 0.07176075875759125


Curr loss = 0.0636:   0%|          | 500/200000 [15:46<92:26:20,  1.67s/it] 

Epoch 500 loss = 0.06359687447547913


Curr loss = 0.0554:   0%|          | 750/200000 [23:00<88:48:05,  1.60s/it] 

Epoch 750 loss = 0.05539754405617714


Curr loss = 0.0582:   0%|          | 1000/200000 [30:10<89:52:08,  1.63s/it]

Epoch 1000 loss = 0.058163076639175415


Curr loss = 0.0540:   1%|          | 1250/200000 [37:27<101:33:54,  1.84s/it]

Epoch 1250 loss = 0.054011598229408264


Curr loss = 0.0514:   1%|          | 1500/200000 [44:58<92:35:24,  1.68s/it] 

Epoch 1500 loss = 0.051385171711444855


Curr loss = 0.0493:   1%|          | 1750/200000 [52:49<98:11:35,  1.78s/it] 

Epoch 1750 loss = 0.04931965097784996


Curr loss = 0.0485:   1%|          | 2000/200000 [1:01:18<133:39:06,  2.43s/it]

Epoch 2000 loss = 0.048507824540138245


Curr loss = 0.0474:   1%|          | 2250/200000 [1:09:07<89:17:19,  1.63s/it] 

Epoch 2250 loss = 0.04742930829524994


Curr loss = 0.0466:   1%|▏         | 2500/200000 [1:17:02<105:07:20,  1.92s/it]

Epoch 2500 loss = 0.04655918478965759


Curr loss = 0.0460:   1%|▏         | 2750/200000 [1:25:01<104:01:44,  1.90s/it]

Epoch 2750 loss = 0.04595132917165756


Curr loss = 0.0455:   2%|▏         | 3000/200000 [1:32:45<88:58:17,  1.63s/it] 

Epoch 3000 loss = 0.04553656280040741


Curr loss = 0.0469:   2%|▏         | 3250/200000 [1:40:20<88:26:35,  1.62s/it] 

Epoch 3250 loss = 0.046893153339624405


Curr loss = 0.0472:   2%|▏         | 3500/200000 [1:47:43<91:31:30,  1.68s/it] 

Epoch 3500 loss = 0.04722341150045395


Curr loss = 0.0439:   2%|▏         | 3750/200000 [1:54:58<87:25:09,  1.60s/it] 

Epoch 3750 loss = 0.04394150152802467


Curr loss = 0.0432:   2%|▏         | 4000/200000 [2:02:06<89:41:40,  1.65s/it] 

Epoch 4000 loss = 0.0432048998773098


Curr loss = 0.0420:   2%|▏         | 4250/200000 [2:09:10<87:25:24,  1.61s/it] 

Epoch 4250 loss = 0.04196389764547348


Curr loss = 0.0458:   2%|▏         | 4500/200000 [2:16:14<86:41:56,  1.60s/it] 

Epoch 4500 loss = 0.04583263024687767


Curr loss = 0.0446:   2%|▏         | 4750/200000 [2:23:16<86:59:33,  1.60s/it] 

Epoch 4750 loss = 0.04458450525999069


Curr loss = 0.0433:   2%|▎         | 5000/200000 [2:30:20<87:12:35,  1.61s/it] 

Epoch 5000 loss = 0.04333155229687691


Curr loss = 0.0401:   3%|▎         | 5250/200000 [2:37:23<87:23:09,  1.62s/it] 

Epoch 5250 loss = 0.040099408477544785


Curr loss = 0.0404:   3%|▎         | 5500/200000 [2:44:25<87:09:52,  1.61s/it] 

Epoch 5500 loss = 0.04038780555129051


Curr loss = 0.0332:   3%|▎         | 5750/200000 [2:51:23<84:52:48,  1.57s/it] 

Epoch 5750 loss = 0.03322248160839081


Curr loss = 0.0265:   3%|▎         | 6000/200000 [2:58:16<84:40:11,  1.57s/it] 

Epoch 6000 loss = 0.02651612088084221


Curr loss = 0.0248:   3%|▎         | 6250/200000 [3:05:09<85:43:56,  1.59s/it] 

Epoch 6250 loss = 0.024833746254444122


Curr loss = 0.0242:   3%|▎         | 6500/200000 [3:12:02<83:59:37,  1.56s/it] 

Epoch 6500 loss = 0.02421506866812706


Curr loss = 0.0304:   3%|▎         | 6750/200000 [3:18:55<84:27:14,  1.57s/it] 

Epoch 6750 loss = 0.030377520248293877


Curr loss = 0.0300:   4%|▎         | 7000/200000 [3:25:47<88:18:17,  1.65s/it] 

Epoch 7000 loss = 0.029960719868540764


Curr loss = 0.0283:   4%|▎         | 7250/200000 [3:32:38<85:12:04,  1.59s/it] 

Epoch 7250 loss = 0.028275303542613983


Curr loss = 0.0274:   4%|▍         | 7500/200000 [3:39:31<84:31:50,  1.58s/it] 

Epoch 7500 loss = 0.02736840397119522


Curr loss = 0.0263:   4%|▍         | 7750/200000 [3:46:25<83:57:48,  1.57s/it] 

Epoch 7750 loss = 0.026272529736161232


Curr loss = 0.0245:   4%|▍         | 7908/200000 [3:51:12<93:36:24,  1.75s/it] 


KeyboardInterrupt: 

In [None]:
skeleton_model = ntu_rgbd
adjacency = Graph(skeleton_model)
conf_kernel_size = adjacency.A.shape[0]
conf_num_nodes = adjacency.A.shape[1]
conf_heads = 5
conf_encoding_per_node = 100
conf_internal_per_node = int(conf_encoding_per_node/conf_heads)
print(conf_encoding_per_node*conf_num_nodes)





In [None]:
print('Using {}'.format(device))

model = SimplePoseEncoderDecoder(
    JoaosDownsampling(
        conf_num_nodes,
        conf_encoding_per_node*conf_num_nodes,
        node_channel_in = 2,
        device=device
    ),
    JoaosUpsampling(
        conf_num_nodes,
        conf_encoding_per_node*conf_num_nodes,
        node_channel_out = 2,
        device=device
    )
)

A = torch.from_numpy(adjacency.A).to(device, dtype=torch.float)
model = model.to(device)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

criterion = torch.nn.L1Loss()

composed = transforms.Compose([Normalize(),
                               SelectDimensions(2),
                               SelectSubSample(skeleton_model)
                              ])

#ntu_dataset = NTUDataset(root_dir='../ntu-rgbd-dataset/Python/raw_npy/')
ntu_dataset = NTUDataset(root_dir='../datasets/NTURGB-D/Python/sel_npy/', transform=composed)
loader = DataLoader(ntu_dataset, batch_size=512, shuffle=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)


pbar = tqdm(range(200000), desc='Initializing ...')
for epoch in pbar:
    
    for data in loader:
        data = data.to(device, dtype=torch.float)

        n_out, c_out, t_out, v_out = data.size()
        mask = subsequent_mask(t_out).to(device, dtype=torch.float)
        optimizer.zero_grad()
        out = model(data, A)

        loss = criterion(out, data)
        loss.backward()

        # update parameters
        optimizer.step()
        pbar.set_description("Curr loss = {:.4f}".format(loss.item()))
        
    if epoch == 0:
        save_animation(data[0], skeleton_model, 'outputs/animations/a_sample_example_epoch_{}.gif'.format(epoch))
    
    if epoch % 3000 == 0:
        print('Epoch {} loss = {}'.format(epoch, loss.item()))
        # torch.save(model.state_dict(), 'outputs/models/simple_encoder_epoch_{}.pth'.format(epoch))
        # save_animation(data[0], ntu_rgbd, 'outputs/animations/sample_example_epoch_{}.gif'.format(epoch))
        save_animation(out[0], skeleton_model, 'outputs/animations/out_example_epoch_{}.gif'.format(epoch))

In [None]:
print(ntu_dataset[0][:,:,1].min())
print(ntu_dataset[0][:,:,0].min())
print(ntu_dataset[0][:,:,1].max())
print(ntu_dataset[0][:,:,0].max())
print(ntu_dataset[0][:,:,1].mean())
print(ntu_dataset[0][:,:,0].mean())

In [None]:
a = ntu_dataset[0]
a[:,:,0] = a[:,:,0] - a[:,:,0].min()
a[:,:,1] = a[:,:,1] - a[:,:,1].min()

In [None]:
ntu_dataset[0].shape

In [None]:
a[:,:,0].max()

In [None]:
print(a[:,:,1].min())
print(a[:,:,0].min())
print(a[:,:,1].max())
print(a[:,:,0].max())
print(a[:,:,1].mean())
print(a[:,:,0].mean())