In [168]:
import os 
import random
import numpy as np

import torch 
from torch import nn 
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import random
import copy
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

rcParams['figure.figsize'] = 12, 8

In [169]:
main_dir = "/Users/devin/Documents/FYP/HAR-ZSL-XAI"
data_dir = os.path.join(main_dir,"sequence_data")
class_names = os.listdir(data_dir)
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 1-train_ratio - val_ratio
batch_size = 32

In [170]:
def classname_id(class_name_list):
    id2classname = {k:v for k, v in zip(list(range(len(class_name_list))),class_name_list)}
    classname2id = {v:k for k, v in id2classname.items()}
    return id2classname, classname2id

In [171]:
class_names = os.listdir(os.path.join(main_dir,"data","pose_data"))

In [172]:
id2clsname, clsname2id = classname_id(class_names)

In [173]:
train_file_list = []
val_file_list = []
test_file_list = []

file_list = [os.path.join(data_dir,x) for x in os.listdir(data_dir)]

random.shuffle(file_list)
num_list = len(file_list)

train_range = [0,int(num_list*train_ratio)]
val_range = [int(num_list*train_ratio),int(num_list*(train_ratio+val_ratio))]
test_range = [int(num_list*(train_ratio+val_ratio)),num_list-1]

train_file_list += file_list[train_range[0]:train_range[1]]
val_file_list += file_list[val_range[0]:val_range[1]]
test_file_list += file_list[test_range[0]:test_range[1]]

In [174]:
len(train_file_list),len(val_file_list),len(test_file_list)

(12357, 2648, 2647)

