In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import os, glob

from stgcn import Model, Feeder

In [2]:
in_channels = 3
num_class = 400
edge_importance_weighting = True
graph_args = {
    "layout": "openpose",
    "strategy": "spatial"
}


In [3]:
stgcn_original_model = Model(in_channels, num_class, graph_args, edge_importance_weighting)

In [4]:
stgcn_state_dict = torch.load('./models/stgcn/st_gcn.kinetics-6fa43f73.pth')
stgcn_original_model.load_state_dict(stgcn_state_dict)

<All keys matched successfully>

In [240]:
class NewModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()

        # Load layers from base model
        self.graph = base_model.graph 
        A = torch.tensor(self.graph.A,
                         dtype=torch.float32,
                         requires_grad=False)
        self.register_buffer('A', A)

        
        self.st_gcn_networks = base_model.st_gcn_networks
        self.data_bn = base_model.data_bn
        self.edge_importance = base_model.edge_importance
        
        # For activity recognition
        self.fcn = base_model.fcn
        
        # For point of contact prediction
        self.poc_conv1 = nn.Conv2d(256, 128, kernel_size=1, stride=(1, 1))
        self.poc_conv2 = nn.Conv2d(128, 64, kernel_size=1, stride=(1, 1))
        self.poc_conv3 = nn.Conv2d(64, 1, kernel_size=1, stride=(1, 1))
        

    def forward(self, x):

        # data normalization
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        x = x.view(N * M, C, T, V)

        # forwad
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x, _ = gcn(x, self.A * importance)

        # BRANCH 1 - Activity recognition 
        x1 = F.avg_pool2d(x, x.size()[2:])
        x1 = x1.view(N, M, -1, 1, 1).mean(dim=1)
        x1 = self.fcn(x1)
        x1 = x1.view(x1.size(0), -1)
        
        # BRANCH 2 - Point of contact prediction
        #print(x.size())
        x2 = F.relu(self.poc_conv1(x))
        #print(x2.size())
        x2 = F.relu(self.poc_conv2(x2))
        #print(x2.size())
        x2 = F.sigmoid(self.poc_conv3(x2))
        x2 = x2.view(x2.size(0), x2.size(2), x2.size(3))

        return x1, x2

In [241]:
new_model = NewModel(stgcn_original_model)

In [242]:
feeder = Feeder( './data/stgcn/val_data.npy',
                 './data/stgcn/val_label.pkl',
                 random_choose=False,
                 random_move=False,
                 window_size=-1,
                 debug=False,
                 mmap=True)

In [243]:
sample = torch.tensor(feeder[0:1][0])

In [244]:
x1, x2 = new_model(sample)

In [245]:
x1.size()

torch.Size([1, 400])

In [246]:
x2.size()

torch.Size([2, 75, 18])

In [247]:
x2

tensor([[[0.5093, 0.4967, 0.5345,  ..., 0.4983, 0.4949, 0.5071],
         [0.5172, 0.4978, 0.5409,  ..., 0.5052, 0.4932, 0.5187],
         [0.5188, 0.5053, 0.5439,  ..., 0.4998, 0.4967, 0.5186],
         ...,
         [0.4831, 0.4933, 0.4964,  ..., 0.4800, 0.4840, 0.4853],
         [0.4830, 0.4916, 0.4962,  ..., 0.4794, 0.4854, 0.4875],
         [0.4827, 0.4912, 0.4941,  ..., 0.4812, 0.4854, 0.4880]],

        [[0.4899, 0.5077, 0.5214,  ..., 0.5089, 0.4986, 0.4900],
         [0.4844, 0.5121, 0.5150,  ..., 0.5019, 0.5069, 0.4847],
         [0.4847, 0.5139, 0.5199,  ..., 0.5021, 0.5087, 0.4871],
         ...,
         [0.4831, 0.4933, 0.4964,  ..., 0.4800, 0.4840, 0.4853],
         [0.4830, 0.4916, 0.4962,  ..., 0.4794, 0.4854, 0.4875],
         [0.4827, 0.4912, 0.4941,  ..., 0.4812, 0.4854, 0.4880]]],
       grad_fn=<ViewBackward>)

In [248]:
sample.size()

torch.Size([1, 3, 300, 18, 2])

In [249]:
def load_sample_from_file(path, length=300):
    file_data = np.load(path)
    data = file_data['arr_0']
    data = np.transpose(data, (2, 0, 1))
    frames = data.shape[1]
    interval = frames/length
    idx = np.floor(np.arange(length)*interval).astype(int)
    touching_points = data[3,idx,:]
    pose_sequence = data[0:3,idx,:]
    return pose_sequence, touching_points

