In [4]:
import os
import gc 
import random 
import torch 
import numpy as np 
import pandas as pd 

---
Dataset and Dataloader 

In [5]:
from torch.utils.data import Dataset, DataLoader

In [6]:
skeleton_data = np.load('./data/skeleton_movements.npz')
skeleton_classes, skeleton_mov = skeleton_data['arr_0'], skeleton_data['arr_1']

In [7]:

from collections import defaultdict
action_dict = defaultdict(list)
for i, a in enumerate(skeleton_classes):
    action_dict[a].append(i)

In [45]:
class SkeletonDataset(Dataset):
    def __init__(self, movements, actions, action_dict):
        super(SkeletonDataset, self).__init__()
        self.movements = movements 
        self.actions = actions
        self.action_dict = action_dict
        self.actionsIDs = list(self.action_dict.keys())

    def __getitem__(self, idx):
        x1 = self.movements[idx, ...]
        action = self.actions[idx]

        partial_idx = random.sample(self.action_dict[action], k=1)[0]
        x2 = self.movements[partial_idx, ...]

        label = self.actionsIDs.index(action)
        return x1, x2, label

    def __len__(self):
        return self.movements.shape[0]


In [46]:
BS = 32

In [47]:
sample_dt = SkeletonDataset(skeleton_mov, skeleton_classes, action_dict)
sample_dl = DataLoader(sample_dt, batch_size=BS, shuffle=True, pin_memory=True, drop_last=True)
for b in sample_dl:
    bx1, bx2, by = b 
    # bx = torch.transpose(bx1, 1, 0)
    bs, seq_len, ft_in = bx1.shape
    print(bx1.shape)
    break 

torch.Size([32, 60, 36])


In [48]:
bx12 = torch.vstack([bx1, bx2])
bx = torch.transpose(bx12, 1, 0)
bx12.shape

torch.Size([64, 60, 36])

In [61]:
bys = torch.cat([by, by], dim=0)
bys.shape

torch.Size([64])

---
Model Initialization 

In [11]:
import random 
from torch import nn 
import torch.nn.functional as F 

In [17]:
class LSTMEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers = 1):
        super(LSTMEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # define LSTM layer
        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size,
                            num_layers = num_layers)

    def forward(self, x_input):
        '''
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence                         
        '''
        lstm_out, self.hidden = self.lstm(x_input.view(x_input.shape[0], x_input.shape[1], self.input_size))
        return lstm_out, self.hidden     
    
    def init_hidden(self, batch_size):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))

In [18]:
class LSTMDecoder(nn.Module):    
    def __init__(self, input_size, hidden_size, num_layers = 1):
        super(LSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_size,
                            num_layers = num_layers)
        self.linear = nn.Linear(hidden_size, input_size)           

    def forward(self, x_input, encoder_hidden_states):
        '''        
        : param x_input:                    should be 2D (batch_size, input_size)
        : param encoder_hidden_states:      hidden states
        : return output, hidden:            output gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence
        '''
        lstm_out, self.hidden = self.lstm(x_input.unsqueeze(0), encoder_hidden_states)
        output = self.linear(lstm_out.squeeze(0))     
        return output, self.hidden

In [67]:
class LSTMAE(nn.Module):

    def __init__(self, input_size, seq_len, hidden_size, batch_size, ae_type='recursive', teacher_forcing_ratio=0.5, device='cpu'):
        super(LSTMAE, self).__init__()
        self.input_size = input_size
        self.seq_len = seq_len
        self.hidden_size = hidden_size 
        self.bs = batch_size
        self.ae_type = ae_type # ['recursive', 'teacher_forcing', 'mixed_teacher_forcing']
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.device = device 

        self.encoder = LSTMEncoder(input_size = input_size, hidden_size = hidden_size)
        self.decoder = LSTMDecoder(input_size = input_size, hidden_size = hidden_size)

        self.encoder_hidden = self.encoder.init_hidden(self.bs)

    def forward(self, x):
        # encoding 
        encoder_output, self.encoder_hidden = self.encoder(x)
        # decoding 
        decoder_input = torch.rand((self.bs, self.input_size), requires_grad=True).to(self.device)#self.encoder_hidden[0].squeeze()
        print(decoder_input.shape)
        decoder_hidden = self.encoder_hidden
        # outputs tensor
        outputs = torch.zeros(self.seq_len, self.bs, self.input_size).to(self.device)

        if self.ae_type == 'recursive':
            for t in range(self.seq_len):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                outputs[t] = decoder_output
                decoder_input = decoder_output
        
        elif self.ae_type == 'teacher_forcing':
            if random.random() < self.teacher_forcing_ratio:
                for t in range(self.seq_len):
                    decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                    outputs[t] = decoder_output
                    decoder_input = x[t, :, :]

            else:
                for t in range(self.seq_len):
                    decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                    outputs[t] = decoder_output
                    decoder_input = decoder_output

        elif self.ae_type == 'mixed_teacher_forcing':
            for t in range(self.seq_len):
                decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
                outputs[t] = decoder_output

                if random.random() < self.teacher_forcing_ratio:
                    decoder_input = x[t, :, :]
                else:
                    decoder_input = decoder_output

        return outputs, encoder_output[-1, ...]