In [175]:
train_file_list = train_file_list[:(len(train_file_list)//batch_size)*batch_size]
val_file_list = val_file_list[:(len(val_file_list)//batch_size)*batch_size]
test_file_list = test_file_list[:(len(test_file_list)//batch_size)*batch_size]

In [176]:
len(train_file_list),len(val_file_list),len(test_file_list)

(12352, 2624, 2624)

In [177]:
class SkeletonDataset(Dataset):
    def __init__(self, file_list,class2id, transform=None, target_transform=None):
        self.file_list = file_list
        self.transform = transform
        self.class2id = class2id
        self.target_transform = target_transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        a_file = np.load(self.file_list[idx])
        action_type = self.file_list[idx].strip().split("/")[-1].split("_cls_")[0]
        coords, vid_size = a_file["coords"],a_file["video_size"]
        
        shape = coords.shape
        
        coords = torch.from_numpy(coords).float()
        
        coords = torch.reshape(coords, (shape[0], shape[1]*shape[2]))
        label = torch.clone(coords)
        
        if self.transform:
            coords = self.transform(coords)
        if self.target_transform:
            label = self.target_transform(coords)
        return coords, label, self.class2id[action_type]

In [178]:
train_data = SkeletonDataset(train_file_list,clsname2id)
val_data = SkeletonDataset(val_file_list,clsname2id)
test_data = SkeletonDataset(test_file_list,clsname2id)

In [179]:
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_dl = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [180]:
class BiLSTMEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers = 1):
        super(BiLSTMEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # define LSTM layer
        self.linear1 = nn.Linear(self.input_size, 128)
        self.linear2 = nn.Linear(128, 256)
        self.linear3 = nn.Linear(256, 512)
        self.lstm = nn.LSTM(input_size = 512, hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=True,
                            batch_first=True)
        
        

    def forward(self, x_input):
        '''
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence                         
        '''
        x = self.linear1(x_input)
        x = self.linear2(x)
        x = self.linear3(x)
        
        lstm_out, self.hidden = self.lstm(x)
        hidden_transformed = torch.concat(self.hidden,0)
        hidden_transformed = torch.transpose(hidden_transformed,0,1)
        hidden_transformed = torch.flatten(hidden_transformed,start_dim=1)
        
        return lstm_out, hidden_transformed
    
class BiLSTMDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers = 1):
        super(BiLSTMDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # define LSTM layer
        self.linear1 = nn.Linear(128, self.input_size)
        self.linear2 = nn.Linear(256, 128)
        self.linear3 = nn.Linear(512, 256)
        self.linear4 = nn.Linear(1024, 512)
        
        self.lstm = nn.LSTM(input_size = 512, hidden_size = self.hidden_size,
                            num_layers = self.num_layers, bidirectional=True,
                            batch_first=True)
        
        

    def forward(self,encoder_hidden):
        '''
        : param x_input:               input of shape (seq_len, # in batch, input_size)
        : return lstm_out, hidden:     lstm_out gives all the hidden states in the sequence; hidden gives the hidden state and cell state for the last element in the sequence                         
        '''
        
        hidden_shape = encoder_hidden.shape
        
        hidden = encoder_hidden.view((hidden_shape[0],4,512))
        hidden = torch.transpose(hidden,1,0)
        h1,h2,c1,c2 = torch.unbind(hidden,0)
        h,c = torch.stack((h1,h2)),torch.stack((c1,c2))
        
        dummy_input = torch.rand((32,50,512), requires_grad=True)
        
        lstm_out, self.hidden = self.lstm(dummy_input,(h,c))
        x = self.linear4(lstm_out)
        x = self.linear3(x)
        x = self.linear2(x)
        x = self.linear1(x)
        
        return x

class BiLSTMEncDecModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers = 1):
        super(BiLSTMEncDecModel, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.encoder = BiLSTMEncoder(input_size,hidden_size)
        self.decoder = BiLSTMDecoder(input_size,hidden_size)
        
    def forward(self,x):
        lstm_out,embedding = self.encoder(x)
        decoder_out = self.decoder(embedding)
        
        return decoder_out
        

In [181]:
encoder = BiLSTMEncoder(99,512)
decoder = BiLSTMDecoder(99,512)

bilstm_model = BiLSTMEncDecModel(99,512)

In [182]:
lstm_out, embedding = encoder(torch.randn((32,50,99)))

In [183]:
decoder_out = decoder(embedding)

In [184]:
model_out = model(torch.randn((32,50,99)))

In [185]:
model_out.shape

torch.Size([32, 50, 99])

In [186]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [189]:
def train_model(model, train_dataset, val_dataset, n_epochs):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    std_loss = nn.L1Loss(reduction='sum').to(device)
    #contrastive_loss = SupConLoss(contrast_mode="one").to(device)
    history = dict(train=[], val=[])
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 10000.0
  
    for epoch in range(1, n_epochs + 1):
        model = model.train()

        train_losses = []
        for in_seq,tar_seq,action in tqdm(train_dataset):
            optimizer.zero_grad()
            
            in_seq = in_seq.to(device)
            tar_seq = tar_seq.to(device)
            seq_pred = model(in_seq)
            
            loss = std_loss(seq_pred, tar_seq)
            #loss += 0.5*contrastive_loss(embed,labels=sample_label.view(-1))
            #print(contrastive_loss(embed,labels=sample_label.view(-1)))

            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())

        val_losses = []
        model = model.eval()
        with torch.no_grad():
            for in_seq,tar_seq,action in val_dataset:

                in_seq = in_seq.to(device)
                tar_seq = tar_seq.to(device)
                seq_pred = model(in_seq)

                loss = std_loss(seq_pred, tar_seq)
                #loss += 0.5*contrastive_loss(embed,labels=sample_label.view(-1))
                val_losses.append(loss.item())

        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)

        history['train'].append(train_loss)
        history['val'].append(val_loss)

        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())

        print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss}')

    model.load_state_dict(best_model_wts)
    return model.eval(), history

In [None]:
model, history = train_model(
  bilstm_model, 
  train_dl, 
  val_dl, 
  n_epochs=500
)

 74%|███████▍  | 286/386 [01:39<00:42,  2.36it/s]