In [250]:
class GrabFeeder(torch.utils.data.Dataset):
    def __init__(self, data_path, selected_classes):
        self.data_path = data_path
        self.selected_classes = selected_classes
        self.length = 300
        self.load_data()
        
    def load_data(self):
        data = []
        class_label = []
        touching_points = []
        for i in range(len(self.selected_classes)):
            selected_class = self.selected_classes[i]
            class_files = glob.glob(self.data_path + '/*'+ selected_class +'*.npz')
            for class_file in class_files:
                poses, tp = load_sample_from_file(class_file, length=self.length)
                data.append(poses)
                class_label.append(i)
                touching_points.append(tp)
                
        self.class_label = class_label
        self.tp = np.array(touching_points)
        self.tp = self.tp.reshape(-1, int(self.length/4), 4, 18)
        self.tp = self.tp.max(axis=2)
        
        self.data = np.array(data)
        self.data = self.data.reshape((-1, 3, self.length, 18, 1))
                
    def __len__(self):
        teste = len(self.class_label)
        return teste
        
    def __getitem__(self, index):
        data_numpy = self.data[index]
        label = self.class_label[index]
        tp = self.tp[index]

        return data_numpy, tp, label

In [251]:
selected_classes = ['offhand', 'eat', 'drink']
grab_feeder = GrabFeeder('../../datasets/grab_skeleton/', selected_classes)

In [252]:
from torch.utils.data import DataLoader 

grab_data_loader = DataLoader(grab_feeder, batch_size=5, shuffle=True)

In [253]:
import torch.optim as optim

cross_entropy = nn.CrossEntropyLoss()
bce = nn.BCELoss()
optimizer = optim.SGD(
                new_model.parameters(),
                lr=0.01,
                momentum=0.9,
                nesterov=True,
                weight_decay=0.0001)

In [255]:
new_model.train()
loss_value = []
iter_info = {}

for data, tp, label in grab_data_loader:
    

    data = data.float()
    label = label.long()
    tp = tp.float()
    

    
    # forward
    out1, out2 = new_model(data)
    loss1 = cross_entropy(out1, label)
    loss2 = bce(out2, tp)
    loss = loss1 + loss2

    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # statistics
    iter_info['loss'] = loss.data.item()
    iter_info['lr'] = '{:.6f}'.format(0.01)
    loss_value.append(iter_info['loss'])
    print(iter_info['loss'])

1.0162608623504639
1.2812612056732178
1.4361464977264404
1.1583032608032227
1.2963591814041138
1.8524937629699707
0.7327706813812256
0.870172381401062
0.7078036069869995
0.5761072039604187
1.0221362113952637
0.7195138335227966
0.7396740913391113
0.7662004232406616
1.1179569959640503
1.1183362007141113
1.0909100770950317
0.5565000772476196
0.3607386350631714
1.2589396238327026
1.9918822050094604
0.3790687620639801
0.2648668885231018
1.3971385955810547
0.335785448551178
2.6895077228546143
0.44877856969833374
1.195904016494751
0.36288949847221375
2.0643327236175537
0.5781333446502686
1.3457486629486084


In [138]:
data, tp, label = grab_feeder[0:2]

In [139]:
sample_n = torch.tensor(data[0:2]).float()

In [140]:
sample_n.shape

torch.Size([2, 3, 300, 18, 1])

In [141]:
x1, x2 = new_model(sample_n)

torch.Size([2, 256, 75, 18])
torch.Size([2, 128, 75, 18])
torch.Size([2, 64, 75, 18])


In [135]:
x1.size()

torch.Size([2, 400])

In [120]:
x2.size()

torch.Size([2, 1, 75, 18])

In [132]:
tp.shape

(2, 300, 18)

In [134]:
x2.view(x2.size(0), x2.size(2), x2.size(3)).size()

torch.Size([2, 75, 18])

In [144]:
teste = tp.reshape(2, int(300/4), 4, 18)

In [147]:
teste = tp.reshape(2, int(300/4), 4, 18)
teste.max(axis=2).shape

(2, 75, 18)

In [215]:
a = np.array([[1.0,0.0,1.0],[0.0,1.0,0.0]])
b = np.array([[1.0,1.0,1.0],[0.0,1.0,0.0]])

In [216]:
loss = nn.BCELoss()

In [217]:
a = torch.tensor(a)
b = torch.tensor(b)
loss(a,b)

tensor(16.6667, dtype=torch.float64)