In [36]:
INTER_Ft = 128
model = LSTMAE(input_size=ft_in, seq_len=seq_len, hidden_size=INTER_Ft, batch_size=2*bs, ae_type='recursive')
skeleton_output, vector_output = model(bx.float())
skeleton_output.shape, vector_output.shape 

torch.Size([64, 36])


(torch.Size([60, 64, 36]), torch.Size([64, 128]))

---

In [53]:
from src.utils.losses import SupConLoss

In [49]:
by.shape

torch.Size([32])

In [51]:
f1, f2 = torch.split(vector_output, [bs, bs], dim=0)
cons_output = torch.stack([f1.squeeze(1), f2.squeeze(1)], dim=1)
cons_output.shape

torch.Size([32, 2, 128])

In [70]:
# test on contrastive loss 
con_loss = SupConLoss()
sample_loss = con_loss(cons_output, by)

In [71]:
sample_loss

tensor(6.4546, device='cuda:0', grad_fn=<MeanBackward0>)

---

In [56]:
bx.shape

torch.Size([60, 64, 36])

In [57]:
skeleton_output.shape

torch.Size([60, 64, 36])

In [58]:
bx_tp = torch.transpose(bx, 1, 0)
skel_tp = torch.transpose(skeleton_output, 1, 0)
bx_tp.shape, skel_tp.shape

(torch.Size([64, 60, 36]), torch.Size([64, 60, 36]))

In [59]:
mseLoss = nn.MSELoss()
l2_loss = mseLoss(bx_tp, skel_tp)
l2_loss

tensor(0.2630, dtype=torch.float64, grad_fn=<MseLossBackward0>)

---

In [62]:
from tqdm import tqdm 
from torch.optim import Adam

In [73]:
# build AE-training step 
# ----------------------------parameter mapping -------------------------

dataloader = sample_dl 
phase = 'train'
batch_size = 32
alpha = 0.5

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model = LSTMAE(input_size=ft_in, seq_len=seq_len, hidden_size=INTER_Ft, batch_size=2*bs, ae_type='recursive', device=device)
optimizer = Adam(model.parameters(), lr=0.0001)
conLoss = SupConLoss()
l2Loss = nn.MSELoss()

# -----------------------------------------------------------------------
model = model.train()
model.to(device)

epoch_loss = 0 
total_samples = 0 

with tqdm(dataloader, unit='batch', desc=phase) as tepoch:
    for batch in tepoch:
        x1, x2, labels = batch 
        # post-process 
        xa = torch.transpose(torch.vstack([x1, x2]), 1, 0)
        # device offload 
        xa = xa.float().to(device)
        labels = labels.float()

        # set optimizer grad to zero 
        optimizer.zero_grad()
        # get model prediction 
        with torch.set_grad_enabled(phase=='train'):
            skel_output, ft_output = model(xa)

        # reconstruct the output 
        f1, f2 = torch.split(ft_output, [batch_size, batch_size], dim=0)
        cons_output = torch.stack([f1.squeeze(1), f2.squeeze(1)], dim=1)
        # calc. contrastive loss 
        con_loss = conLoss(cons_output, labels)
        # calc. reconstruction loss 
        l2_loss = l2Loss(xa, skel_output)
        # calc. total loss
        total_loss = alpha*con_loss + (1-alpha)*l2_loss

        if phase == 'train':
            total_loss.backward()
            optimizer.step()

        metrics = {'contrastive loss ': con_loss, 'reconstruction loss': l2_loss}
        with torch.no_grad():
                total_samples += len(labels)
                epoch_loss += total_loss.item()
        
        tepoch.set_postfix(metrics)

epoch_loss = epoch_loss/total_samples
# return epoch_loss 


train:   0%|          | 0/3 [00:00<?, ?batch/s]

torch.Size([64, 36])


train:  67%|██████▋   | 2/3 [00:00<00:00,  4.93batch/s, contrastive loss =tensor(8.5735, device='cuda:0', grad_fn=<MeanBackward0>), reconstruction loss=tensor(0.2647, device='cuda:0', grad_fn=<MseLossBackward0>)]

torch.Size([64, 36])
torch.Size([64, 36])


train: 100%|██████████| 3/3 [00:00<00:00,  5.22batch/s, contrastive loss =tensor(5.2948, device='cuda:0', grad_fn=<MeanBackward0>), reconstruction loss=tensor(0.2636, device='cuda:0', grad_fn=<MseLossBackward0>)]
